Skip to content

Commit

Permalink
Socket handoff prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
rjlaine committed Jan 24, 2025
1 parent 0b70c76 commit d594516
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 2 deletions.
181 changes: 181 additions & 0 deletions go/vt/vtgateproxy/handoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package vtgateproxy

import (
"errors"
"fmt"
"log/slog"

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)

Check failure on line 6 in go/vt/vtgateproxy/handoff.go

View workflow job for this annotation

GitHub Actions / Docker Test Cluster 10

package log/slog is not in GOROOT (/usr/local/go/src/log/slog)
"net"
"os"
"syscall"
"time"

"golang.org/x/sys/unix"
)

// handoff implements a no-downtime handoff of a TCP listener from one running
// process to another. It can be used for no-downtime deploys of HTTP servers
// on a single host/port.

// ListenForHandoff opens a unix domain socket and listens for handoff
// requests. When a handoff request is received, the underlying file
// descriptor of `listener` is handed off over the socket.
//
// If an error occurs while opening the unix domain
// socket, or during handoff, it will be logged and the listener will resume
// listening. Otherwise, Listen will return nil when the handoff is complete.
//
// Callers should drain any servers
// connected to the net.Listener, and in-flight requests should be resolved
// before shutting down.
func ListenForHandoff(socketPath string, listener net.Listener) error {
// Clean up any leftover sockets that might have gotten left from previous
// processes.
os.Remove(socketPath)

unixListener, err := net.Listen("unix", socketPath)
if err != nil {
return err
}
defer func() {
unixListener.Close()
os.Remove(socketPath)
}()

for {
err := listen(unixListener, listener)
if err != nil {
slog.Error("handoff socket error", "error", err)
continue
}

return nil
}
}

var magicPacket = "handoff"

func listen(unixListener, listener net.Listener) error {
conn, err := unixListener.Accept()
if err != nil {
return err
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(1 * time.Second))
if err != nil {
return err
}

b := make([]byte, len(magicPacket))
n, err := conn.Read(b)
if err != nil {
return err
}
if string(b[:n]) != magicPacket {
return errors.New("bad magic packet")
}

return handoff(conn, listener)
}

func handoff(conn net.Conn, listener net.Listener) error {
unixFD, err := getFD(conn.(*net.UnixConn))
if err != nil {
return err
}

tcpListener := listener.(*net.TCPListener)

tcpFd, err := getFD(tcpListener)
if err != nil {
return err
}

rights := unix.UnixRights(tcpFd)
err = unix.Sendmsg(unixFD, nil, rights, nil, 0)
if err != nil {
return err
}

return nil
}

// RequestHandoff checks for the presence of a unix domain socket at
// `socketPath` and opens a connection. The server side of the socket will
// immediately send a file descriptor of a TCP socket over the unix domain
// socket. This file descriptor is converted into a net.Listener and returned
// to the caller for immediate use.
//
// During the time between socket handoff and startup of the new server,
// requests to the socket will block. Requests will only fail if the client
// timeout is shorter than the duration of the handoff period.
//
// If nothing is listening on the other end of the unix domain socket,
// ErrNoHandoff is returned. Clients should check for this condition, and dial
// the TCP socket themselves.
func RequestHandoff(socketPath string) (net.Listener, error) {
if socketPath == "" {
return nil, ErrNoHandoff
}
conn, err := net.Dial("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrNoHandoff, err)
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(1 * time.Second))
if err != nil {
return nil, err
}

_, err = conn.Write([]byte(magicPacket))
if err != nil {
return nil, fmt.Errorf("%w: failed to send magic packet", err)
}

f, err := (conn.(*net.UnixConn)).File()
if err != nil {
return nil, fmt.Errorf("%w: fd not read", err)
}
defer f.Close()

b := make([]byte, unix.CmsgSpace(4))
//nolint:dogsled
_, _, _, _, err = unix.Recvmsg(int(f.Fd()), nil, b, 0)
if err != nil {
return nil, fmt.Errorf("%w: msg not received", err)
}

cmsgs, err := unix.ParseSocketControlMessage(b)
if err != nil {
return nil, fmt.Errorf("%w: control msg not parsed", err)
}
fds, err := unix.ParseUnixRights(&cmsgs[0])
if err != nil {
return nil, fmt.Errorf("%w: invalid unix rights", err)
}
fd := fds[0]

listenerFD := os.NewFile(uintptr(fd), "listener")
defer f.Close()

l, err := net.FileListener(listenerFD)
if err != nil {
return nil, fmt.Errorf("%w: failed to acquire new fd", err)
}

return l, nil
}

// ErrNoHandoff indicates that no handoff was performed.
var ErrNoHandoff = errors.New("no handoff")

func getFD(conn syscall.Conn) (fd int, err error) {
raw, err := conn.SyscallConn()
if err != nil {
return -1, err
}

err = raw.Control(func(ptr uintptr) {
fd = int(ptr)
})
return fd, err
}
41 changes: 39 additions & 2 deletions go/vt/vtgateproxy/mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@ package vtgateproxy

import (
"context"
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

"github.com/pires/go-proxyproto"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/vtgateconn"
Expand All @@ -50,6 +52,7 @@ import (
var (
mysqlServerPort = flag.Int("mysql_server_port", -1, "If set, also listen for MySQL binary protocol connections on this port.")
mysqlServerBindAddress = flag.String("mysql_server_bind_address", "", "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.")
mysqlServerHandoffSocket = flag.String("mysql_server_handoff_sock", "", "Opens a unix domain socket for no-downtime handoff of the mysql listener during deploys of the proxy")
mysqlServerSocketPath = flag.String("mysql_server_socket_path", "", "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket")
mysqlTCPVersion = flag.String("mysql_tcp_version", "tcp", "Select tcp, tcp4, or tcp6 to control the socket type.")
mysqlAuthServerImpl = flag.String("mysql_auth_server_impl", "static", "Which auth server implementation to use. Options: none, ldap, clientcert, static, vault.")
Expand Down Expand Up @@ -439,8 +442,42 @@ func initMySQLProtocol() {
var err error
proxyHandle = newProxyHandler(vtGateProxy)
if *mysqlServerPort >= 0 {
// Request socket from an already running process, or start a new
// listener to serve requests on.
listener, err := RequestHandoff(*mysqlServerHandoffSocket)
if errors.Is(err, ErrNoHandoff) {
listener, err = net.Listen(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, strconv.Itoa(*mysqlServerPort)))
}
if err != nil {
log.Exitf("Failed to open listener: %v", err)
}

// Advertise unix domain socket for handoff by future processes.
if *mysqlServerHandoffSocket != "" {
go func() {
// Shut down the server as soon as the handoff is complete.
defer func() {
// TODO better shutdown hook
shutdownMysqlProtocolAndDrain()
rollbackAtShutdown()
os.Exit(0)
}()

err := ListenForHandoff(*mysqlServerHandoffSocket, listener)
if err != nil {
log.Errorf("Handoff failed: %v", err)
return
}

log.Info("Handed off socket and shutting down")
}()
}

log.Infof("Mysql Server listening on Port %d", *mysqlServerPort)
mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol, *mysqlConnBufferPooling)
if *mysqlProxyProtocol {
listener = &proxyproto.Listener{Listener: listener}
}
mysqlListener, err = mysql.NewFromListener(listener, authServer, proxyHandle, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlConnBufferPooling)
if err != nil {
log.Exitf("mysql.NewListener failed: %v", err)
}
Expand Down

0 comments on commit d594516

Please sign in to comment.