Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zero downtime reconnects for workers draining #79

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
733 changes: 49 additions & 684 deletions connect.go

Large diffs are not rendered by default.

307 changes: 307 additions & 0 deletions connect/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
package connect

import (
"context"
"crypto/rand"
"errors"
"fmt"
"github.com/coder/websocket"
"github.com/inngest/inngest/pkg/connect/types"
"github.com/inngest/inngest/pkg/connect/wsproto"
connectproto "github.com/inngest/inngest/proto/gen/connect/v1"
"github.com/oklog/ulid/v2"
"golang.org/x/sync/errgroup"
"io"
"net"
"time"
)

type connectReport struct {
reconnect bool
err error
}

func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData, notifyConnectedChan chan struct{}, notifyConnectDoneChan chan connectReport) {
// Set up connection (including connect handshake protocol)
preparedConn, reconnect, err := h.prepareConnection(ctx, data)
if err != nil {
h.logger.Error("could not establish connection", "err", err)

notifyConnectDoneChan <- connectReport{
reconnect: reconnect,
err: fmt.Errorf("could not establish connection: %w", err),
}
return
}

// Notify that the connection was established
notifyConnectedChan <- struct{}{}

// Set up connection lifecycle logic (receiving messages, handling requests, etc.)
reconnect, err = h.handleConnection(ctx, data, preparedConn.ws, preparedConn.gatewayHost, notifyConnectedChan, notifyConnectDoneChan)
if err != nil {
h.logger.Error("could not handle connection", "err", err)

if errors.Is(err, errGatewayDraining) {
// if the gateway is draining, the original connection was closed, and we already reconnected inside handleConnection
return
}

notifyConnectDoneChan <- connectReport{
reconnect: reconnect,
err: fmt.Errorf("could not handle connection: %w", err),
}
return
}

notifyConnectDoneChan <- connectReport{
reconnect: reconnect,
err: nil,
}
}

type connectionEstablishData struct {
hashedSigningKey []byte
numCpuCores int32
totalMem int64
marshaledFns []byte
marshaledCapabilities []byte
manualReadinessAck bool
}

type preparedConnection struct {
ws *websocket.Conn
gatewayHost string
connectionId string
}

func (h *connectHandler) prepareConnection(ctx context.Context, data connectionEstablishData) (preparedConnection, bool, error) {
connectTimeout, cancelConnectTimeout := context.WithTimeout(ctx, 10*time.Second)
defer cancelConnectTimeout()

gatewayHost := h.hostsManager.pickAvailableGateway()
if gatewayHost == "" {
// All gateways have been tried, reset the internal state to retry
h.hostsManager.resetGateways()

return preparedConnection{}, true, fmt.Errorf("no available gateway hosts")
}

// Establish WebSocket connection to one of the gateways
ws, _, err := websocket.Dial(connectTimeout, gatewayHost, &websocket.DialOptions{
Subprotocols: []string{
types.GatewaySubProtocol,
},
})
if err != nil {
h.hostsManager.markUnreachableGateway(gatewayHost)
return preparedConnection{}, true, fmt.Errorf("could not connect to gateway: %w", err)
}

// Connection ID is unique per connection, reconnections should get a new ID
connectionId := ulid.MustNew(ulid.Now(), rand.Reader)

h.logger.Debug("websocket connection established", "gateway_host", gatewayHost)

reconnect, err := h.performConnectHandshake(ctx, connectionId.String(), ws, gatewayHost, data)
if err != nil {
return preparedConnection{}, reconnect, fmt.Errorf("could not perform connect handshake: %w", err)
}

return preparedConnection{ws, gatewayHost, connectionId.String()}, false, nil
}

func (h *connectHandler) handleConnection(ctx context.Context, data connectionEstablishData, ws *websocket.Conn, gatewayHost string, notifyConnectedChan chan struct{}, notifyConnectDoneChan chan connectReport) (reconnect bool, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

defer func() {
// This is a fallback safeguard to always close the WebSocket connection at the end of the function
// Usually, we provide a specific reason, so this is only necessary for unhandled errors
_ = ws.CloseNow()
}()

// When shutting down the worker, close the connection with a reason
go func() {
<-ctx.Done()
_ = ws.Close(websocket.StatusNormalClosure, connectproto.WorkerDisconnectReason_WORKER_SHUTDOWN.String())
}()

// Send buffered but unsent messages if connection was re-established
if len(h.messageBuffer) > 0 {
h.logger.Debug("sending buffered messages", "count", len(h.messageBuffer))
err = h.sendBufferedMessages(ws)
if err != nil {
return true, fmt.Errorf("could not send buffered messages: %w", err)
}
}

go func() {
heartbeatTicker := time.NewTicker(WorkerHeartbeatInterval)
defer heartbeatTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-heartbeatTicker.C:
err := wsproto.Write(ctx, ws, &connectproto.ConnectMessage{
Kind: connectproto.GatewayMessageType_WORKER_HEARTBEAT,
})
if err != nil {
h.logger.Error("failed to send worker heartbeat", "err", err)
}
}

}
}()

eg := errgroup.Group{}
eg.Go(func() error {
for {
var msg connectproto.ConnectMessage
err = wsproto.Read(context.Background(), ws, &msg)
if err != nil {
h.logger.Error("failed to read message", "err", err)

// The connection may still be active, but for some reason we couldn't read the message
return err
}

h.logger.Debug("received gateway request", "msg", &msg)

switch msg.Kind {
case connectproto.GatewayMessageType_GATEWAY_CLOSING:
// Stop the read loop: We will not receive any further messages and should establish a new connection
// We can still use the old connection to send replies to the gateway
return errGatewayDraining
case connectproto.GatewayMessageType_GATEWAY_EXECUTOR_REQUEST:
// Handle invoke in a non-blocking way to allow for other messages to be processed
h.workerPool.Add(workerPoolMsg{
msg: &msg,
ws: ws,
})
default:
h.logger.Error("got unknown gateway request", "err", err)
continue
}
}
})

