diff --git a/connect.go b/connect.go index 203f689a..d4444bab 100644 --- a/connect.go +++ b/connect.go @@ -9,18 +9,20 @@ import ( "github.com/coder/websocket" "github.com/inngest/inngest/pkg/connect/types" "github.com/inngest/inngest/pkg/connect/wsproto" - "github.com/pbnjay/memory" - "google.golang.org/protobuf/proto" - "runtime" - "github.com/inngest/inngest/pkg/enums" "github.com/inngest/inngest/pkg/publicerr" + "github.com/inngest/inngest/pkg/syscode" connectproto "github.com/inngest/inngest/proto/gen/connect/v1" sdkerrors "github.com/inngest/inngestgo/errors" + "github.com/pbnjay/memory" + "google.golang.org/protobuf/proto" + "io" + "net/url" + "runtime" + "sync" "github.com/inngest/inngestgo/internal/sdkrequest" "github.com/oklog/ulid/v2" - "net/url" "os" "strings" "time" @@ -30,6 +32,15 @@ type connectHandler struct { h *handler connectionId ulid.ULID + + messageBuffer []*connectproto.ConnectMessage + messageBufferLock sync.Mutex +} + +// authContext is wrapper for information related to authentication +type authContext struct { + signingKey string + fallback bool } func (h *connectHandler) connectURLs() []string { @@ -91,13 +102,11 @@ func (h *connectHandler) instanceId() string { } func (h *connectHandler) Connect(ctx context.Context) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - signingKey := h.h.GetSigningKey() if signingKey == "" { return fmt.Errorf("must provide signing key") } + auth := authContext{signingKey: signingKey} numCpuCores := runtime.NumCPU() totalMem := memory.TotalMemory() @@ -122,17 +131,77 @@ func (h *connectHandler) Connect(ctx context.Context) error { return fmt.Errorf("failed to serialize connect config: %w", err) } + var attempts int + for { + attempts++ + + shouldReconnect, err := h.connect(ctx, connectionEstablishData{ + signingKey: auth.signingKey, + numCpuCores: int32(numCpuCores), + totalMem: int64(totalMem), + marshaledFns: marshaledFns, + marshaledCapabilities: marshaledCapabilities, + }) + + h.h.Logger.Error("connect failed", "err", err, "reconnect", shouldReconnect) + + if !shouldReconnect { + return err + } + + closeErr := websocket.CloseError{} + if errors.As(err, &closeErr) { + switch closeErr.Reason { + // If auth failed, retry with fallback key + case syscode.CodeConnectAuthFailed: + if auth.fallback { + return fmt.Errorf("failed to authenticate with fallback key, exiting") + } + + signingKeyFallback := h.h.GetSigningKeyFallback() + if signingKeyFallback != "" { + auth = authContext{signingKey: signingKeyFallback, fallback: true} + } + + continue + + // Retry on the following error codes + case syscode.CodeConnectGatewayClosing, syscode.CodeConnectInternal, syscode.CodeConnectWorkerHelloTimeout: + continue + + default: + // If we received a reason that's non-retriable, stop here. + return fmt.Errorf("connect failed with error code %q", closeErr.Reason) + } + } + } +} + +type connectionEstablishData struct { + signingKey string + numCpuCores int32 + totalMem int64 + marshaledFns []byte + marshaledCapabilities []byte +} + +func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData) (reconnect bool, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + connectTimeout, cancelConnectTimeout := context.WithTimeout(ctx, 10*time.Second) defer cancelConnectTimeout() ws, err := h.connectToGateway(connectTimeout) if err != nil { - return fmt.Errorf("could not connect: %w", err) + return false, fmt.Errorf("could not connect: %w", err) } defer func() { + // TODO Do we need to include a reason here? If we only use this for unexpected disconnects, probably not _ = ws.CloseNow() }() + // Connection ID is unique per connection, reconnections should get a new ID h.connectionId = ulid.MustNew(ulid.Now(), rand.Reader) h.h.Logger.Debug("connection established") @@ -144,11 +213,11 @@ func (h *connectHandler) Connect(ctx context.Context) error { var helloMessage connectproto.ConnectMessage err = wsproto.Read(initialMessageTimeout, ws, &helloMessage) if err != nil { - return fmt.Errorf("did not receive gateway hello message: %w", err) + return true, fmt.Errorf("did not receive gateway hello message: %w", err) } if helloMessage.Kind != connectproto.GatewayMessageType_GATEWAY_HELLO { - return fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind) + return true, fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind) } h.h.Logger.Debug("received gateway hello message") @@ -156,9 +225,9 @@ func (h *connectHandler) Connect(ctx context.Context) error { // Send connect message { - hashedKey, err := hashedSigningKey([]byte(signingKey)) + hashedKey, err := hashedSigningKey([]byte(data.signingKey)) if err != nil { - return fmt.Errorf("could not hash signing key: %w", err) + return false, fmt.Errorf("could not hash signing key: %w", err) } apiOrigin := defaultAPIOrigin @@ -177,13 +246,13 @@ func (h *connectHandler) Connect(ctx context.Context) error { }, AppName: h.h.appName, Config: &connectproto.ConfigDetails{ - Capabilities: marshaledCapabilities, - Functions: marshaledFns, + Capabilities: data.marshaledCapabilities, + Functions: data.marshaledFns, ApiOrigin: apiOrigin, }, SystemAttributes: &connectproto.SystemAttributes{ - CpuCores: int32(numCpuCores), - MemBytes: int64(totalMem), + CpuCores: data.numCpuCores, + MemBytes: data.totalMem, Os: runtime.GOOS, }, Environment: h.h.Env, @@ -192,7 +261,7 @@ func (h *connectHandler) Connect(ctx context.Context) error { SdkLanguage: SDKLanguage, }) if err != nil { - return fmt.Errorf("could not serialize sdk connect message: %w", err) + return false, fmt.Errorf("could not serialize sdk connect message: %w", err) } err = wsproto.Write(ctx, ws, &connectproto.ConnectMessage{ @@ -200,71 +269,156 @@ func (h *connectHandler) Connect(ctx context.Context) error { Payload: data, }) if err != nil { - return fmt.Errorf("could not send initial message") + return true, fmt.Errorf("could not send initial message") } } - for { - if ctx.Err() != nil { - break - } + // TODO Read gateway ready - var msg connectproto.ConnectMessage - err = wsproto.Read(ctx, ws, &msg) - if err != nil { - closeErr := websocket.CloseError{} - if errors.As(err, &closeErr) { - h.h.Logger.Error("connection closed unexpectedly", "reason", closeErr.Reason) - return err + // TODO Send buffered but unsent messages if connection was re-established + + inProgress := sync.WaitGroup{} + + var ( + closeErr error + closeErrLock sync.Mutex + ) + + readLoopCtx, cancelReadLoop := context.WithCancel(ctx) + go func() { + // Close connection if run loop ends + defer cancelReadLoop() + + for { + if readLoopCtx.Err() != nil { + break } - // TODO Handle issues reading message: Should we re-establish the connection? - return err - } - h.h.Logger.Debug("received gateway request", "msg", &msg) + var msg connectproto.ConnectMessage + err = wsproto.Read(readLoopCtx, ws, &msg) + if err != nil { + // connection lost with reason + cerr := websocket.CloseError{} + if errors.As(err, &cerr) { + h.h.Logger.Error("connection closed unexpectedly", "reason", cerr.Reason) + closeErrLock.Lock() + closeErr = cerr + closeErrLock.Unlock() + // Reconnect! + return + } + + // connection lost without reason + if errors.Is(err, io.EOF) { + h.h.Logger.Error("failed to read message from gateway, lost connection unexpectedly", "err", err) + return + } + + h.h.Logger.Error("failed to read message", "err", err) + + // The connection may still be active, but for some reason we couldn't read the message + return + } - switch msg.Kind { - case connectproto.GatewayMessageType_GATEWAY_EXECUTOR_REQUEST: - // TODO: this should be a pool instead of dynamic goroutines - // Handle invoke in a non-blocking way to allow for other messages to be processed - go h.handleInvokeMessage(ctx, ws, &msg) - default: - h.h.Logger.Error("got unknown gateway request", "err", err) - continue + h.h.Logger.Debug("received gateway request", "msg", &msg) + + switch msg.Kind { + case connectproto.GatewayMessageType_GATEWAY_EXECUTOR_REQUEST: + // TODO: this should be a pool instead of dynamic goroutines + // Handle invoke in a non-blocking way to allow for other messages to be processed + inProgress.Add(1) + go func() { + defer inProgress.Done() + + // Always make sure the invoke finishes properly + processCtx := context.Background() + + err := h.handleInvokeMessage(processCtx, ws, &msg) + + // When we encounter an error, we cannot retry the connection from inside the goroutine. + // If we're dealing with connection loss, the next read loop will fail with the same error + // and handle the reconnection. + if err != nil { + cerr := websocket.CloseError{} + if errors.As(err, &cerr) { + h.h.Logger.Error("gateway connection closed with reason", "reason", cerr.Reason) + return + } + + if errors.Is(err, io.EOF) { + h.h.Logger.Error("gateway connection closed unexpectedly", "err", err) + return + } + } + }() + default: + h.h.Logger.Error("got unknown gateway request", "err", err) + continue + } } + }() + + <-readLoopCtx.Done() + + // In case the gateway intentionally closed the connection, we'll receive a close error + if closeErr != nil { + return true, fmt.Errorf("connection closed unexpectedly: %w", closeErr) + } + + // If read loop ended, this could be for two reasons + // - Connection loss (io.EOF), read loop terminated intentionally + // - Worker shutdown, parent context got canceled + if ctx.Err() == nil { + return true, fmt.Errorf("connection closed unexpectedly") } - // TODO Perform graceful shutdown routine - _ = ws.Close(websocket.StatusNormalClosure, "") + // Perform graceful shutdown routine - return nil + // TODO Signal gateway that we won't process additional messages! + + // Wait until all in-progress requests are completed + inProgress.Wait() + + // TODO Send out buffered messages, using new connection if necessary! + + _ = ws.Close(websocket.StatusNormalClosure, connectproto.WorkerDisconnectReason_WORKER_SHUTDOWN.String()) + + return false, nil } -func (h *connectHandler) handleInvokeMessage(ctx context.Context, ws *websocket.Conn, msg *connectproto.ConnectMessage) { +func (h *connectHandler) handleInvokeMessage(ctx context.Context, ws *websocket.Conn, msg *connectproto.ConnectMessage) error { resp, err := h.connectInvoke(ctx, msg) if err != nil { h.h.Logger.Error("failed to handle sdk request", "err", err) - // TODO Should we drop the connection? Continue receiving messages? handle error - return + // TODO Should we drop the connection? Continue receiving messages? + return fmt.Errorf("could not handle sdk request: %w", err) } data, err := proto.Marshal(resp) if err != nil { h.h.Logger.Error("failed to serialize sdk response", "err", err) // TODO This should never happen; Signal that we should retry - return + return fmt.Errorf("could not serialize sdk response: %w", err) } - err = wsproto.Write(ctx, ws, &connectproto.ConnectMessage{ + responseMessage := &connectproto.ConnectMessage{ Kind: connectproto.GatewayMessageType_WORKER_REPLY, Payload: data, - }) + } + + err = wsproto.Write(ctx, ws, responseMessage) if err != nil { h.h.Logger.Error("failed to send sdk response", "err", err) - // TODO This should never happen; Signal that we should retry - // continue - return + + // Buffer message to retry + h.messageBufferLock.Lock() + h.messageBuffer = append(h.messageBuffer, responseMessage) + h.messageBufferLock.Unlock() + + return fmt.Errorf("could not send sdk response: %w", err) } + + return nil } // connectInvoke is the counterpart to invoke for connect