From 460d0cdb1750cf8606d486a9e70661d4b0832650 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Fri, 19 Jul 2024 17:08:34 +0100 Subject: [PATCH] Add a mechanism to obtain immutable request parameters (#14) ID, headers, attempt headers and body should all be obtainable if you need them. --- context.go | 18 ++++++++++++++++++ internal/state/awakeable.go | 2 +- internal/state/state.go | 34 ++++++++++++++++++++++++---------- internal/state/sys.go | 8 +++++--- server/restate.go | 2 +- 5 files changed, 49 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index ac38a4a..5dea68e 100644 --- a/context.go +++ b/context.go @@ -122,6 +122,24 @@ type RunContext interface { // By default, this logger will not output messages if the invocation is currently replaying // The log handler can be set with `.WithLogger()` on the server object Log() *slog.Logger + + // Request gives extra information about the request that started this invocation + Request() *Request +} + +type Request struct { + // The unique id that identifies the current function invocation. This id is guaranteed to be + // unique across invocations, but constant across reties and suspensions. + ID []byte + // Request headers - the following headers capture the original invocation headers, as provided to + // the ingress. + Headers map[string]string + // Attempt headers - the following headers are sent by the restate runtime. + // These headers are attempt specific, generated by the restate runtime uniquely for each attempt. + // These headers might contain information such as the W3C trace context, and attempt specific information. + AttemptHeaders map[string][]string + // Raw unparsed request body + Body []byte } // After is a handle on a Sleep operation which allows you to do other work concurrently diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 85f6671..7eb3a9f 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -18,7 +18,7 @@ func (c *Machine) awakeable() *futures.Awakeable { c._awakeable, ) - return futures.NewAwakeable(c.suspensionCtx, c.id, entry, entryIndex) + return futures.NewAwakeable(c.suspensionCtx, c.request.ID, entry, entryIndex) } func (c *Machine) _awakeable() *wire.AwakeableEntryMessage { diff --git a/internal/state/state.go b/internal/state/state.go index be95066..bc8f1d1 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -47,6 +47,10 @@ func (c *Context) Log() *slog.Logger { return c.machine.userLog } +func (c *Context) Request() *restate.Request { + return &c.machine.request +} + func (c *Context) Rand() *rand.Rand { return c.machine.rand } @@ -266,8 +270,8 @@ type Machine struct { protocol *wire.Protocol // state - id []byte - key string + key string + request restate.Request partial bool current map[string][]byte @@ -289,12 +293,15 @@ type Machine struct { failure any } -func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine { +func NewMachine(handler restate.Handler, conn io.ReadWriter, attemptHeaders map[string][]string) *Machine { m := &Machine{ handler: handler, current: make(map[string][]byte), pendingAcks: map[uint32]wire.AckableMessage{}, pendingCompletions: map[uint32]wire.CompleteableMessage{}, + request: restate.Request{ + AttemptHeaders: attemptHeaders, + }, } m.protocol = wire.NewProtocol(conn) return m @@ -317,8 +324,8 @@ func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler s m.ctx = inner m.suspensionCtx, m.suspend = context.WithCancelCause(m.ctx) - m.id = start.Id - m.rand = rand.New(m.id) + m.request.ID = start.Id + m.rand = rand.New(m.request.ID) m.key = start.Key logHandler = logHandler.WithAttrs([]slog.Attr{slog.String("invocationID", start.DebugId)}) @@ -331,7 +338,7 @@ func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler s return m.process(ctx, start) } -func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error { +func (m *Machine) invoke(ctx *Context, outputSeen bool) error { // always terminate the invocation with // an end message. // this will always terminate the connection @@ -485,9 +492,9 @@ The journal entry at position %d was: var err error switch handler := m.handler.(type) { case restate.ObjectHandler: - bytes, err = handler.Call(ctx, input) + bytes, err = handler.Call(ctx, m.request.Body) case restate.ServiceHandler: - bytes, err = handler.Call(ctx, input) + bytes, err = handler.Call(ctx, m.request.Body) } if err != nil && restate.IsTerminalError(err) { @@ -580,9 +587,16 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { go m.handleCompletionsAcks() inputMsg := msg.(*wire.InputEntryMessage) - value := inputMsg.GetValue() - return m.invoke(ctx, value, outputSeen) + m.request.Body = inputMsg.GetValue() + + if len(inputMsg.GetHeaders()) > 0 { + m.request.Headers = make(map[string]string, len(inputMsg.Headers)) + for _, header := range inputMsg.Headers { + m.request.Headers[header.Key] = header.Value + } + } + return m.invoke(ctx, outputSeen) } func (c *Machine) currentEntry() (wire.Message, bool) { diff --git a/internal/state/sys.go b/internal/state/sys.go index 0409049..279f121 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -307,13 +307,15 @@ func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, erro type runContext struct { context.Context - log *slog.Logger + log *slog.Logger + request *restate.Request } -func (r runContext) Log() *slog.Logger { return r.log } +func (r runContext) Log() *slog.Logger { return r.log } +func (r runContext) Request() *restate.Request { return r.request } func (m *Machine) _run(fn func(restate.RunContext) ([]byte, error)) *wire.RunEntryMessage { - bytes, err := fn(runContext{m.ctx, m.userLog}) + bytes, err := fn(runContext{m.ctx, m.userLog, &m.request}) if err != nil { if restate.IsTerminalError(err) { diff --git a/server/restate.go b/server/restate.go index d40f5f0..d5e493b 100644 --- a/server/restate.go +++ b/server/restate.go @@ -262,7 +262,7 @@ func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVer defer conn.Close() - machine := state.NewMachine(handler, conn) + machine := state.NewMachine(handler, conn, request.Header) if err := machine.Start(request.Context(), r.dropReplayLogs, r.logHandler); err != nil { r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Failed to handle invocation", log.Error(err))