Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconnect on connection loss #73

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 208 additions & 54 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@ import (
"github.com/coder/websocket"
"github.com/inngest/inngest/pkg/connect/types"
"github.com/inngest/inngest/pkg/connect/wsproto"
"github.com/pbnjay/memory"
"google.golang.org/protobuf/proto"
"runtime"

"github.com/inngest/inngest/pkg/enums"
"github.com/inngest/inngest/pkg/publicerr"
"github.com/inngest/inngest/pkg/syscode"
connectproto "github.com/inngest/inngest/proto/gen/connect/v1"
sdkerrors "github.com/inngest/inngestgo/errors"
"github.com/pbnjay/memory"
"google.golang.org/protobuf/proto"
"io"
"net/url"
"runtime"
"sync"

"github.com/inngest/inngestgo/internal/sdkrequest"
"github.com/oklog/ulid/v2"
"net/url"
"os"
"strings"
"time"
Expand All @@ -30,6 +32,15 @@ type connectHandler struct {
h *handler

connectionId ulid.ULID

messageBuffer []*connectproto.ConnectMessage
messageBufferLock sync.Mutex
}

// authContext is wrapper for information related to authentication
type authContext struct {
signingKey string
fallback bool
}

func (h *connectHandler) connectURLs() []string {
Expand Down Expand Up @@ -91,13 +102,11 @@ func (h *connectHandler) instanceId() string {
}

func (h *connectHandler) Connect(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

signingKey := h.h.GetSigningKey()
if signingKey == "" {
return fmt.Errorf("must provide signing key")
}
auth := authContext{signingKey: signingKey}

numCpuCores := runtime.NumCPU()
totalMem := memory.TotalMemory()
Expand All @@ -122,17 +131,77 @@ func (h *connectHandler) Connect(ctx context.Context) error {
return fmt.Errorf("failed to serialize connect config: %w", err)
}

var attempts int
for {
attempts++

shouldReconnect, err := h.connect(ctx, connectionEstablishData{
signingKey: auth.signingKey,
numCpuCores: int32(numCpuCores),
totalMem: int64(totalMem),
marshaledFns: marshaledFns,
marshaledCapabilities: marshaledCapabilities,
})

h.h.Logger.Error("connect failed", "err", err, "reconnect", shouldReconnect)

if !shouldReconnect {
return err
}

closeErr := websocket.CloseError{}
if errors.As(err, &closeErr) {
switch closeErr.Reason {
// If auth failed, retry with fallback key
case syscode.CodeConnectAuthFailed:
if auth.fallback {
return fmt.Errorf("failed to authenticate with fallback key, exiting")
}

signingKeyFallback := h.h.GetSigningKeyFallback()
if signingKeyFallback != "" {
auth = authContext{signingKey: signingKeyFallback, fallback: true}
}

continue

// Retry on the following error codes
case syscode.CodeConnectGatewayClosing, syscode.CodeConnectInternal, syscode.CodeConnectWorkerHelloTimeout:
continue

default:
// If we received a reason that's non-retriable, stop here.
return fmt.Errorf("connect failed with error code %q", closeErr.Reason)
}
}
}
}

type connectionEstablishData struct {
signingKey string
numCpuCores int32
totalMem int64
marshaledFns []byte
marshaledCapabilities []byte
}

func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData) (reconnect bool, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

connectTimeout, cancelConnectTimeout := context.WithTimeout(ctx, 10*time.Second)
defer cancelConnectTimeout()

ws, err := h.connectToGateway(connectTimeout)
if err != nil {
return fmt.Errorf("could not connect: %w", err)
return false, fmt.Errorf("could not connect: %w", err)
}
defer func() {
// TODO Do we need to include a reason here? If we only use this for unexpected disconnects, probably not
_ = ws.CloseNow()
}()

// Connection ID is unique per connection, reconnections should get a new ID
h.connectionId = ulid.MustNew(ulid.Now(), rand.Reader)

h.h.Logger.Debug("connection established")
Expand All @@ -144,21 +213,21 @@ func (h *connectHandler) Connect(ctx context.Context) error {
var helloMessage connectproto.ConnectMessage
err = wsproto.Read(initialMessageTimeout, ws, &helloMessage)
if err != nil {
return fmt.Errorf("did not receive gateway hello message: %w", err)
return true, fmt.Errorf("did not receive gateway hello message: %w", err)
}

if helloMessage.Kind != connectproto.GatewayMessageType_GATEWAY_HELLO {
return fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind)
return true, fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind)
}

h.h.Logger.Debug("received gateway hello message")
}

