Skip to content

Commit

Permalink
clean up return types
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoScheufler committed Nov 25, 2024
1 parent c3096e1 commit 1abd900
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 33 deletions.
51 changes: 26 additions & 25 deletions connect/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ type connectReport struct {

func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData) {
// Set up connection (including connect handshake protocol)
preparedConn, reconnect, err := h.prepareConnection(ctx, data)
preparedConn, err := h.prepareConnection(ctx, data)
if err != nil {
h.logger.Error("could not establish connection", "err", err)

h.notifyConnectDoneChan <- connectReport{
reconnect: reconnect,
reconnect: shouldReconnect(err),
err: fmt.Errorf("could not establish connection: %w", err),
}
return
Expand All @@ -38,7 +38,7 @@ func (h *connectHandler) connect(ctx context.Context, data connectionEstablishDa
h.notifyConnectedChan <- struct{}{}

// Set up connection lifecycle logic (receiving messages, handling requests, etc.)
reconnect, err = h.handleConnection(ctx, data, preparedConn.ws, preparedConn.gatewayHost)
err = h.handleConnection(ctx, data, preparedConn.ws, preparedConn.gatewayHost)
if err != nil {
h.logger.Error("could not handle connection", "err", err)

Expand All @@ -48,16 +48,13 @@ func (h *connectHandler) connect(ctx context.Context, data connectionEstablishDa
}

h.notifyConnectDoneChan <- connectReport{
reconnect: reconnect,
reconnect: shouldReconnect(err),
err: fmt.Errorf("could not handle connection: %w", err),
}
return
}

h.notifyConnectDoneChan <- connectReport{
reconnect: reconnect,
err: nil,
}
h.notifyConnectDoneChan <- connectReport{}
}

type connectionEstablishData struct {
Expand All @@ -69,13 +66,13 @@ type connectionEstablishData struct {
manualReadinessAck bool
}

type preparedConnection struct {
type connection struct {
ws *websocket.Conn
gatewayHost string
connectionId string
}

func (h *connectHandler) prepareConnection(ctx context.Context, data connectionEstablishData) (preparedConnection, bool, error) {
func (h *connectHandler) prepareConnection(ctx context.Context, data connectionEstablishData) (connection, error) {
connectTimeout, cancelConnectTimeout := context.WithTimeout(ctx, 10*time.Second)
defer cancelConnectTimeout()

Expand All @@ -84,7 +81,7 @@ func (h *connectHandler) prepareConnection(ctx context.Context, data connectionE
// All gateways have been tried, reset the internal state to retry
h.hostsManager.resetGateways()

return preparedConnection{}, true, fmt.Errorf("no available gateway hosts")
return connection{}, reconnectError{fmt.Errorf("no available gateway hosts")}
}

// Establish WebSocket connection to one of the gateways
Expand All @@ -95,23 +92,27 @@ func (h *connectHandler) prepareConnection(ctx context.Context, data connectionE
})
if err != nil {
h.hostsManager.markUnreachableGateway(gatewayHost)
return preparedConnection{}, true, fmt.Errorf("could not connect to gateway: %w", err)
return connection{}, reconnectError{fmt.Errorf("could not connect to gateway: %w", err)}
}

// Connection ID is unique per connection, reconnections should get a new ID
connectionId := ulid.MustNew(ulid.Now(), rand.Reader)

h.logger.Debug("websocket connection established", "gateway_host", gatewayHost)

reconnect, err := h.performConnectHandshake(ctx, connectionId.String(), ws, gatewayHost, data)
err = h.performConnectHandshake(ctx, connectionId.String(), ws, gatewayHost, data)
if err != nil {
return preparedConnection{}, reconnect, fmt.Errorf("could not perform connect handshake: %w", err)
return connection{}, reconnectError{fmt.Errorf("could not perform connect handshake: %w", err)}
}

return preparedConnection{ws, gatewayHost, connectionId.String()}, false, nil
return connection{
ws: ws,
gatewayHost: gatewayHost,
connectionId: connectionId.String(),
}, nil
}

func (h *connectHandler) handleConnection(ctx context.Context, data connectionEstablishData, ws *websocket.Conn, gatewayHost string) (reconnect bool, err error) {
func (h *connectHandler) handleConnection(ctx context.Context, data connectionEstablishData, ws *websocket.Conn, gatewayHost string) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand All @@ -130,9 +131,9 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
// Send buffered but unsent messages if connection was re-established
if len(h.messageBuffer) > 0 {
h.logger.Debug("sending buffered messages", "count", len(h.messageBuffer))
err = h.sendBufferedMessages(ws)
err := h.sendBufferedMessages(ws)
if err != nil {
return true, fmt.Errorf("could not send buffered messages: %w", err)
return reconnectError{fmt.Errorf("could not send buffered messages: %w", err)}
}
}

Expand All @@ -159,7 +160,7 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
eg.Go(func() error {
for {
var msg connectproto.ConnectMessage
err = wsproto.Read(context.Background(), ws, &msg)
err := wsproto.Read(context.Background(), ws, &msg)
if err != nil {
h.logger.Error("failed to read message", "err", err)

Expand Down Expand Up @@ -222,7 +223,7 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
}

// By returning, we will close the old connection
return false, errGatewayDraining
return errGatewayDraining
}

h.logger.Debug("read loop ended with error", "err", err)
Expand All @@ -233,17 +234,17 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
h.logger.Error("connection closed with reason", "reason", cerr.Reason)

// Reconnect!
return true, fmt.Errorf("connection closed with reason %q: %w", cerr.Reason, cerr)
return reconnectError{fmt.Errorf("connection closed with reason %q: %w", cerr.Reason, cerr)}
}

// connection closed without reason
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
h.logger.Error("failed to read message from gateway, lost connection unexpectedly", "err", err)
return true, fmt.Errorf("connection closed unexpectedly: %w", cerr)
return reconnectError{fmt.Errorf("connection closed unexpectedly: %w", cerr)}
}

// If this is not a worker shutdown, we should reconnect
return true, fmt.Errorf("connection closed unexpectedly: %w", ctx.Err())
return reconnectError{fmt.Errorf("connection closed unexpectedly: %w", ctx.Err())}
}

// Perform graceful shutdown routine (context was cancelled)
Expand All @@ -268,7 +269,7 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
// Attempt to shut down connection if not already done
_ = ws.Close(websocket.StatusNormalClosure, connectproto.WorkerDisconnectReason_WORKER_SHUTDOWN.String())

return false, nil
return nil
}

func (h *connectHandler) withTemporaryConnection(data connectionEstablishData, handler func(ws *websocket.Conn) error) error {
Expand All @@ -284,7 +285,7 @@ func (h *connectHandler) withTemporaryConnection(data connectionEstablishData, h
return fmt.Errorf("could not establish connection after %d attempts", maxAttempts)
}

ws, _, err := h.prepareConnection(context.Background(), data)
ws, err := h.prepareConnection(context.Background(), data)
if err != nil {
attempts++
continue
Expand Down
33 changes: 25 additions & 8 deletions connect/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connect

import (
"context"
"errors"
"fmt"
"github.com/coder/websocket"
"github.com/inngest/inngest/pkg/connect/wsproto"
Expand All @@ -11,7 +12,23 @@ import (
"time"
)

func (h *connectHandler) performConnectHandshake(ctx context.Context, connectionId string, ws *websocket.Conn, gatewayHost string, data connectionEstablishData) (bool, error) {
type reconnectError struct {
err error
}

func (e reconnectError) Unwrap() error {
return e.err
}

func (e reconnectError) Error() string {
return fmt.Sprintf("reconnect error: %v", e.err)
}

func shouldReconnect(err error) bool {
return errors.Is(err, reconnectError{})
}

func (h *connectHandler) performConnectHandshake(ctx context.Context, connectionId string, ws *websocket.Conn, gatewayHost string, data connectionEstablishData) error {
// Wait for gateway hello message
{
initialMessageTimeout, cancelInitialTimeout := context.WithTimeout(ctx, 5*time.Second)
Expand All @@ -20,12 +37,12 @@ func (h *connectHandler) performConnectHandshake(ctx context.Context, connection
err := wsproto.Read(initialMessageTimeout, ws, &helloMessage)
if err != nil {
h.hostsManager.markUnreachableGateway(gatewayHost)
return true, fmt.Errorf("did not receive gateway hello message: %w", err)
return reconnectError{fmt.Errorf("did not receive gateway hello message: %w", err)}
}

if helloMessage.Kind != connectproto.GatewayMessageType_GATEWAY_HELLO {
h.hostsManager.markUnreachableGateway(gatewayHost)
return true, fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind)
return reconnectError{fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind)}
}

h.logger.Debug("received gateway hello message")
Expand Down Expand Up @@ -66,15 +83,15 @@ func (h *connectHandler) performConnectHandshake(ctx context.Context, connection
WorkerManualReadinessAck: data.manualReadinessAck,
})
if err != nil {
return false, fmt.Errorf("could not serialize sdk connect message: %w", err)
return fmt.Errorf("could not serialize sdk connect message: %w", err)
}

err = wsproto.Write(ctx, ws, &connectproto.ConnectMessage{
Kind: connectproto.GatewayMessageType_WORKER_CONNECT,
Payload: data,
})
if err != nil {
return true, fmt.Errorf("could not send initial message")
return reconnectError{fmt.Errorf("could not send initial message")}
}
}

Expand All @@ -85,15 +102,15 @@ func (h *connectHandler) performConnectHandshake(ctx context.Context, connection
var connectionReadyMsg connectproto.ConnectMessage
err := wsproto.Read(connectionReadyTimeout, ws, &connectionReadyMsg)
if err != nil {
return true, fmt.Errorf("did not receive gateway connection ready message: %w", err)
return reconnectError{fmt.Errorf("did not receive gateway connection ready message: %w", err)}
}

if connectionReadyMsg.Kind != connectproto.GatewayMessageType_GATEWAY_CONNECTION_READY {
return true, fmt.Errorf("expected gateway connection ready message, got %s", connectionReadyMsg.Kind)
return reconnectError{fmt.Errorf("expected gateway connection ready message, got %s", connectionReadyMsg.Kind)}
}

h.logger.Debug("received gateway connection ready message")
}

return false, nil
return nil
}

0 comments on commit 1abd900

Please sign in to comment.