From 06ec30ba080c6b6838d099c3922af02fac97627c Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Tue, 24 Sep 2024 18:41:17 +0100 Subject: [PATCH 1/4] Add ingress clients --- client/client.go | 508 ++++++++++++++++++++ client/workflow.go | 83 ++++ examples/ticketreservation/client/client.go | 65 +++ facilitators.go | 2 +- internal/options/options.go | 53 ++ options.go | 43 +- 6 files changed, 752 insertions(+), 2 deletions(-) create mode 100644 client/client.go create mode 100644 client/workflow.go create mode 100644 examples/ticketreservation/client/client.go diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..dc2194d --- /dev/null +++ b/client/client.go @@ -0,0 +1,508 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/restatedev/sdk-go/encoding" + "github.com/restatedev/sdk-go/internal/options" +) + +type ingressContextKey struct{} + +func Connect(ctx context.Context, ingressURL string, opts ...options.ConnectOption) (context.Context, error) { + o := options.ConnectOptions{} + for _, opt := range opts { + opt.BeforeConnect(&o) + } + if o.Client == nil { + o.Client = http.DefaultClient + } + + url, err := url.Parse(ingressURL) + if err != nil { + return nil, err + } + + resp, err := o.Client.Get(url.JoinPath("restate", "health").String()) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ingress is not healthy: status %d", resp.StatusCode) + } + + return context.WithValue(ctx, ingressContextKey{}, &connection{o, url}), nil +} + +type connection struct { + options.ConnectOptions + url *url.URL +} + +func fromContext(ctx context.Context) (*connection, bool) { + if val := ctx.Value(ingressContextKey{}); val != nil { + c, ok := val.(*connection) + return c, ok + } + return nil, false +} + +func fromContextOrPanic(ctx context.Context) *connection { + conn, ok := fromContext(ctx) + if !ok { + panic("Not connected to Restate ingress; provided ctx must have been returned from client.Connect") + } + + return conn +} + +// Client represents all the different ways you can invoke a particular service-method. +type IngressClient[I any, O any] interface { + // RequestFuture makes a call and returns a handle on a future response + RequestFuture(input I, options ...options.RequestOption) IngressResponseFuture[O] + // Request makes a call and blocks on getting the response + Request(input I, options ...options.RequestOption) (O, error) + IngressSendClient[I] +} + +type ingressClient[I any, O any] struct { + ctx context.Context + conn *connection + opts options.IngressClientOptions + url *url.URL +} + +func (c *ingressClient[I, O]) RequestFuture(input I, opts ...options.RequestOption) IngressResponseFuture[O] { + o := options.RequestOptions{} + for _, opt := range opts { + opt.BeforeRequest(&o) + } + + headers := make(http.Header, len(c.conn.Headers)+len(o.Headers)+1) + for k, v := range c.conn.Headers { + headers.Add(k, v) + } + for k, v := range o.Headers { + headers.Add(k, v) + } + if o.IdempotencyKey != "" { + headers.Set("Idempotency-Key", o.IdempotencyKey) + } + + done := make(chan struct{}) + f := &ingressResponseFuture[O]{done: done, codec: c.opts.Codec} + go func() { + defer close(done) + + data, err := encoding.Marshal(c.opts.Codec, input) + if err != nil { + f.r.err = err + return + } + + if len(data) > 0 { + var i I + if p := encoding.InputPayloadFor(c.opts.Codec, i); p != nil && p.ContentType != nil { + headers.Add("Content-Type", *p.ContentType) + } + } + + request, err := http.NewRequestWithContext(c.ctx, "POST", c.url.String(), bytes.NewReader(data)) + if err != nil { + f.r.err = err + return + } + request.Header = headers + + f.r.Response, f.r.err = c.conn.Client.Do(request) + }() + + return f +} + +func (c *ingressClient[I, O]) Request(input I, opts ...options.RequestOption) (O, error) { + return c.RequestFuture(input, opts...).Response() +} + +func (c *ingressClient[I, O]) Send(input I, opts ...options.SendOption) (Send, error) { + o := options.SendOptions{} + for _, opt := range opts { + opt.BeforeSend(&o) + } + + headers := make(http.Header, len(c.conn.Headers)+len(o.Headers)+2) + for k, v := range c.conn.Headers { + headers.Add(k, v) + } + for k, v := range o.Headers { + headers.Add(k, v) + } + if o.IdempotencyKey != "" { + headers.Set("Idempotency-Key", o.IdempotencyKey) + } + + data, err := encoding.Marshal(c.opts.Codec, input) + if err != nil { + return Send{}, err + } + + if len(data) > 0 { + var i I + if p := encoding.InputPayloadFor(c.opts.Codec, i); p != nil && p.ContentType != nil { + headers.Add("Content-Type", *p.ContentType) + } + } + + url := c.url.JoinPath("send") + url.Query().Add("delay", fmt.Sprintf("%dms", o.Delay.Milliseconds())) + + request, err := http.NewRequestWithContext(c.ctx, "POST", url.String(), bytes.NewReader(data)) + if err != nil { + return Send{}, err + } + request.Header = headers + + resp, err := c.conn.Client.Do(request) + if err != nil { + return Send{}, err + } + + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := io.ReadAll(resp.Body) + return Send{}, fmt.Errorf("Send request failed: status %d\n%s", resp.StatusCode, string(body)) + } + + bytes, err := io.ReadAll(resp.Body) + if err != nil { + return Send{}, err + } + + send := Send{Attachable: o.IdempotencyKey != ""} + return send, encoding.Unmarshal(encoding.JSONCodec, bytes, &send) +} + +// IngressResponseFuture is a handle on a potentially not-yet completed outbound call. +type IngressResponseFuture[O any] interface { + // Response blocks on the response to the call and returns it + Response() (O, error) +} + +type ingressResponseFuture[O any] struct { + done <-chan struct{} // guards access to r + r response + codec encoding.Codec +} + +func (f *ingressResponseFuture[O]) Response() (o O, err error) { + <-f.done + + return o, f.r.Decode(f.codec, &o) +} + +type response struct { + *http.Response + err error +} + +func (r *response) Decode(codec encoding.Codec, v any) error { + if r.err != nil { + return r.err + } + + defer r.Body.Close() + + if r.StatusCode < 200 || r.StatusCode > 299 { + body, _ := io.ReadAll(r.Body) + return fmt.Errorf("Request failed: status %d\n%s", r.StatusCode, string(body)) + } + + bytes, err := io.ReadAll(r.Body) + if err != nil { + return err + } + + return encoding.Unmarshal(codec, bytes, v) +} + +// IngressSendClient allows making one-way invocations +type IngressSendClient[I any] interface { + // Send makes a one-way call which is executed in the background + Send(input I, options ...options.SendOption) (Send, error) +} + +type SendStatus uint16 + +const ( + SendStatusUnknown SendStatus = iota + SendStatusAccepted + SendStatusPreviouslyAccepted +) + +func (s *SendStatus) UnmarshalJSON(data []byte) (err error) { + var ss string + if err := json.Unmarshal(data, &ss); err != nil { + return err + } + switch ss { + case "Accepted": + *s = SendStatusAccepted + case "PreviouslyAccepted": + *s = SendStatusPreviouslyAccepted + default: + *s = SendStatusUnknown + } + return nil +} + +// Send is an object describing a submitted invocation to Restate, which can be attached to with [Attach] +type Send struct { + InvocationId string `json:"invocationID"` + Status SendStatus `json:"status"` + Attachable bool `json:"-"` +} + +func (s Send) attachable() bool { + return s.Attachable +} + +func (s Send) attachUrl(connURL *url.URL) *url.URL { + return connURL.JoinPath("restate", "invocation", s.InvocationId, "attach") +} + +func (s Send) outputUrl(connURL *url.URL) *url.URL { + return connURL.JoinPath("restate", "invocation", s.InvocationId, "output") +} + +var _ Attacher = Send{} + +// Attacher is implemented by [Send], [WorkflowSubmission] and [WorkflowIdentifier] +type Attacher interface { + attachable() bool + attachUrl(connURL *url.URL) *url.URL + outputUrl(connURL *url.URL) *url.URL +} + +// Attach attaches to the attachable invocation and returns its response. The invocation must have been created with an idempotency key +// or by a workflow submission. +// It must be called with a context returned from [Connect] +func Attach[O any](ctx context.Context, attacher Attacher, opts ...options.IngressClientOption) (o O, err error) { + conn := fromContextOrPanic(ctx) + + if !attacher.attachable() { + return o, fmt.Errorf("Unable to fetch the result.\nA service's result is stored only when an idempotencyKey is supplied when invoking the service, or if its a workflow submission.") + } + + os := options.IngressClientOptions{} + for _, opt := range opts { + opt.BeforeIngressClient(&os) + } + if os.Codec == nil { + os.Codec = encoding.JSONCodec + } + + headers := make(http.Header, len(conn.Headers)+1) + for k, v := range conn.Headers { + headers.Add(k, v) + } + + url := attacher.attachUrl(conn.url) + + request, err := http.NewRequestWithContext(ctx, "GET", url.String(), nil) + if err != nil { + return o, err + } + request.Header = headers + + resp, err := conn.Client.Do(request) + if err != nil { + return o, err + } + + return o, (&response{Response: resp}).Decode(os.Codec, &o) +} + +// GetOutput gets the output of the attachable invocation and returns its response if it has completed. The invocation must have been created with an idempotency key +// or by a workflow submission. +// It must be called with a context returned from [Connect]. +func GetOutput[O any](ctx context.Context, attacher Attacher, opts ...options.IngressClientOption) (o O, ready bool, err error) { + conn := fromContextOrPanic(ctx) + + if !attacher.attachable() { + return o, false, fmt.Errorf("Unable to fetch the result.\nA service's result is stored only when an idempotencyKey is supplied when invoking the service, or if its a workflow submission.") + } + + os := options.IngressClientOptions{} + for _, opt := range opts { + opt.BeforeIngressClient(&os) + } + if os.Codec == nil { + os.Codec = encoding.JSONCodec + } + + headers := make(http.Header, len(conn.Headers)+1) + for k, v := range conn.Headers { + headers.Add(k, v) + } + + url := attacher.outputUrl(conn.url) + + request, err := http.NewRequestWithContext(ctx, "GET", url.String(), nil) + if err != nil { + return o, false, err + } + request.Header = headers + + resp, err := conn.Client.Do(request) + if err != nil { + return o, false, err + } + defer resp.Body.Close() + + if resp.StatusCode == 470 { + // special status code used by restate to say that the result is not ready + return o, false, nil + } + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := io.ReadAll(resp.Body) + return o, false, fmt.Errorf("Request failed: status %d\n%s", resp.StatusCode, string(body)) + } + + bytes, err := io.ReadAll(resp.Body) + if err != nil { + return o, false, err + } + + if err := encoding.Unmarshal(os.Codec, bytes, &o); err != nil { + return o, false, err + } + + return o, true, nil +} + +func getClient[O any](ctx context.Context, conn *connection, url *url.URL, opts ...options.IngressClientOption) IngressClient[any, O] { + o := options.IngressClientOptions{} + for _, opt := range opts { + opt.BeforeIngressClient(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + return &ingressClient[any, O]{ctx, conn, o, url} +} + +// Service gets a Service request client by service and method name +// It must be called with a context returned from [Connect] +func Service[O any](ctx context.Context, service string, method string, opts ...options.IngressClientOption) IngressClient[any, O] { + conn := fromContextOrPanic(ctx) + url := conn.url.JoinPath(service, method) + return getClient[O](ctx, conn, url, opts...) +} + +// Service gets a Service send client by service and method name +// It must be called with a context returned from [Connect] +func ServiceSend(ctx context.Context, service string, method string, opts ...options.IngressClientOption) IngressSendClient[any] { + return Service[any](ctx, service, method, opts...) +} + +// Object gets an Object request client by service name, key and method name +// It must be called with a context returned from [Connect] +func Object[O any](ctx context.Context, service string, key string, method string, opts ...options.IngressClientOption) IngressClient[any, O] { + conn := fromContextOrPanic(ctx) + url := conn.url.JoinPath(service, key, method) + + return getClient[O](ctx, conn, url, opts...) +} + +// ObjectSend gets an Object send client by service name, key and method name +// It must be called with a context returned from [Connect] +func ObjectSend(ctx context.Context, service string, key string, method string, opts ...options.IngressClientOption) IngressSendClient[any] { + return Object[any](ctx, service, key, method, opts...) +} + +// ResolveAwakeable allows an awakeable to be resolved with a particular value. +// It must be called with a context returned from [Connect] +func ResolveAwakeable[T any](ctx context.Context, id string, value T, opts ...options.ResolveAwakeableOption) error { + conn := fromContextOrPanic(ctx) + + o := options.ResolveAwakeableOptions{} + for _, opt := range opts { + opt.BeforeResolveAwakeable(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + headers := make(http.Header, len(conn.Headers)+1) + for k, v := range conn.Headers { + headers.Add(k, v) + } + + data, err := encoding.Marshal(o.Codec, value) + if err != nil { + return err + } + + url := conn.url.JoinPath("restate", "a", id, "resolve") + + request, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(data)) + if err != nil { + return err + } + request.Header = headers + + resp, err := conn.Client.Do(request) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("Resolve awakeable request failed: status %d\n%s", resp.StatusCode, string(body)) + } + return nil +} + +// RejectAwakeable allows an awakeable to be rejected with a particular error. +// It must be called with a context returned from [Connect] +func RejectAwakeable(ctx context.Context, id string, reason error) error { + conn := fromContextOrPanic(ctx) + + headers := make(http.Header, len(conn.Headers)+1) + for k, v := range conn.Headers { + headers.Add(k, v) + } + + data := []byte(reason.Error()) + + url := conn.url.JoinPath("restate", "a", id, "reject") + + request, err := http.NewRequestWithContext(ctx, "POST", url.String(), bytes.NewReader(data)) + if err != nil { + return err + } + request.Header = headers + + resp, err := conn.Client.Do(request) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("Reject awakeable request failed: status %d\n%s", resp.StatusCode, string(body)) + } + return nil +} diff --git a/client/workflow.go b/client/workflow.go new file mode 100644 index 0000000..b1f1e31 --- /dev/null +++ b/client/workflow.go @@ -0,0 +1,83 @@ +package client + +import ( + "context" + "net/url" + + "github.com/restatedev/sdk-go/internal/options" +) + +// Workflow gets a Workflow request client by service name, workflow ID and method name +// It must be called with a context returned from [Connect] +func Workflow[O any](ctx context.Context, service string, workflowID string, method string, opts ...options.IngressClientOption) IngressClient[any, O] { + return Object[O](ctx, service, workflowID, method, opts...) +} + +type WorkflowSubmission struct { + InvocationId string +} + +func (w WorkflowSubmission) attachable() bool { + return true +} + +func (w WorkflowSubmission) attachUrl(connURL *url.URL) *url.URL { + return connURL.JoinPath("restate", "invocation", w.InvocationId, "attach") +} + +func (w WorkflowSubmission) outputUrl(connURL *url.URL) *url.URL { + return connURL.JoinPath("restate", "invocation", w.InvocationId, "output") +} + +var _ Attacher = WorkflowSubmission{} + +// WorkflowSubmit submits a workflow, defaulting to 'Run' as the main handler name, but this is configurable with [restate.WithWorkflowRun] +// It must be called with a context returned from [Connect] +func WorkflowSubmit[I any](ctx context.Context, service string, workflowID string, input I, opts ...options.WorkflowSubmitOption) (WorkflowSubmission, error) { + os := options.WorkflowSubmitOptions{} + for _, opt := range opts { + opt.BeforeWorkflowSubmit(&os) + } + if os.RunHandler == "" { + os.RunHandler = "Run" + } + + send, err := Workflow[I](ctx, service, workflowID, os.RunHandler, os).Send(input, os) + if err != nil { + return WorkflowSubmission{}, err + } + return WorkflowSubmission{InvocationId: send.InvocationId}, nil +} + +type WorkflowIdentifier struct { + Service string + WorkflowID string +} + +var _ Attacher = WorkflowIdentifier{} + +func (w WorkflowIdentifier) attachable() bool { + return true +} + +func (w WorkflowIdentifier) attachUrl(connURL *url.URL) *url.URL { + return connURL.JoinPath("restate", "workflow", w.Service, w.WorkflowID, "attach") +} + +func (w WorkflowIdentifier) outputUrl(connURL *url.URL) *url.URL { + return connURL.JoinPath("restate", "workflow", w.Service, w.WorkflowID, "output") +} + +// WorkflowAttach attaches to a workflow, waiting for it to complete and returning the result. +// It is only possible to 'attach' to a workflow that has been previously submitted. +// This operation is safe to retry many times, and it will always return the same result. +// It must be called with a context returned from [Connect] +func WorkflowAttach[O any](ctx context.Context, service string, workflowID string, opts ...options.IngressClientOption) (O, error) { + return Attach[O](ctx, WorkflowIdentifier{service, workflowID}, opts...) +} + +// WorkflowOutput tries to retrieve the output of a workflow if it has already completed. Otherwise, [ready] will be false. +// It must be called with a context returned from [Connect] +func WorkflowOutput[O any](ctx context.Context, service string, workflowID string, opts ...options.IngressClientOption) (o O, ready bool, err error) { + return GetOutput[O](ctx, WorkflowIdentifier{service, workflowID}, opts...) +} diff --git a/examples/ticketreservation/client/client.go b/examples/ticketreservation/client/client.go new file mode 100644 index 0000000..7226af0 --- /dev/null +++ b/examples/ticketreservation/client/client.go @@ -0,0 +1,65 @@ +package main + +import ( + "context" + "fmt" + + restate "github.com/restatedev/sdk-go" + "github.com/restatedev/sdk-go/client" +) + +func main() { + ctx, err := client.Connect(context.Background(), "http://127.0.0.1:8080") + if err != nil { + panic(err) + } + + if ok, err := AddTicketSend(ctx, "user-1", "ticket-1"); err != nil { + panic(err) + } else if !ok { + fmt.Println("Ticket-1 was not available") + } else { + fmt.Println("Added ticket-1 to user-1 basket") + } + + if ok, err := Checkout(ctx, "user-1", "ticket-1"); err != nil { + panic(err) + } else if !ok { + fmt.Println("Nothing to check out") + } else { + fmt.Println("Checked out") + } +} + +func AddTicket(ctx context.Context, userId, ticketId string) (bool, error) { + return client. + Object[bool](ctx, "UserSession", userId, "AddTicket"). + Request(ticketId) +} + +func AddTicketSend(ctx context.Context, userId, ticketId string) (bool, error) { + send, err := client. + Object[bool](ctx, "UserSession", userId, "AddTicket"). + Send(ticketId, restate.WithIdempotencyKey(fmt.Sprintf("%s/%s", userId, ticketId))) + if err != nil { + return false, err + } + + fmt.Println("Submitted AddTicket with ID", send.InvocationId) + + o, ready, err := client.GetOutput[bool](ctx, send) + if err != nil { + return false, err + } + if ready { + return o, nil + } + + return client.Attach[bool](ctx, send) +} + +func Checkout(ctx context.Context, userId, ticketId string) (bool, error) { + return client. + Object[bool](ctx, "UserSession", userId, "Checkout"). + Request(restate.Void{}) +} diff --git a/facilitators.go b/facilitators.go index 11fa196..ed2ee4a 100644 --- a/facilitators.go +++ b/facilitators.go @@ -176,7 +176,7 @@ func ResolveAwakeable[T any](ctx Context, id string, value T, options ...options ctx.inner().ResolveAwakeable(id, value, options...) } -// ResolveAwakeable allows an awakeable (not necessarily from this service) to be +// RejectAwakeable allows an awakeable (not necessarily from this service) to be // rejected with a particular error. func RejectAwakeable(ctx Context, id string, reason error) { ctx.inner().RejectAwakeable(id, reason) diff --git a/internal/options/options.go b/internal/options/options.go index 278a1c7..5d8e287 100644 --- a/internal/options/options.go +++ b/internal/options/options.go @@ -1,6 +1,7 @@ package options import ( + "net/http" "time" "github.com/restatedev/sdk-go/encoding" @@ -54,8 +55,18 @@ type ClientOption interface { BeforeClient(*ClientOptions) } +type IngressClientOptions struct { + Codec encoding.PayloadCodec +} + +type IngressClientOption interface { + BeforeIngressClient(*IngressClientOptions) +} + type RequestOptions struct { Headers map[string]string + // IdempotencyKey is currently only supported in ingress clients + IdempotencyKey string } type RequestOption interface { @@ -65,12 +76,45 @@ type RequestOption interface { type SendOptions struct { Headers map[string]string Delay time.Duration + // IdempotencyKey is currently only supported in ingress clients + IdempotencyKey string } type SendOption interface { BeforeSend(*SendOptions) } +type WorkflowSubmitOptions struct { + IngressClientOptions + SendOptions + RunHandler string +} + +var _ SendOption = WorkflowSubmitOptions{} +var _ IngressClientOption = WorkflowSubmitOptions{} + +func (w WorkflowSubmitOptions) BeforeSend(opts *SendOptions) { + if w.SendOptions.Headers != nil { + opts.Headers = w.SendOptions.Headers + } + if w.SendOptions.Delay != 0 { + opts.Delay = w.SendOptions.Delay + } + if w.SendOptions.IdempotencyKey != "" { + opts.IdempotencyKey = w.SendOptions.IdempotencyKey + } +} + +func (w WorkflowSubmitOptions) BeforeIngressClient(opts *IngressClientOptions) { + if w.IngressClientOptions.Codec != nil { + opts.Codec = w.IngressClientOptions.Codec + } +} + +type WorkflowSubmitOption interface { + BeforeWorkflowSubmit(*WorkflowSubmitOptions) +} + type RunOptions struct { Codec encoding.Codec } @@ -94,3 +138,12 @@ type ServiceDefinitionOptions struct { type ServiceDefinitionOption interface { BeforeServiceDefinition(*ServiceDefinitionOptions) } + +type ConnectOptions struct { + Headers map[string]string + Client *http.Client +} + +type ConnectOption interface { + BeforeConnect(*ConnectOptions) +} diff --git a/options.go b/options.go index 88621ff..e48b8cb 100644 --- a/options.go +++ b/options.go @@ -46,6 +46,7 @@ type withPayloadCodec struct { var _ options.HandlerOption = withPayloadCodec{} var _ options.ServiceDefinitionOption = withPayloadCodec{} +var _ options.IngressClientOption = withPayloadCodec{} func (w withPayloadCodec) BeforeHandler(opts *options.HandlerOptions) { opts.Codec = w.codec @@ -53,6 +54,9 @@ func (w withPayloadCodec) BeforeHandler(opts *options.HandlerOptions) { func (w withPayloadCodec) BeforeServiceDefinition(opts *options.ServiceDefinitionOptions) { opts.DefaultCodec = w.codec } +func (w withPayloadCodec) BeforeIngressClient(opts *options.IngressClientOptions) { + opts.Codec = w.codec +} // WithPayloadCodec is an option that can be provided to handler/service options // in order to specify a custom [encoding.PayloadCodec] with which to (de)serialise and @@ -105,7 +109,44 @@ func (w withDelay) BeforeSend(opts *options.SendOptions) { opts.Delay = w.delay } -// WithDelay is an [SendOption] to specify the duration to delay the request +// WithDelay is a [SendOption] to specify the duration to delay the request func WithDelay(delay time.Duration) withDelay { return withDelay{delay} } + +type withIdempotencyKey struct { + idempotencyKey string +} + +var _ options.RequestOption = withIdempotencyKey{} +var _ options.SendOption = withIdempotencyKey{} + +func (w withIdempotencyKey) BeforeRequest(opts *options.RequestOptions) { + opts.IdempotencyKey = w.idempotencyKey +} + +func (w withIdempotencyKey) BeforeSend(opts *options.SendOptions) { + opts.IdempotencyKey = w.idempotencyKey +} + +// WithIdempotencyKey is a [SendOption] to specify an idempotency key for the request +// Currently this key is only used by the ingress client +func WithIdempotencyKey(idempotencyKey string) withIdempotencyKey { + return withIdempotencyKey{idempotencyKey} +} + +type withWorkflowRun struct { + runHandler string +} + +var _ options.WorkflowSubmitOption = withWorkflowRun{} + +func (w withWorkflowRun) BeforeWorkflowSubmit(opts *options.WorkflowSubmitOptions) { + opts.RunHandler = w.runHandler +} + +// WithWorkflowRun is a [WorkflowSubmitOption] to specify an idempotency key for the request +// Currently this key is only used by the ingress client +func WithWorkflowRun(idempotencyKey string) withIdempotencyKey { + return withIdempotencyKey{idempotencyKey} +} From 1615efa79c895135da7fa01ca581dc144032a12e Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 7 Nov 2024 13:02:25 +0000 Subject: [PATCH 2/4] Fix WithWorkflowRun --- options.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/options.go b/options.go index e48b8cb..b8bbf70 100644 --- a/options.go +++ b/options.go @@ -129,7 +129,7 @@ func (w withIdempotencyKey) BeforeSend(opts *options.SendOptions) { opts.IdempotencyKey = w.idempotencyKey } -// WithIdempotencyKey is a [SendOption] to specify an idempotency key for the request +// WithIdempotencyKey is an option to specify an idempotency key for the request // Currently this key is only used by the ingress client func WithIdempotencyKey(idempotencyKey string) withIdempotencyKey { return withIdempotencyKey{idempotencyKey} @@ -145,8 +145,8 @@ func (w withWorkflowRun) BeforeWorkflowSubmit(opts *options.WorkflowSubmitOption opts.RunHandler = w.runHandler } -// WithWorkflowRun is a [WorkflowSubmitOption] to specify an idempotency key for the request -// Currently this key is only used by the ingress client -func WithWorkflowRun(idempotencyKey string) withIdempotencyKey { - return withIdempotencyKey{idempotencyKey} +// WithWorkflowRun is a [WorkflowSubmitOption] to specify a different handler name than 'Run' for the +// workflows main handler. +func WithWorkflowRun(runHandler string) withWorkflowRun { + return withWorkflowRun{runHandler} } From 07b19db35312a48a653156ad4eb5a57b95954d85 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 7 Nov 2024 15:51:04 +0000 Subject: [PATCH 3/4] proto clients --- client/client.go | 26 ++++ .../codegen/proto/helloworld_restate.pb.go | 87 ++++++++++++ protoc-gen-go-restate/restate.go | 129 +++++++++++++++--- 3 files changed, 226 insertions(+), 16 deletions(-) diff --git a/client/client.go b/client/client.go index dc2194d..3e1b8f2 100644 --- a/client/client.go +++ b/client/client.go @@ -13,6 +13,9 @@ import ( "github.com/restatedev/sdk-go/internal/options" ) +// re-export for use in generated code +type IngressClientOption = options.IngressClientOption + type ingressContextKey struct{} func Connect(ctx context.Context, ingressURL string, opts ...options.ConnectOption) (context.Context, error) { @@ -506,3 +509,26 @@ func RejectAwakeable(ctx context.Context, id string, reason error) error { } return nil } + +type withRequestType[I any, O any] struct { + inner IngressClient[any, O] +} + +func (w withRequestType[I, O]) Request(input I, options ...options.RequestOption) (O, error) { + return w.inner.Request(input, options...) +} + +func (w withRequestType[I, O]) RequestFuture(input I, options ...options.RequestOption) IngressResponseFuture[O] { + return w.inner.RequestFuture(input, options...) +} + +func (w withRequestType[I, O]) Send(input I, options ...options.SendOption) (Send, error) { + return w.inner.Send(input, options...) +} + +// WithRequestType is primarily intended to be called from generated code, to provide +// type safety of input types. In other contexts it's generally less cumbersome to use [Object] and [Service], +// as the output type can be inferred. +func WithRequestType[I any, O any](inner IngressClient[any, O]) IngressClient[I, O] { + return withRequestType[I, O]{inner} +} diff --git a/examples/codegen/proto/helloworld_restate.pb.go b/examples/codegen/proto/helloworld_restate.pb.go index 19d78f0..09a95b5 100644 --- a/examples/codegen/proto/helloworld_restate.pb.go +++ b/examples/codegen/proto/helloworld_restate.pb.go @@ -7,8 +7,10 @@ package proto import ( + context "context" fmt "fmt" sdk_go "github.com/restatedev/sdk-go" + client "github.com/restatedev/sdk-go/client" ) // GreeterClient is the client API for Greeter service. @@ -36,6 +38,32 @@ func (c *greeterClient) SayHello(opts ...sdk_go.ClientOption) sdk_go.Client[*Hel return sdk_go.WithRequestType[*HelloRequest](sdk_go.Service[*HelloResponse](c.ctx, "Greeter", "SayHello", cOpts...)) } +// GreeterIngressClient is the ingress client API for Greeter service. +type GreeterIngressClient interface { + SayHello(opts ...client.IngressClientOption) client.IngressClient[*HelloRequest, *HelloResponse] +} + +type greeterIngressClient struct { + ctx context.Context + options []client.IngressClientOption +} + +// NewGreeterIngressClient must be called with a ctx returned from github.com/restatedev/sdk-go/client.Connect +func NewGreeterIngressClient(ctx context.Context, opts ...client.IngressClientOption) GreeterIngressClient { + cOpts := append([]client.IngressClientOption{sdk_go.WithProtoJSON}, opts...) + return &greeterIngressClient{ + ctx, + cOpts, + } +} +func (c *greeterIngressClient) SayHello(opts ...client.IngressClientOption) client.IngressClient[*HelloRequest, *HelloResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*HelloRequest](client.Service[*HelloResponse](c.ctx, "Greeter", "SayHello", cOpts...)) +} + // GreeterServer is the server API for Greeter service. // All implementations should embed UnimplementedGreeterServer // for forward compatibility. @@ -134,6 +162,65 @@ func (c *counterClient) Watch(opts ...sdk_go.ClientOption) sdk_go.Client[*WatchR return sdk_go.WithRequestType[*WatchRequest](sdk_go.Object[*GetResponse](c.ctx, "Counter", c.key, "Watch", cOpts...)) } +// CounterIngressClient is the ingress client API for Counter service. +type CounterIngressClient interface { + // Mutate the value + Add(opts ...client.IngressClientOption) client.IngressClient[*AddRequest, *GetResponse] + // Get the current value + Get(opts ...client.IngressClientOption) client.IngressClient[*GetRequest, *GetResponse] + // Internal method to store an awakeable ID for the Watch method + AddWatcher(opts ...client.IngressClientOption) client.IngressClient[*AddWatcherRequest, *AddWatcherResponse] + // Wait for the counter to change and then return the new value + Watch(opts ...client.IngressClientOption) client.IngressClient[*WatchRequest, *GetResponse] +} + +type counterIngressClient struct { + ctx context.Context + key string + options []client.IngressClientOption +} + +// NewCounterIngressClient must be called with a ctx returned from github.com/restatedev/sdk-go/client.Connect +func NewCounterIngressClient(ctx context.Context, key string, opts ...client.IngressClientOption) CounterIngressClient { + cOpts := append([]client.IngressClientOption{sdk_go.WithProtoJSON}, opts...) + return &counterIngressClient{ + ctx, + key, + cOpts, + } +} +func (c *counterIngressClient) Add(opts ...client.IngressClientOption) client.IngressClient[*AddRequest, *GetResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*AddRequest](client.Object[*GetResponse](c.ctx, "Counter", c.key, "Add", cOpts...)) +} + +func (c *counterIngressClient) Get(opts ...client.IngressClientOption) client.IngressClient[*GetRequest, *GetResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*GetRequest](client.Object[*GetResponse](c.ctx, "Counter", c.key, "Get", cOpts...)) +} + +func (c *counterIngressClient) AddWatcher(opts ...client.IngressClientOption) client.IngressClient[*AddWatcherRequest, *AddWatcherResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*AddWatcherRequest](client.Object[*AddWatcherResponse](c.ctx, "Counter", c.key, "AddWatcher", cOpts...)) +} + +func (c *counterIngressClient) Watch(opts ...client.IngressClientOption) client.IngressClient[*WatchRequest, *GetResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*WatchRequest](client.Object[*GetResponse](c.ctx, "Counter", c.key, "Watch", cOpts...)) +} + // CounterServer is the server API for Counter service. // All implementations should embed UnimplementedCounterServer // for forward compatibility. diff --git a/protoc-gen-go-restate/restate.go b/protoc-gen-go-restate/restate.go index b78db68..fb4aaa7 100644 --- a/protoc-gen-go-restate/restate.go +++ b/protoc-gen-go-restate/restate.go @@ -12,8 +12,10 @@ import ( ) const ( - fmtPackage = protogen.GoImportPath("fmt") - sdkPackage = protogen.GoImportPath("github.com/restatedev/sdk-go") + fmtPackage = protogen.GoImportPath("fmt") + contextPackage = protogen.GoImportPath("context") + sdkPackage = protogen.GoImportPath("github.com/restatedev/sdk-go") + clientPackage = protogen.GoImportPath("github.com/restatedev/sdk-go/client") ) type serviceGenerateHelper struct{} @@ -32,6 +34,17 @@ func generateClientStruct(g *protogen.GeneratedFile, service *protogen.Service, g.P("}") } +func generateIngressClientStruct(g *protogen.GeneratedFile, service *protogen.Service, clientName string) { + g.P("type ", unexport(clientName), " struct {") + g.P("ctx ", contextPackage.Ident("Context")) + serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) + if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + g.P("key string") + } + g.P("options []", clientPackage.Ident("IngressClientOption")) + g.P("}") +} + func generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string) { g.P("cOpts := append([]", sdkPackage.Ident("ClientOption"), "{", sdkPackage.Ident("WithProtoJSON"), "}, opts...)") g.P("return &", unexport(clientName), "{") @@ -47,6 +60,18 @@ func generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.S g.P("}") } +func generateNewIngressClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string) { + g.P("cOpts := append([]", clientPackage.Ident("IngressClientOption"), "{", sdkPackage.Ident("WithProtoJSON"), "}, opts...)") + g.P("return &", unexport(clientName), "{") + g.P("ctx,") + serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) + if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + g.P("key,") + } + g.P("cOpts,") + g.P("}") +} + func generateUnimplementedServerType(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protogen.Service) { serverType := service.GoName + "Server" mustOrShould := "must" @@ -162,7 +187,7 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog g.P(deprecationComment) } g.P(method.Comments.Leading, - clientSignature(g, method)) + clientSignature(g, method, false)) } g.P("}") g.P() @@ -188,17 +213,66 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog generateNewClientDefinitions(g, service, clientName) g.P("}") - var methodIndex int // Client method implementations. for _, method := range service.Methods { if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { - genClientMethod(gen, g, method) - methodIndex++ + genClientMethod(gen, g, method, false) } else { gen.Error(fmt.Errorf("streaming methods are not currently supported in Restate.")) } } + // Ingress client interface. + ingressClientName := service.GoName + "IngressClient" + + g.P("// ", ingressClientName, " is the ingress client API for ", service.GoName, " service.") + g.P("//") + + // Copy comments from proto file. + genServiceComments(g, service) + + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { + g.P("//") + g.P(deprecationComment) + } + g.AnnotateSymbol(ingressClientName, protogen.Annotation{Location: service.Location}) + g.P("type ", ingressClientName, " interface {") + for _, method := range service.Methods { + g.AnnotateSymbol(ingressClientName+"."+method.GoName, protogen.Annotation{Location: method.Location}) + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { + g.P(deprecationComment) + } + g.P(method.Comments.Leading, + clientSignature(g, method, true)) + } + g.P("}") + g.P() + + // Ingress client structure. + generateIngressClientStruct(g, service, ingressClientName) + + // NewIngressClient factory. + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { + g.P(deprecationComment) + } + g.P("// New", ingressClientName, " must be called with a ctx returned from github.com/restatedev/sdk-go/client.Connect") + newIngressClientSignature := "New" + ingressClientName + " (ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + newIngressClientSignature += ", key string" + } + newIngressClientSignature += ", opts..." + g.QualifiedGoIdent(clientPackage.Ident("IngressClientOption")) + ") " + ingressClientName + + g.P("func ", newIngressClientSignature, " {") + generateNewIngressClientDefinitions(g, service, ingressClientName) + g.P("}") + + // Ingress method implementations. + for _, method := range service.Methods { + if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { + genClientMethod(gen, g, method, true) + } + } + mustOrShould := "must" if !*requireUnimplemented { mustOrShould = "should" @@ -268,41 +342,64 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog g.P() } -func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { +func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, ingress bool) string { + var optionName protogen.GoIdent + var clientName protogen.GoIdent + if ingress { + optionName = clientPackage.Ident("IngressClientOption") + clientName = clientPackage.Ident("IngressClient") + } else { + optionName = sdkPackage.Ident("ClientOption") + clientName = sdkPackage.Ident("Client") + } + s := method.GoName + "(" - s += "opts ..." + g.QualifiedGoIdent(sdkPackage.Ident("ClientOption")) + ") (" - s += g.QualifiedGoIdent(sdkPackage.Ident("Client")) + "[" + "*" + g.QualifiedGoIdent(method.Input.GoIdent) + ", *" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + s += "opts ..." + g.QualifiedGoIdent(optionName) + ") (" + s += g.QualifiedGoIdent(clientName) + "[" + "*" + g.QualifiedGoIdent(method.Input.GoIdent) + ", *" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" s += ")" return s } -func genClientMethod(gen *protogen.Plugin, g *protogen.GeneratedFile, method *protogen.Method) { +func genClientMethod(gen *protogen.Plugin, g *protogen.GeneratedFile, method *protogen.Method, ingress bool) { + var pack protogen.GoImportPath + var clientSuffix string + var optionName protogen.GoIdent + if ingress { + pack = clientPackage + clientSuffix = "IngressClient" + optionName = clientPackage.Ident("IngressClientOption") + } else { + pack = sdkPackage + clientSuffix = "Client" + optionName = sdkPackage.Ident("ClientOption") + } + service := method.Parent serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { g.P(deprecationComment) } - g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{") + g.P("func (c *", unexport(service.GoName), clientSuffix, ") ", clientSignature(g, method, ingress), "{") g.P("cOpts := c.options") g.P("if len(opts) > 0 {") - g.P("cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...)") + g.P("cOpts = append(append([]", optionName, "{}, cOpts...), opts...)") g.P("}") var getClient string switch serviceType { case sdk.ServiceType_SERVICE: - getClient = g.QualifiedGoIdent(sdkPackage.Ident("Service")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `",` + getClient = g.QualifiedGoIdent(pack.Ident("Service")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `",` case sdk.ServiceType_VIRTUAL_OBJECT: - getClient = g.QualifiedGoIdent(sdkPackage.Ident("Object")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.key,` + getClient = g.QualifiedGoIdent(pack.Ident("Object")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.key,` case sdk.ServiceType_WORKFLOW: - getClient = g.QualifiedGoIdent(sdkPackage.Ident("Workflow")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.workflowID,` + getClient = g.QualifiedGoIdent(pack.Ident("Workflow")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.workflowID,` default: gen.Error(fmt.Errorf("Unexpected service type: %s", serviceType.String())) return } getClient += `"` + method.GoName + `", cOpts...)` - g.P("return ", sdkPackage.Ident("WithRequestType"), "[*", method.Input.GoIdent, "]", `(`, getClient, `)`) + g.P("return ", pack.Ident("WithRequestType"), "[*", method.Input.GoIdent, "]", `(`, getClient, `)`) g.P("}") g.P() return From 0e33acb41700f087d2520a1dc61973539fde1f30 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Thu, 7 Nov 2024 17:34:36 +0000 Subject: [PATCH 4/4] Write a client in the codegen example --- examples/codegen/client/client.go | 38 ++++++++++++++ .../codegen/proto/helloworld_restate.pb.go | 49 +++++++++++++++++++ protoc-gen-go-restate/restate.go | 15 ++++-- 3 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 examples/codegen/client/client.go diff --git a/examples/codegen/client/client.go b/examples/codegen/client/client.go new file mode 100644 index 0000000..0a95644 --- /dev/null +++ b/examples/codegen/client/client.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "fmt" + + "github.com/restatedev/sdk-go/client" + helloworld "github.com/restatedev/sdk-go/examples/codegen/proto" +) + +func main() { + ctx, err := client.Connect(context.Background(), "http://127.0.0.1:8080") + if err != nil { + panic(err) + } + + greeter := helloworld.NewGreeterIngressClient(ctx) + greeting, err := greeter.SayHello().Request(&helloworld.HelloRequest{Name: "world"}) + if err != nil { + panic(err) + } + fmt.Println(greeting.Message) + + workflow := helloworld.NewWorkflowIngressClient(ctx, "my-workflow") + send, err := workflow.Run().Send(&helloworld.RunRequest{}) + if err != nil { + panic(err) + } + status, err := workflow.Status().Request(&helloworld.StatusRequest{}) + if err != nil { + panic(err) + } + fmt.Println("workflow running with invocation id", send.InvocationId, "and status", status.Status) + + if _, err := workflow.Finish().Request(&helloworld.FinishRequest{}); err != nil { + panic(err) + } +} diff --git a/examples/codegen/proto/helloworld_restate.pb.go b/examples/codegen/proto/helloworld_restate.pb.go index 09a95b5..4c632b3 100644 --- a/examples/codegen/proto/helloworld_restate.pb.go +++ b/examples/codegen/proto/helloworld_restate.pb.go @@ -328,6 +328,55 @@ func (c *workflowClient) Status(opts ...sdk_go.ClientOption) sdk_go.Client[*Stat return sdk_go.WithRequestType[*StatusRequest](sdk_go.Workflow[*StatusResponse](c.ctx, "Workflow", c.workflowID, "Status", cOpts...)) } +// WorkflowIngressClient is the ingress client API for Workflow service. +type WorkflowIngressClient interface { + // Execute the workflow + Run(opts ...client.IngressClientOption) client.IngressClient[*RunRequest, *RunResponse] + // Unblock the workflow + Finish(opts ...client.IngressClientOption) client.IngressClient[*FinishRequest, *FinishResponse] + // Check the current status + Status(opts ...client.IngressClientOption) client.IngressClient[*StatusRequest, *StatusResponse] +} + +type workflowIngressClient struct { + ctx context.Context + workflowID string + options []client.IngressClientOption +} + +// NewWorkflowIngressClient must be called with a ctx returned from github.com/restatedev/sdk-go/client.Connect +func NewWorkflowIngressClient(ctx context.Context, workflowID string, opts ...client.IngressClientOption) WorkflowIngressClient { + cOpts := append([]client.IngressClientOption{sdk_go.WithProtoJSON}, opts...) + return &workflowIngressClient{ + ctx, + workflowID, + cOpts, + } +} +func (c *workflowIngressClient) Run(opts ...client.IngressClientOption) client.IngressClient[*RunRequest, *RunResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*RunRequest](client.Workflow[*RunResponse](c.ctx, "Workflow", c.workflowID, "Run", cOpts...)) +} + +func (c *workflowIngressClient) Finish(opts ...client.IngressClientOption) client.IngressClient[*FinishRequest, *FinishResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*FinishRequest](client.Workflow[*FinishResponse](c.ctx, "Workflow", c.workflowID, "Finish", cOpts...)) +} + +func (c *workflowIngressClient) Status(opts ...client.IngressClientOption) client.IngressClient[*StatusRequest, *StatusResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]client.IngressClientOption{}, cOpts...), opts...) + } + return client.WithRequestType[*StatusRequest](client.Workflow[*StatusResponse](c.ctx, "Workflow", c.workflowID, "Status", cOpts...)) +} + // WorkflowServer is the server API for Workflow service. // All implementations should embed UnimplementedWorkflowServer // for forward compatibility. diff --git a/protoc-gen-go-restate/restate.go b/protoc-gen-go-restate/restate.go index fb4aaa7..0a4c17a 100644 --- a/protoc-gen-go-restate/restate.go +++ b/protoc-gen-go-restate/restate.go @@ -38,8 +38,11 @@ func generateIngressClientStruct(g *protogen.GeneratedFile, service *protogen.Se g.P("type ", unexport(clientName), " struct {") g.P("ctx ", contextPackage.Ident("Context")) serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) - if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + switch serviceType { + case sdk.ServiceType_VIRTUAL_OBJECT: g.P("key string") + case sdk.ServiceType_WORKFLOW: + g.P("workflowID string") } g.P("options []", clientPackage.Ident("IngressClientOption")) g.P("}") @@ -65,8 +68,11 @@ func generateNewIngressClientDefinitions(g *protogen.GeneratedFile, service *pro g.P("return &", unexport(clientName), "{") g.P("ctx,") serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) - if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + switch serviceType { + case sdk.ServiceType_VIRTUAL_OBJECT: g.P("key,") + case sdk.ServiceType_WORKFLOW: + g.P("workflowID,") } g.P("cOpts,") g.P("}") @@ -257,8 +263,11 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog } g.P("// New", ingressClientName, " must be called with a ctx returned from github.com/restatedev/sdk-go/client.Connect") newIngressClientSignature := "New" + ingressClientName + " (ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) - if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + switch serviceType { + case sdk.ServiceType_VIRTUAL_OBJECT: newIngressClientSignature += ", key string" + case sdk.ServiceType_WORKFLOW: + newIngressClientSignature += ", workflowID string" } newIngressClientSignature += ", opts..." + g.QualifiedGoIdent(clientPackage.Ident("IngressClientOption")) + ") " + ingressClientName