// Send connect message
{
hashedKey, err := hashedSigningKey([]byte(signingKey))
hashedKey, err := hashedSigningKey([]byte(data.signingKey))
if err != nil {
return fmt.Errorf("could not hash signing key: %w", err)
return false, fmt.Errorf("could not hash signing key: %w", err)
}

apiOrigin := defaultAPIOrigin
Expand All @@ -177,13 +246,13 @@ func (h *connectHandler) Connect(ctx context.Context) error {
},
AppName: h.h.appName,
Config: &connectproto.ConfigDetails{
Capabilities: marshaledCapabilities,
Functions: marshaledFns,
Capabilities: data.marshaledCapabilities,
Functions: data.marshaledFns,
ApiOrigin: apiOrigin,
},
SystemAttributes: &connectproto.SystemAttributes{
CpuCores: int32(numCpuCores),
MemBytes: int64(totalMem),
CpuCores: data.numCpuCores,
MemBytes: data.totalMem,
Os: runtime.GOOS,
},
Environment: h.h.Env,
Expand All @@ -192,79 +261,164 @@ func (h *connectHandler) Connect(ctx context.Context) error {
SdkLanguage: SDKLanguage,
})
if err != nil {
return fmt.Errorf("could not serialize sdk connect message: %w", err)
return false, fmt.Errorf("could not serialize sdk connect message: %w", err)
}

err = wsproto.Write(ctx, ws, &connectproto.ConnectMessage{
Kind: connectproto.GatewayMessageType_WORKER_CONNECT,
Payload: data,
})
if err != nil {
return fmt.Errorf("could not send initial message")
return true, fmt.Errorf("could not send initial message")
}
}