h.logger.Debug("waiting for read loop to end")

// If read loop ends, this can be for two reasons
// - Connection loss (io.EOF), read loop terminated intentionally (CloseError), other error (unexpected)
// - Worker shutdown, parent context got cancelled
if err := eg.Wait(); err != nil && ctx.Err() == nil {
if errors.Is(err, errGatewayDraining) {
h.hostsManager.markDrainingGateway(gatewayHost)

// Gateway is draining and will not accept new connections.
// We must reconnect to a different gateway, only then can we close the old connection.
waitUntilConnected, doneWaiting := context.WithCancel(context.Background())
defer doneWaiting()

// Intercept connected signal and pass it to the main goroutine
notifyConnectedInterceptChan := make(chan struct{})
go func() {
<-notifyConnectedChan
notifyConnectedInterceptChan <- struct{}{}
doneWaiting()
}()

// Establish new connection and pass close reports back to the main goroutine
go h.connect(context.Background(), data, notifyConnectedInterceptChan, notifyConnectDoneChan)

cancel()

// Wait until the new connection is established before closing the old one
select {
case <-waitUntilConnected.Done():
case <-time.After(10 * time.Second):
h.logger.Error("timed out waiting for new connection to be established")
}

// By returning, we will close the old connection
return false, errGatewayDraining
}

h.logger.Debug("read loop ended with error", "err", err)

// In case the gateway intentionally closed the connection, we'll receive a close error
cerr := websocket.CloseError{}
if errors.As(err, &cerr) {
h.logger.Error("connection closed with reason", "reason", cerr.Reason)

// Reconnect!
return true, fmt.Errorf("connection closed with reason %q: %w", cerr.Reason, cerr)
}

// connection closed without reason
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
h.logger.Error("failed to read message from gateway, lost connection unexpectedly", "err", err)
return true, fmt.Errorf("connection closed unexpectedly: %w", cerr)
}

// If this is not a worker shutdown, we should reconnect
return true, fmt.Errorf("connection closed unexpectedly: %w", ctx.Err())
}

// Perform graceful shutdown routine (context was cancelled)

// Signal gateway that we won't process additional messages!
{
h.logger.Debug("sending worker pause message")
err := wsproto.Write(context.Background(), ws, &connectproto.ConnectMessage{
Kind: connectproto.GatewayMessageType_WORKER_PAUSE,
})
if err != nil {
// We should not exit here, as we're already in the shutdown routine
h.logger.Error("failed to serialize worker pause msg", "err", err)
}
}

h.logger.Debug("waiting for in-progress requests to finish")

// Wait until all in-progress requests are completed
h.workerPool.Wait()

// Attempt to shut down connection if not already done
_ = ws.Close(websocket.StatusNormalClosure, connectproto.WorkerDisconnectReason_WORKER_SHUTDOWN.String())

return false, nil
}

func (h *connectHandler) withTemporaryConnection(data connectionEstablishData, handler func(ws *websocket.Conn) error) error {
// Prevent this connection from receiving work
data.manualReadinessAck = true

maxAttempts := 4

var conn *websocket.Conn
var attempts int
for {
if attempts == maxAttempts {
return fmt.Errorf("could not establish connection after %d attempts", maxAttempts)
}

ws, _, err := h.prepareConnection(context.Background(), data)
if err != nil {
attempts++
continue
}

conn = ws.ws
break
}

defer func() {
_ = conn.Close(websocket.StatusNormalClosure, connectproto.WorkerDisconnectReason_WORKER_SHUTDOWN.String())
}()

err := handler(conn)
if err != nil {
return err
}

return nil
}
59 changes: 59 additions & 0 deletions connect/gateway_hosts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package connect

import "sync"

type hostsManager struct {
gatewayHosts []string
availableGatewayHosts map[string]struct{}
drainingGatewayHosts map[string]struct{}
unreachableGatewayHosts map[string]struct{}
hostsLock sync.RWMutex
}

func newHostsManager(gatewayHosts []string) *hostsManager {
hm := &hostsManager{
gatewayHosts: gatewayHosts,
availableGatewayHosts: make(map[string]struct{}),
drainingGatewayHosts: make(map[string]struct{}),
unreachableGatewayHosts: make(map[string]struct{}),
}

hm.resetGateways()

return hm
}

func (h *hostsManager) pickAvailableGateway() string {
h.hostsLock.RLock()
defer h.hostsLock.RUnlock()

for host := range h.availableGatewayHosts {
return host
}
return ""
}

func (h *hostsManager) markDrainingGateway(host string) {
h.hostsLock.Lock()
defer h.hostsLock.Unlock()
delete(h.availableGatewayHosts, host)
h.drainingGatewayHosts[host] = struct{}{}
}

func (h *hostsManager) markUnreachableGateway(host string) {
h.hostsLock.Lock()
defer h.hostsLock.Unlock()
delete(h.availableGatewayHosts, host)
h.unreachableGatewayHosts[host] = struct{}{}
}

func (h *hostsManager) resetGateways() {
h.hostsLock.Lock()
defer h.hostsLock.Unlock()
h.availableGatewayHosts = make(map[string]struct{})
h.drainingGatewayHosts = make(map[string]struct{})
h.unreachableGatewayHosts = make(map[string]struct{})
for _, host := range h.gatewayHosts {
h.availableGatewayHosts[host] = struct{}{}
}
}
Loading
Loading