for {
if ctx.Err() != nil {
break
}
// TODO Read gateway ready

var msg connectproto.ConnectMessage
err = wsproto.Read(ctx, ws, &msg)
if err != nil {
closeErr := websocket.CloseError{}
if errors.As(err, &closeErr) {
h.h.Logger.Error("connection closed unexpectedly", "reason", closeErr.Reason)
return err
// TODO Send buffered but unsent messages if connection was re-established

inProgress := sync.WaitGroup{}

var (
closeErr error
closeErrLock sync.Mutex
)

readLoopCtx, cancelReadLoop := context.WithCancel(ctx)
go func() {
// Close connection if run loop ends
defer cancelReadLoop()

for {
if readLoopCtx.Err() != nil {
break
}
// TODO Handle issues reading message: Should we re-establish the connection?
return err
}

h.h.Logger.Debug("received gateway request", "msg", &msg)
var msg connectproto.ConnectMessage
err = wsproto.Read(readLoopCtx, ws, &msg)
if err != nil {
// connection lost with reason
cerr := websocket.CloseError{}
if errors.As(err, &cerr) {
h.h.Logger.Error("connection closed unexpectedly", "reason", cerr.Reason)
closeErrLock.Lock()
closeErr = cerr
closeErrLock.Unlock()
// Reconnect!
return
}

// connection lost without reason
if errors.Is(err, io.EOF) {
h.h.Logger.Error("failed to read message from gateway, lost connection unexpectedly", "err", err)
return
}

h.h.Logger.Error("failed to read message", "err", err)

// The connection may still be active, but for some reason we couldn't read the message
return
}

switch msg.Kind {
case connectproto.GatewayMessageType_GATEWAY_EXECUTOR_REQUEST:
// TODO: this should be a pool instead of dynamic goroutines
// Handle invoke in a non-blocking way to allow for other messages to be processed
go h.handleInvokeMessage(ctx, ws, &msg)
default:
h.h.Logger.Error("got unknown gateway request", "err", err)
continue
h.h.Logger.Debug("received gateway request", "msg", &msg)

switch msg.Kind {
case connectproto.GatewayMessageType_GATEWAY_EXECUTOR_REQUEST:
// TODO: this should be a pool instead of dynamic goroutines
// Handle invoke in a non-blocking way to allow for other messages to be processed
inProgress.Add(1)
go func() {
defer inProgress.Done()

// Always make sure the invoke finishes properly
processCtx := context.Background()

err := h.handleInvokeMessage(processCtx, ws, &msg)

// When we encounter an error, we cannot retry the connection from inside the goroutine.
// If we're dealing with connection loss, the next read loop will fail with the same error
// and handle the reconnection.
if err != nil {
cerr := websocket.CloseError{}
if errors.As(err, &cerr) {
h.h.Logger.Error("gateway connection closed with reason", "reason", cerr.Reason)
return
}

if errors.Is(err, io.EOF) {
h.h.Logger.Error("gateway connection closed unexpectedly", "err", err)
return
}
}
}()
default:
h.h.Logger.Error("got unknown gateway request", "err", err)
continue
}
}
}()

<-readLoopCtx.Done()

// In case the gateway intentionally closed the connection, we'll receive a close error
if closeErr != nil {
return true, fmt.Errorf("connection closed unexpectedly: %w", closeErr)
}

// If read loop ended, this could be for two reasons
// - Connection loss (io.EOF), read loop terminated intentionally
// - Worker shutdown, parent context got canceled
if ctx.Err() == nil {
return true, fmt.Errorf("connection closed unexpectedly")
}

// TODO Perform graceful shutdown routine
_ = ws.Close(websocket.StatusNormalClosure, "")
// Perform graceful shutdown routine

return nil
// TODO Signal gateway that we won't process additional messages!

// Wait until all in-progress requests are completed
inProgress.Wait()

// TODO Send out buffered messages, using new connection if necessary!

_ = ws.Close(websocket.StatusNormalClosure, connectproto.WorkerDisconnectReason_WORKER_SHUTDOWN.String())

return false, nil
}

func (h *connectHandler) handleInvokeMessage(ctx context.Context, ws *websocket.Conn, msg *connectproto.ConnectMessage) {
func (h *connectHandler) handleInvokeMessage(ctx context.Context, ws *websocket.Conn, msg *connectproto.ConnectMessage) error {
resp, err := h.connectInvoke(ctx, msg)
if err != nil {
h.h.Logger.Error("failed to handle sdk request", "err", err)
// TODO Should we drop the connection? Continue receiving messages? handle error
return
// TODO Should we drop the connection? Continue receiving messages?
return fmt.Errorf("could not handle sdk request: %w", err)
}

data, err := proto.Marshal(resp)
if err != nil {
h.h.Logger.Error("failed to serialize sdk response", "err", err)
// TODO This should never happen; Signal that we should retry
return
return fmt.Errorf("could not serialize sdk response: %w", err)
}

err = wsproto.Write(ctx, ws, &connectproto.ConnectMessage{
responseMessage := &connectproto.ConnectMessage{
Kind: connectproto.GatewayMessageType_WORKER_REPLY,
Payload: data,
})
}

err = wsproto.Write(ctx, ws, responseMessage)
if err != nil {
h.h.Logger.Error("failed to send sdk response", "err", err)
// TODO This should never happen; Signal that we should retry
// continue
return

// Buffer message to retry
h.messageBufferLock.Lock()
h.messageBuffer = append(h.messageBuffer, responseMessage)
h.messageBufferLock.Unlock()

return fmt.Errorf("could not send sdk response: %w", err)
}

return nil
}

// connectInvoke is the counterpart to invoke for connect
Expand Down
Loading