diff --git a/context.go b/context.go index ca638fe..6c934cd 100644 --- a/context.go +++ b/context.go @@ -27,7 +27,7 @@ type Context interface { inner() *state.Context } -// ObjectContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers, +// ObjectSharedContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers, // giving read-only access to a snapshot of state. type ObjectSharedContext interface { Context @@ -40,3 +40,18 @@ type ObjectContext interface { ObjectSharedContext exclusiveObject() } + +// WorkflowSharedContext is an extension of [ObjectSharedContext] which is passed to shared-mode Workflow handlers, +// giving read-only access to a snapshot of state. +type WorkflowSharedContext interface { + ObjectSharedContext + workflow() +} + +// WorkflowContext is an extension of [WorkflowSharedContext] and [ObjectContext] which is passed to Workflow 'run' handlers, +// giving mutable access to state. +type WorkflowContext interface { + WorkflowSharedContext + ObjectContext + runWorkflow() +} diff --git a/examples/codegen/main.go b/examples/codegen/main.go index 757a4eb..a558145 100644 --- a/examples/codegen/main.go +++ b/examples/codegen/main.go @@ -118,10 +118,37 @@ func (c counter) Watch(ctx restate.ObjectSharedContext, req *helloworld.WatchReq return &helloworld.GetResponse{Value: next}, nil } +type workflow struct { + helloworld.UnimplementedWorkflowServer +} + +func (workflow) Run(ctx restate.WorkflowContext, _ *helloworld.RunRequest) (*helloworld.RunResponse, error) { + restate.Set(ctx, "status", "waiting") + _, err := restate.Promise[restate.Void](ctx, "promise").Result() + if err != nil { + return nil, err + } + restate.Set(ctx, "status", "finished") + return &helloworld.RunResponse{Status: "finished"}, nil +} + +func (workflow) Finish(ctx restate.WorkflowSharedContext, _ *helloworld.FinishRequest) (*helloworld.FinishResponse, error) { + return nil, restate.Promise[restate.Void](ctx, "promise").Resolve(restate.Void{}) +} + +func (workflow) Status(ctx restate.WorkflowSharedContext, _ *helloworld.StatusRequest) (*helloworld.StatusResponse, error) { + status, err := restate.Get[string](ctx, "status") + if err != nil { + return nil, err + } + return &helloworld.StatusResponse{Status: status}, nil +} + func main() { server := server.NewRestate(). Bind(helloworld.NewGreeterServer(greeter{})). - Bind(helloworld.NewCounterServer(counter{})) + Bind(helloworld.NewCounterServer(counter{})). + Bind(helloworld.NewWorkflowServer(workflow{})) if err := server.Start(context.Background(), ":9080"); err != nil { slog.Error("application exited unexpectedly", "err", err.Error()) diff --git a/examples/codegen/proto/helloworld.pb.go b/examples/codegen/proto/helloworld.pb.go index c1e7b34..8139585 100644 --- a/examples/codegen/proto/helloworld.pb.go +++ b/examples/codegen/proto/helloworld.pb.go @@ -379,6 +379,252 @@ func (x *WatchRequest) GetTimeoutMillis() int64 { return 0 } +type RunRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *RunRequest) Reset() { + *x = RunRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RunRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RunRequest) ProtoMessage() {} + +func (x *RunRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RunRequest.ProtoReflect.Descriptor instead. +func (*RunRequest) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{8} +} + +type RunResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status string `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` +} + +func (x *RunResponse) Reset() { + *x = RunResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RunResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RunResponse) ProtoMessage() {} + +func (x *RunResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RunResponse.ProtoReflect.Descriptor instead. +func (*RunResponse) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{9} +} + +func (x *RunResponse) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +type StatusRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *StatusRequest) Reset() { + *x = StatusRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StatusRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StatusRequest) ProtoMessage() {} + +func (x *StatusRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead. +func (*StatusRequest) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{10} +} + +type StatusResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status string `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` +} + +func (x *StatusResponse) Reset() { + *x = StatusResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StatusResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StatusResponse) ProtoMessage() {} + +func (x *StatusResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. +func (*StatusResponse) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{11} +} + +func (x *StatusResponse) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +type FinishRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *FinishRequest) Reset() { + *x = FinishRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *FinishRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FinishRequest) ProtoMessage() {} + +func (x *FinishRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[12] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FinishRequest.ProtoReflect.Descriptor instead. +func (*FinishRequest) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{12} +} + +type FinishResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *FinishResponse) Reset() { + *x = FinishResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_helloworld_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *FinishResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FinishResponse) ProtoMessage() {} + +func (x *FinishResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_helloworld_proto_msgTypes[13] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FinishResponse.ProtoReflect.Descriptor instead. +func (*FinishResponse) Descriptor() ([]byte, []int) { + return file_proto_helloworld_proto_rawDescGZIP(), []int{13} +} + var File_proto_helloworld_proto protoreflect.FileDescriptor var file_proto_helloworld_proto_rawDesc = []byte{ @@ -404,40 +650,63 @@ var file_proto_helloworld_proto_rawDesc = []byte{ 0x73, 0x65, 0x22, 0x35, 0x0a, 0x0c, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x5f, 0x6d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x74, 0x69, 0x6d, 0x65, - 0x6f, 0x75, 0x74, 0x4d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x32, 0x4c, 0x0a, 0x07, 0x47, 0x72, 0x65, - 0x65, 0x74, 0x65, 0x72, 0x12, 0x41, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 0x6f, - 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, - 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x68, 0x65, 0x6c, - 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x32, 0x98, 0x02, 0x0a, 0x07, 0x43, 0x6f, 0x75, 0x6e, - 0x74, 0x65, 0x72, 0x12, 0x38, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, - 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, - 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3c, 0x0a, - 0x03, 0x47, 0x65, 0x74, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, - 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, - 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x12, 0x4d, 0x0a, 0x0a, 0x41, - 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x68, 0x65, 0x6c, 0x6c, - 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, - 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, - 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x40, 0x0a, 0x05, 0x57, 0x61, - 0x74, 0x63, 0x68, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, - 0x2e, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, - 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x1a, 0x04, 0x98, 0x80, - 0x01, 0x01, 0x42, 0x9e, 0x01, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, - 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x42, 0x0f, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, - 0x64, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, 0x74, 0x65, 0x64, 0x65, 0x76, 0x2f, - 0x73, 0x64, 0x6b, 0x2d, 0x67, 0x6f, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, - 0x63, 0x6f, 0x64, 0x65, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0xa2, 0x02, 0x03, - 0x48, 0x58, 0x58, 0xaa, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, - 0xca, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0xe2, 0x02, 0x16, - 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, - 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x75, 0x74, 0x4d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x22, 0x0c, 0x0a, 0x0a, 0x52, 0x75, 0x6e, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x25, 0x0a, 0x0b, 0x52, 0x75, 0x6e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x0f, + 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0x28, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x0f, 0x0a, 0x0d, 0x46, 0x69, 0x6e, + 0x69, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x10, 0x0a, 0x0e, 0x46, 0x69, + 0x6e, 0x69, 0x73, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x4c, 0x0a, 0x07, + 0x47, 0x72, 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, 0x41, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, + 0x6c, 0x6c, 0x6f, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, + 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x32, 0x98, 0x02, 0x0a, 0x07, 0x43, + 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x38, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x12, 0x16, 0x2e, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, + 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x3c, 0x0a, 0x03, 0x47, 0x65, 0x74, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, + 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x12, 0x4d, + 0x0a, 0x0a, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, 0x68, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, + 0x63, 0x68, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x68, 0x65, + 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x41, 0x64, 0x64, 0x57, 0x61, 0x74, 0x63, + 0x68, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x40, 0x0a, + 0x05, 0x57, 0x61, 0x74, 0x63, 0x68, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, + 0x72, 0x6c, 0x64, 0x2e, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x47, 0x65, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x04, 0x98, 0x80, 0x01, 0x02, 0x1a, + 0x04, 0x98, 0x80, 0x01, 0x01, 0x32, 0xd0, 0x01, 0x0a, 0x08, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, + 0x6f, 0x77, 0x12, 0x38, 0x0a, 0x03, 0x52, 0x75, 0x6e, 0x12, 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x52, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x17, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x52, + 0x75, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x41, 0x0a, 0x06, + 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, + 0x72, 0x6c, 0x64, 0x2e, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1a, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x46, + 0x69, 0x6e, 0x69, 0x73, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x41, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x19, 0x2e, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, + 0x64, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x1a, 0x04, 0x98, 0x80, 0x01, 0x02, 0x42, 0x9e, 0x01, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, + 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x42, 0x0f, 0x48, 0x65, 0x6c, + 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x33, + 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x64, 0x65, 0x76, 0x2f, 0x73, 0x64, 0x6b, 0x2d, 0x67, 0x6f, 0x2f, 0x65, 0x78, 0x61, + 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0xa2, 0x02, 0x03, 0x48, 0x58, 0x58, 0xaa, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, + 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0xca, 0x02, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, + 0x72, 0x6c, 0x64, 0xe2, 0x02, 0x16, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, + 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0a, 0x48, + 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( @@ -452,7 +721,7 @@ func file_proto_helloworld_proto_rawDescGZIP() []byte { return file_proto_helloworld_proto_rawDescData } -var file_proto_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_proto_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 14) var file_proto_helloworld_proto_goTypes = []any{ (*HelloRequest)(nil), // 0: helloworld.HelloRequest (*HelloResponse)(nil), // 1: helloworld.HelloResponse @@ -462,23 +731,35 @@ var file_proto_helloworld_proto_goTypes = []any{ (*AddWatcherRequest)(nil), // 5: helloworld.AddWatcherRequest (*AddWatcherResponse)(nil), // 6: helloworld.AddWatcherResponse (*WatchRequest)(nil), // 7: helloworld.WatchRequest + (*RunRequest)(nil), // 8: helloworld.RunRequest + (*RunResponse)(nil), // 9: helloworld.RunResponse + (*StatusRequest)(nil), // 10: helloworld.StatusRequest + (*StatusResponse)(nil), // 11: helloworld.StatusResponse + (*FinishRequest)(nil), // 12: helloworld.FinishRequest + (*FinishResponse)(nil), // 13: helloworld.FinishResponse } var file_proto_helloworld_proto_depIdxs = []int32{ - 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest - 2, // 1: helloworld.Counter.Add:input_type -> helloworld.AddRequest - 3, // 2: helloworld.Counter.Get:input_type -> helloworld.GetRequest - 5, // 3: helloworld.Counter.AddWatcher:input_type -> helloworld.AddWatcherRequest - 7, // 4: helloworld.Counter.Watch:input_type -> helloworld.WatchRequest - 1, // 5: helloworld.Greeter.SayHello:output_type -> helloworld.HelloResponse - 4, // 6: helloworld.Counter.Add:output_type -> helloworld.GetResponse - 4, // 7: helloworld.Counter.Get:output_type -> helloworld.GetResponse - 6, // 8: helloworld.Counter.AddWatcher:output_type -> helloworld.AddWatcherResponse - 4, // 9: helloworld.Counter.Watch:output_type -> helloworld.GetResponse - 5, // [5:10] is the sub-list for method output_type - 0, // [0:5] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest + 2, // 1: helloworld.Counter.Add:input_type -> helloworld.AddRequest + 3, // 2: helloworld.Counter.Get:input_type -> helloworld.GetRequest + 5, // 3: helloworld.Counter.AddWatcher:input_type -> helloworld.AddWatcherRequest + 7, // 4: helloworld.Counter.Watch:input_type -> helloworld.WatchRequest + 8, // 5: helloworld.Workflow.Run:input_type -> helloworld.RunRequest + 12, // 6: helloworld.Workflow.Finish:input_type -> helloworld.FinishRequest + 10, // 7: helloworld.Workflow.Status:input_type -> helloworld.StatusRequest + 1, // 8: helloworld.Greeter.SayHello:output_type -> helloworld.HelloResponse + 4, // 9: helloworld.Counter.Add:output_type -> helloworld.GetResponse + 4, // 10: helloworld.Counter.Get:output_type -> helloworld.GetResponse + 6, // 11: helloworld.Counter.AddWatcher:output_type -> helloworld.AddWatcherResponse + 4, // 12: helloworld.Counter.Watch:output_type -> helloworld.GetResponse + 9, // 13: helloworld.Workflow.Run:output_type -> helloworld.RunResponse + 13, // 14: helloworld.Workflow.Finish:output_type -> helloworld.FinishResponse + 11, // 15: helloworld.Workflow.Status:output_type -> helloworld.StatusResponse + 8, // [8:16] is the sub-list for method output_type + 0, // [0:8] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name } func init() { file_proto_helloworld_proto_init() } @@ -583,6 +864,78 @@ func file_proto_helloworld_proto_init() { return nil } } + file_proto_helloworld_proto_msgTypes[8].Exporter = func(v any, i int) any { + switch v := v.(*RunRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[9].Exporter = func(v any, i int) any { + switch v := v.(*RunResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[10].Exporter = func(v any, i int) any { + switch v := v.(*StatusRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[11].Exporter = func(v any, i int) any { + switch v := v.(*StatusResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[12].Exporter = func(v any, i int) any { + switch v := v.(*FinishRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_helloworld_proto_msgTypes[13].Exporter = func(v any, i int) any { + switch v := v.(*FinishResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -590,9 +943,9 @@ func file_proto_helloworld_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proto_helloworld_proto_rawDesc, NumEnums: 0, - NumMessages: 8, + NumMessages: 14, NumExtensions: 0, - NumServices: 2, + NumServices: 3, }, GoTypes: file_proto_helloworld_proto_goTypes, DependencyIndexes: file_proto_helloworld_proto_depIdxs, diff --git a/examples/codegen/proto/helloworld.proto b/examples/codegen/proto/helloworld.proto index 127d1e0..693f0fe 100644 --- a/examples/codegen/proto/helloworld.proto +++ b/examples/codegen/proto/helloworld.proto @@ -26,6 +26,16 @@ service Counter { } } +service Workflow { + option (dev.restate.sdk.go.service_type) = WORKFLOW; + // Execute the workflow + rpc Run (RunRequest) returns (RunResponse) {} + // Unblock the workflow + rpc Finish(FinishRequest) returns (FinishResponse) {} + // Check the current status + rpc Status (StatusRequest) returns (StatusResponse) {} +} + message HelloRequest { string name = 1; } @@ -53,3 +63,19 @@ message AddWatcherResponse {} message WatchRequest { int64 timeout_millis = 1; } + +message RunRequest {} + +message RunResponse { + string status = 1; +} + +message StatusRequest {} + +message StatusResponse { + string status = 1; +} + +message FinishRequest {} + +message FinishResponse {} diff --git a/examples/codegen/proto/helloworld_restate.pb.go b/examples/codegen/proto/helloworld_restate.pb.go index 01f3c21..19d78f0 100644 --- a/examples/codegen/proto/helloworld_restate.pb.go +++ b/examples/codegen/proto/helloworld_restate.pb.go @@ -192,3 +192,104 @@ func NewCounterServer(srv CounterServer, opts ...sdk_go.ServiceDefinitionOption) router = router.Handler("Watch", sdk_go.NewObjectSharedHandler(srv.Watch)) return router } + +// WorkflowClient is the client API for Workflow service. +type WorkflowClient interface { + // Execute the workflow + Run(opts ...sdk_go.ClientOption) sdk_go.Client[*RunRequest, *RunResponse] + // Unblock the workflow + Finish(opts ...sdk_go.ClientOption) sdk_go.Client[*FinishRequest, *FinishResponse] + // Check the current status + Status(opts ...sdk_go.ClientOption) sdk_go.Client[*StatusRequest, *StatusResponse] +} + +type workflowClient struct { + ctx sdk_go.Context + workflowID string + options []sdk_go.ClientOption +} + +func NewWorkflowClient(ctx sdk_go.Context, workflowID string, opts ...sdk_go.ClientOption) WorkflowClient { + cOpts := append([]sdk_go.ClientOption{sdk_go.WithProtoJSON}, opts...) + return &workflowClient{ + ctx, + workflowID, + cOpts, + } +} +func (c *workflowClient) Run(opts ...sdk_go.ClientOption) sdk_go.Client[*RunRequest, *RunResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) + } + return sdk_go.WithRequestType[*RunRequest](sdk_go.Workflow[*RunResponse](c.ctx, "Workflow", c.workflowID, "Run", cOpts...)) +} + +func (c *workflowClient) Finish(opts ...sdk_go.ClientOption) sdk_go.Client[*FinishRequest, *FinishResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) + } + return sdk_go.WithRequestType[*FinishRequest](sdk_go.Workflow[*FinishResponse](c.ctx, "Workflow", c.workflowID, "Finish", cOpts...)) +} + +func (c *workflowClient) Status(opts ...sdk_go.ClientOption) sdk_go.Client[*StatusRequest, *StatusResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) + } + return sdk_go.WithRequestType[*StatusRequest](sdk_go.Workflow[*StatusResponse](c.ctx, "Workflow", c.workflowID, "Status", cOpts...)) +} + +// WorkflowServer is the server API for Workflow service. +// All implementations should embed UnimplementedWorkflowServer +// for forward compatibility. +type WorkflowServer interface { + // Execute the workflow + Run(ctx sdk_go.WorkflowContext, req *RunRequest) (*RunResponse, error) + // Unblock the workflow + Finish(ctx sdk_go.WorkflowSharedContext, req *FinishRequest) (*FinishResponse, error) + // Check the current status + Status(ctx sdk_go.WorkflowSharedContext, req *StatusRequest) (*StatusResponse, error) +} + +// UnimplementedWorkflowServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedWorkflowServer struct{} + +func (UnimplementedWorkflowServer) Run(ctx sdk_go.WorkflowContext, req *RunRequest) (*RunResponse, error) { + return nil, sdk_go.TerminalError(fmt.Errorf("method Run not implemented"), 501) +} +func (UnimplementedWorkflowServer) Finish(ctx sdk_go.WorkflowSharedContext, req *FinishRequest) (*FinishResponse, error) { + return nil, sdk_go.TerminalError(fmt.Errorf("method Finish not implemented"), 501) +} +func (UnimplementedWorkflowServer) Status(ctx sdk_go.WorkflowSharedContext, req *StatusRequest) (*StatusResponse, error) { + return nil, sdk_go.TerminalError(fmt.Errorf("method Status not implemented"), 501) +} +func (UnimplementedWorkflowServer) testEmbeddedByValue() {} + +// UnsafeWorkflowServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to WorkflowServer will +// result in compilation errors. +type UnsafeWorkflowServer interface { + mustEmbedUnimplementedWorkflowServer() +} + +func NewWorkflowServer(srv WorkflowServer, opts ...sdk_go.ServiceDefinitionOption) sdk_go.ServiceDefinition { + // If the following call panics, it indicates UnimplementedWorkflowServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + sOpts := append([]sdk_go.ServiceDefinitionOption{sdk_go.WithProtoJSON}, opts...) + router := sdk_go.NewWorkflow("Workflow", sOpts...) + router = router.Handler("Run", sdk_go.NewWorkflowHandler(srv.Run)) + router = router.Handler("Finish", sdk_go.NewWorkflowSharedHandler(srv.Finish)) + router = router.Handler("Status", sdk_go.NewWorkflowSharedHandler(srv.Status)) + return router +} diff --git a/facilitators.go b/facilitators.go index fbc12ac..11fa196 100644 --- a/facilitators.go +++ b/facilitators.go @@ -59,6 +59,16 @@ func ObjectSend(ctx Context, service string, key string, method string, options return ctx.inner().Object(service, key, method, options...) } +// Workflow gets a Workflow request client by service name, workflow ID and method name +func Workflow[O any](ctx Context, service string, workflowID string, method string, options ...options.ClientOption) Client[any, O] { + return outputClient[O]{ctx.inner().Workflow(service, workflowID, method, options...)} +} + +// WorkflowSend gets a Workflow send client by service name, workflow ID and method name +func WorkflowSend[O any](ctx Context, service string, workflowID string, method string, options ...options.ClientOption) SendClient[any] { + return ctx.inner().Workflow(service, workflowID, method, options...) +} + // Client represents all the different ways you can invoke a particular service-method. type Client[I any, O any] interface { // RequestFuture makes a call and returns a handle on a future response @@ -143,8 +153,8 @@ func Awakeable[T any](ctx Context, options ...options.AwakeableOption) Awakeable type AwakeableFuture[T any] interface { // Id returns the awakeable ID, which can be stored or sent to a another service Id() string - // Result blocks on receiving the result of the awakeable, storing the value it was - // resolved with in output or otherwise returning the error it was rejected with. + // Result blocks on receiving the result of the awakeable, returning the value it was + // resolved or otherwise returning the error it was rejected with. // It is *not* safe to call this in a goroutine - use Context.Select if you // want to wait on multiple results at once. Result() (T, error) @@ -237,3 +247,46 @@ func Clear(ctx ObjectContext, key string) { func ClearAll(ctx ObjectContext) { ctx.inner().ClearAll() } + +// Promise returns a named Restate durable Promise that can be resolved or rejected during the workflow execution. +// The promise is bound to the workflow and will be persisted across suspensions and retries. +func Promise[T any](ctx WorkflowSharedContext, name string, options ...options.PromiseOption) DurablePromise[T] { + return durablePromise[T]{ctx.inner().Promise(name, options...)} +} + +type DurablePromise[T any] interface { + // Result blocks on receiving the result of the Promise, returning the value it was + // resolved or otherwise returning the error it was rejected with or a cancellation error. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Result() (T, error) + // Peek returns the value of the promise if it has been resolved. If it has not been resolved, + // the zero value of T is returned. To check explicitly for this case pass a pointer eg *string as T. + // If the promise was rejected or the invocation was cancelled, an error is returned. + Peek() (T, error) + // Resolve resolves the promise with a value, returning an error if it was already completed + // or if the invocation was cancelled. + Resolve(value T) error + // Reject rejects the promise with an error, returning an error if it was already completed + // or if the invocation was cancelled. + Reject(reason error) error + futures.Selectable +} + +type durablePromise[T any] struct { + state.DecodingPromise +} + +func (t durablePromise[T]) Result() (output T, err error) { + err = t.DecodingPromise.Result(&output) + return +} + +func (t durablePromise[T]) Peek() (output T, err error) { + _, err = t.DecodingPromise.Peek(&output) + return +} + +func (t durablePromise[T]) Resolve(value T) (err error) { + return t.DecodingPromise.Resolve(value) +} diff --git a/handler.go b/handler.go index 2510b3d..7201d47 100644 --- a/handler.go +++ b/handler.go @@ -22,12 +22,18 @@ type Void = encoding.Void // ServiceHandlerFn is the signature for a Service handler function type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (O, error) -// ObjectHandlerFn is the signature for a Virtual Object exclusive-mode handler function +// ObjectHandlerFn is the signature for a Virtual Object exclusive handler function type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (O, error) -// ObjectHandlerFn is the signature for a Virtual Object shared-mode handler function +// ObjectSharedHandlerFn is the signature for a Virtual Object shared-mode handler function type ObjectSharedHandlerFn[I any, O any] func(ctx ObjectSharedContext, input I) (O, error) +// ObjectHandlerFn is the signature for a Workflow 'Run' handler function +type WorkflowHandlerFn[I any, O any] func(ctx WorkflowContext, input I) (O, error) + +// WorkflowSharedHandlerFn is the signature for a Workflow shared handler function +type WorkflowSharedHandlerFn[I any, O any] func(ctx WorkflowSharedContext, input I) (O, error) + type serviceHandler[I any, O any] struct { fn ServiceHandlerFn[I, O] options options.HandlerOptions @@ -135,6 +141,8 @@ func (o ctxWrapper) inner() *state.Context { } func (o ctxWrapper) object() {} func (o ctxWrapper) exclusiveObject() {} +func (o ctxWrapper) workflow() {} +func (o ctxWrapper) runWorkflow() {} func (h *objectHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) { var input I @@ -186,3 +194,92 @@ func (h *objectHandler[I, O]) GetOptions() *options.HandlerOptions { func (h *objectHandler[I, O]) HandlerType() *internal.ServiceHandlerType { return &h.handlerType } + +type workflowHandler[I any, O any] struct { + // only one of workflowFn or sharedFn should be set, as indicated by handlerType + workflowFn WorkflowHandlerFn[I, O] + sharedFn WorkflowSharedHandlerFn[I, O] + options options.HandlerOptions + handlerType internal.ServiceHandlerType +} + +var _ state.Handler = (*workflowHandler[struct{}, struct{}])(nil) + +// NewWorkflowHandler converts a function of signature [WorkflowHandlerFn] into the 'Run' handler on a Workflow. +// The handler will have access to a full [WorkflowContext] which may mutate state. +func NewWorkflowHandler[I any, O any](fn WorkflowHandlerFn[I, O], opts ...options.HandlerOption) *workflowHandler[I, O] { + o := options.HandlerOptions{} + for _, opt := range opts { + opt.BeforeHandler(&o) + } + return &workflowHandler[I, O]{ + workflowFn: fn, + options: o, + handlerType: internal.ServiceHandlerType_WORKFLOW, + } +} + +// NewWorkflowSharedHandler converts a function of signature [ObjectSharedHandlerFn] into a shared-mode handler on a Workflow. +// The handler will only have access to a [WorkflowSharedContext] which can only read a snapshot of state. +func NewWorkflowSharedHandler[I any, O any](fn WorkflowSharedHandlerFn[I, O], opts ...options.HandlerOption) *workflowHandler[I, O] { + o := options.HandlerOptions{} + for _, opt := range opts { + opt.BeforeHandler(&o) + } + return &workflowHandler[I, O]{ + sharedFn: fn, + options: o, + handlerType: internal.ServiceHandlerType_SHARED, + } +} + +func (h *workflowHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) { + var input I + if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) + } + + var output O + var err error + switch h.handlerType { + case internal.ServiceHandlerType_WORKFLOW: + output, err = h.workflowFn( + ctxWrapper{ctx}, + input, + ) + case internal.ServiceHandlerType_SHARED: + output, err = h.sharedFn( + ctxWrapper{ctx}, + input, + ) + } + if err != nil { + return nil, err + } + + bytes, err = encoding.Marshal(h.options.Codec, output) + if err != nil { + // we don't use a terminal error here as this is hot-fixable by changing the return type + return nil, fmt.Errorf("failed to serialize output: %w", err) + } + + return bytes, nil +} + +func (h *workflowHandler[I, O]) InputPayload() *encoding.InputPayload { + var i I + return encoding.InputPayloadFor(h.options.Codec, i) +} + +func (h *workflowHandler[I, O]) OutputPayload() *encoding.OutputPayload { + var o O + return encoding.OutputPayloadFor(h.options.Codec, o) +} + +func (h *workflowHandler[I, O]) GetOptions() *options.HandlerOptions { + return &h.options +} + +func (h *workflowHandler[I, O]) HandlerType() *internal.ServiceHandlerType { + return &h.handlerType +} diff --git a/internal/futures/futures.go b/internal/futures/futures.go index d40eb71..ed1ca6f 100644 --- a/internal/futures/futures.go +++ b/internal/futures/futures.go @@ -107,3 +107,38 @@ func (r *ResponseFuture) Response() ([]byte, error) { func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32) { return r.entry, r.entryIndex } + +type Promise struct { + suspensionCtx context.Context + invocationID []byte + entry *wire.GetPromiseEntryMessage + entryIndex uint32 + getPromise func() (*wire.GetPromiseEntryMessage, uint32) +} + +func NewPromise(suspensionCtx context.Context, invocationID []byte, getPromise func() (*wire.GetPromiseEntryMessage, uint32)) *Promise { + return &Promise{suspensionCtx, invocationID, nil, 0, getPromise} +} + +func (c *Promise) Result() ([]byte, error) { + c.getEntry() + + c.entry.Await(c.suspensionCtx, c.entryIndex) + + switch result := c.entry.Result.(type) { + case *protocol.GetPromiseEntryMessage_Value: + return result.Value, nil + case *protocol.GetPromiseEntryMessage_Failure: + return nil, errors.ErrorFromFailure(result.Failure) + default: + return nil, fmt.Errorf("unexpected result in completed get promise entry: %v", c.entry.Result) + } +} + +func (c *Promise) getEntry() (wire.CompleteableMessage, uint32) { + if c.entry == nil { + c.entry, c.entryIndex = c.getPromise() + } + + return c.entry, c.entryIndex +} diff --git a/internal/options/options.go b/internal/options/options.go index 4eec3a6..278a1c7 100644 --- a/internal/options/options.go +++ b/internal/options/options.go @@ -14,6 +14,14 @@ type AwakeableOption interface { BeforeAwakeable(*AwakeableOptions) } +type PromiseOptions struct { + Codec encoding.Codec +} + +type PromiseOption interface { + BeforePromise(*PromiseOptions) +} + type ResolveAwakeableOptions struct { Codec encoding.Codec } diff --git a/internal/state/call.go b/internal/state/call.go index baf9637..febbfa1 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -31,7 +31,7 @@ func (c *Client) RequestFuture(input any, opts ...options.RequestOption) Decodin bytes, err := encoding.Marshal(c.options.Codec, input) if err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal RequestFuture input: %w", err))) + panic(c.machine.newCodecFailure(wire.CallEntryMessageType, fmt.Errorf("failed to marshal RequestFuture input: %w", err))) } entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, o.Headers, bytes) @@ -56,7 +56,7 @@ func (d DecodingResponseFuture) Response(output any) (err error) { } if err := encoding.Unmarshal(d.options.Codec, bytes, output); err != nil { - panic(d.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Call response into output: %w", err))) + panic(d.machine.newCodecFailure(wire.CallEntryMessageType, fmt.Errorf("failed to unmarshal Call response into output: %w", err))) } return nil @@ -76,7 +76,7 @@ func (c *Client) Send(input any, opts ...options.SendOption) { bytes, err := encoding.Marshal(c.options.Codec, input) if err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Send input: %w", err))) + panic(c.machine.newCodecFailure(wire.OneWayCallEntryMessageType, fmt.Errorf("failed to marshal Send input: %w", err))) } c.machine.sendCall(c.service, c.key, c.method, o.Headers, bytes, o.Delay) return diff --git a/internal/state/promise.go b/internal/state/promise.go new file mode 100644 index 0000000..a50357b --- /dev/null +++ b/internal/state/promise.go @@ -0,0 +1,230 @@ +package state + +import ( + "bytes" + "fmt" + + "github.com/restatedev/sdk-go/encoding" + protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" + "github.com/restatedev/sdk-go/internal/errors" + "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" + "github.com/restatedev/sdk-go/internal/wire" +) + +func (c *Context) Promise(key string, opts ...options.PromiseOption) DecodingPromise { + o := options.PromiseOptions{} + for _, opt := range opts { + opt.BeforePromise(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + return DecodingPromise{futures.NewPromise(c.machine.suspensionCtx, c.machine.request.ID, func() (*wire.GetPromiseEntryMessage, uint32) { + return c.machine.getPromise(key) + }), key, c.machine, o.Codec} +} + +type DecodingPromise struct { + *futures.Promise + key string + machine *Machine + codec encoding.Codec +} + +func (d DecodingPromise) Result(output any) (err error) { + bytes, err := d.Promise.Result() + if err != nil { + return err + } + if err := encoding.Unmarshal(d.codec, bytes, output); err != nil { + panic(d.machine.newCodecFailure(wire.GetPromiseEntryMessageType, fmt.Errorf("failed to unmarshal Promise result into output: %w", err))) + } + return +} + +func (d DecodingPromise) Peek(output any) (ok bool, err error) { + bytes, ok, err := d.machine.peekPromise(d.key) + if err != nil || !ok { + return ok, err + } + if err := encoding.Unmarshal(d.codec, bytes, output); err != nil { + panic(d.machine.newCodecFailure(wire.PeekPromiseEntryMessageType, fmt.Errorf("failed to unmarshal Promise result into output: %w", err))) + } + return +} + +func (d DecodingPromise) Resolve(value any) error { + bytes, err := encoding.Marshal(d.codec, value) + if err != nil { + panic(d.machine.newCodecFailure(wire.CompletePromiseEntryMessageType, fmt.Errorf("failed to marshal Promise Resolve value: %w", err))) + } + return d.machine.resolvePromise(d.key, bytes) +} + +func (d DecodingPromise) Reject(reason error) error { + return d.machine.rejectPromise(d.key, reason) +} + +func (m *Machine) getPromise(key string) (*wire.GetPromiseEntryMessage, uint32) { + return replayOrNew( + m, + func(entry *wire.GetPromiseEntryMessage) *wire.GetPromiseEntryMessage { + if entry.Key != key { + panic(m.newEntryMismatch(&wire.GetPromiseEntryMessage{ + GetPromiseEntryMessage: protocol.GetPromiseEntryMessage{ + Key: key, + }, + }, entry)) + } + return entry + }, + func() *wire.GetPromiseEntryMessage { + return m._getPromise(key) + }, + ) +} + +func (c *Machine) _getPromise(key string) *wire.GetPromiseEntryMessage { + msg := &wire.GetPromiseEntryMessage{ + GetPromiseEntryMessage: protocol.GetPromiseEntryMessage{ + Key: key, + }, + } + c.Write(msg) + return msg +} + +func (m *Machine) peekPromise(key string) ([]byte, bool, error) { + entry, entryIndex := replayOrNew( + m, + func(entry *wire.PeekPromiseEntryMessage) *wire.PeekPromiseEntryMessage { + if entry.Key != key { + panic(m.newEntryMismatch(&wire.PeekPromiseEntryMessage{ + PeekPromiseEntryMessage: protocol.PeekPromiseEntryMessage{ + Key: key, + }, + }, entry)) + } + return entry + }, + func() *wire.PeekPromiseEntryMessage { + return m._peekPromise(key) + }, + ) + + entry.Await(m.suspensionCtx, entryIndex) + + switch value := entry.Result.(type) { + case *protocol.PeekPromiseEntryMessage_Empty: + return nil, false, nil + case *protocol.PeekPromiseEntryMessage_Failure: + return nil, false, errors.ErrorFromFailure(value.Failure) + case *protocol.PeekPromiseEntryMessage_Value: + return value.Value, true, nil + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("peek promise entry had invalid result: %v", entry.Result))) + } +} + +func (c *Machine) _peekPromise(key string) *wire.PeekPromiseEntryMessage { + msg := &wire.PeekPromiseEntryMessage{ + PeekPromiseEntryMessage: protocol.PeekPromiseEntryMessage{ + Key: key, + }, + } + c.Write(msg) + return msg +} + +func (m *Machine) resolvePromise(key string, value []byte) error { + entry, entryIndex := replayOrNew( + m, + func(entry *wire.CompletePromiseEntryMessage) *wire.CompletePromiseEntryMessage { + messageValue, ok := entry.Completion.(*protocol.CompletePromiseEntryMessage_CompletionValue) + if entry.Key != key || !ok || !bytes.Equal(messageValue.CompletionValue, value) { + panic(m.newEntryMismatch(&wire.CompletePromiseEntryMessage{ + CompletePromiseEntryMessage: protocol.CompletePromiseEntryMessage{ + Key: key, + Completion: &protocol.CompletePromiseEntryMessage_CompletionValue{CompletionValue: value}, + }, + }, entry)) + } + return entry + }, + func() *wire.CompletePromiseEntryMessage { + return m._resolvePromise(key, value) + }, + ) + + entry.Await(m.suspensionCtx, entryIndex) + + switch value := entry.Result.(type) { + case *protocol.CompletePromiseEntryMessage_Empty: + return nil + case *protocol.CompletePromiseEntryMessage_Failure: + return errors.ErrorFromFailure(value.Failure) + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("complete promise entry had invalid result: %v", entry.Result))) + } +} + +func (c *Machine) _resolvePromise(key string, value []byte) *wire.CompletePromiseEntryMessage { + msg := &wire.CompletePromiseEntryMessage{ + CompletePromiseEntryMessage: protocol.CompletePromiseEntryMessage{ + Key: key, + Completion: &protocol.CompletePromiseEntryMessage_CompletionValue{CompletionValue: value}, + }, + } + c.Write(msg) + return msg +} + +func (m *Machine) rejectPromise(key string, reason error) error { + entry, entryIndex := replayOrNew( + m, + func(entry *wire.CompletePromiseEntryMessage) *wire.CompletePromiseEntryMessage { + messageFailure, ok := entry.Result.(*protocol.CompletePromiseEntryMessage_Failure) + if entry.Key != key || !ok || messageFailure.Failure.Code != uint32(errors.ErrorCode(reason)) || messageFailure.Failure.Message != reason.Error() { + panic(m.newEntryMismatch(&wire.CompletePromiseEntryMessage{ + CompletePromiseEntryMessage: protocol.CompletePromiseEntryMessage{ + Key: key, + Completion: &protocol.CompletePromiseEntryMessage_CompletionFailure{CompletionFailure: &protocol.Failure{ + Code: uint32(errors.ErrorCode(reason)), + Message: reason.Error(), + }}, + }, + }, entry)) + } + return entry + }, + func() *wire.CompletePromiseEntryMessage { + return m._rejectPromise(key, reason) + }, + ) + + entry.Await(m.suspensionCtx, entryIndex) + + switch value := entry.Result.(type) { + case *protocol.CompletePromiseEntryMessage_Empty: + return nil + case *protocol.CompletePromiseEntryMessage_Failure: + return errors.ErrorFromFailure(value.Failure) + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("complete promise entry had invalid result: %v", entry.Result))) + } +} + +func (c *Machine) _rejectPromise(key string, reason error) *wire.CompletePromiseEntryMessage { + msg := &wire.CompletePromiseEntryMessage{ + CompletePromiseEntryMessage: protocol.CompletePromiseEntryMessage{ + Key: key, + Completion: &protocol.CompletePromiseEntryMessage_CompletionFailure{CompletionFailure: &protocol.Failure{ + Code: uint32(errors.ErrorCode(reason)), + Message: reason.Error(), + }}, + }, + } + c.Write(msg) + return msg +} diff --git a/internal/state/state.go b/internal/state/state.go index 98312e6..080dd76 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -60,7 +60,7 @@ func (c *Context) Set(key string, value any, opts ...options.SetOption) { bytes, err := encoding.Marshal(o.Codec, value) if err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Set value: %w", err))) + panic(c.machine.newCodecFailure(wire.SetStateEntryMessageType, fmt.Errorf("failed to marshal Set value: %w", err))) } c.machine.set(key, bytes) @@ -93,7 +93,7 @@ func (c *Context) Get(key string, output any, opts ...options.GetOption) (bool, } if err := encoding.Unmarshal(o.Codec, bytes, output); err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Get state into output: %w", err))) + panic(c.machine.newCodecFailure(wire.GetStateEntryMessageType, fmt.Errorf("failed to unmarshal Get state into output: %w", err))) } return true, nil @@ -146,6 +146,24 @@ func (c *Context) Object(service, key, method string, opts ...options.ClientOpti } } +func (c *Context) Workflow(service, workflowID, method string, opts ...options.ClientOption) *Client { + o := options.ClientOptions{} + for _, opt := range opts { + opt.BeforeClient(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + return &Client{ + options: o, + machine: c.machine, + service: service, + key: workflowID, + method: method, + } +} + func (c *Context) Run(fn func(ctx RunContext) (any, error), output any, opts ...options.RunOption) error { o := options.RunOptions{} for _, opt := range opts { @@ -163,7 +181,7 @@ func (c *Context) Run(fn func(ctx RunContext) (any, error), output any, opts ... bytes, err := encoding.Marshal(o.Codec, output) if err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Run output: %w", err))) + panic(c.machine.newCodecFailure(wire.RunEntryMessageType, fmt.Errorf("failed to marshal Run output: %w", err))) } return bytes, nil @@ -173,20 +191,12 @@ func (c *Context) Run(fn func(ctx RunContext) (any, error), output any, opts ... } if err := encoding.Unmarshal(o.Codec, bytes, output); err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Run output: %w", err))) + panic(c.machine.newCodecFailure(wire.RunEntryMessageType, fmt.Errorf("failed to unmarshal Run output: %w", err))) } return nil } -type awakeableOptions struct { - codec encoding.Codec -} - -type AwakeableOption interface { - beforeAwakeable(*awakeableOptions) -} - func (c *Context) Awakeable(opts ...options.AwakeableOption) DecodingAwakeable { o := options.AwakeableOptions{} for _, opt := range opts { @@ -211,7 +221,7 @@ func (d DecodingAwakeable) Result(output any) (err error) { return err } if err := encoding.Unmarshal(d.codec, bytes, output); err != nil { - panic(d.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err))) + panic(d.machine.newCodecFailure(wire.AwakeableEntryMessageType, fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err))) } return } @@ -226,7 +236,7 @@ func (c *Context) ResolveAwakeable(id string, value any, opts ...options.Resolve } bytes, err := encoding.Marshal(o.Codec, value) if err != nil { - panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err))) + panic(c.machine.newCodecFailure(wire.CompleteAwakeableEntryMessageType, fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err))) } c.machine.resolveAwakeable(id, bytes) } @@ -420,7 +430,7 @@ The journal entry at position %d was: Code: uint32(errors.ErrorCode(typ.err)), Message: typ.err.Error(), RelatedEntryIndex: &typ.entryIndex, - RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), + RelatedEntryType: wire.RunEntryMessageType.UInt32(), }, }); err != nil { m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(typ.err)) @@ -435,7 +445,7 @@ The journal entry at position %d was: Code: uint32(errors.ErrorCode(typ.err)), Message: typ.err.Error(), RelatedEntryIndex: &typ.entryIndex, - RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), + RelatedEntryType: typ.entryType.UInt32(), }, }); err != nil { m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(typ.err)) diff --git a/internal/state/sys.go b/internal/state/sys.go index 190c0b9..f4fae38 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -378,12 +378,13 @@ func (m *Machine) newRunFailure(err error) *runFailure { } type codecFailure struct { + entryType wire.Type entryIndex uint32 err error } -func (m *Machine) newCodecFailure(err error) *codecFailure { - c := &codecFailure{m.entryIndex, err} +func (m *Machine) newCodecFailure(entryType wire.Type, err error) *codecFailure { + c := &codecFailure{entryType, m.entryIndex, err} m.failure = c return c } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 4cf991c..8b23c70 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -47,7 +47,12 @@ const ( ClearAllStateEntryMessageType Type = 0x0800 + 3 GetStateKeysEntryMessageType Type = 0x0800 + 4 - //SysCalls + // Promises + GetPromiseEntryMessageType Type = 0x0800 + 8 + PeekPromiseEntryMessageType Type = 0x0800 + 9 + CompletePromiseEntryMessageType Type = 0x0800 + 10 + + // SysCalls SleepEntryMessageType Type = 0x0C00 CallEntryMessageType Type = 0x0C00 + 1 OneWayCallEntryMessageType Type = 0x0C00 + 2 @@ -181,6 +186,12 @@ func MessageType(message Message) Type { return AwakeableEntryMessageType case *CompleteAwakeableEntryMessage: return CompleteAwakeableEntryMessageType + case *GetPromiseEntryMessage: + return GetPromiseEntryMessageType + case *PeekPromiseEntryMessage: + return PeekPromiseEntryMessageType + case *CompletePromiseEntryMessage: + return CompletePromiseEntryMessageType case *RunEntryMessage: return RunEntryMessageType case *SelectorEntryMessage: @@ -415,6 +426,33 @@ var ( return msg, proto.Unmarshal(bytes, msg) }, + GetPromiseEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &GetPromiseEntryMessage{} + + if header.Flag.Completed() { + msg.completable.complete() + } + + return msg, proto.Unmarshal(bytes, msg) + }, + PeekPromiseEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &PeekPromiseEntryMessage{} + + if header.Flag.Completed() { + msg.completable.complete() + } + + return msg, proto.Unmarshal(bytes, msg) + }, + CompletePromiseEntryMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &CompletePromiseEntryMessage{} + + if header.Flag.Completed() { + msg.completable.complete() + } + + return msg, proto.Unmarshal(bytes, msg) + }, RunEntryMessageType: func(header Header, bytes []byte) (Message, error) { msg := &RunEntryMessage{} @@ -532,6 +570,69 @@ func (a *GetStateKeysEntryMessage) Complete(c *protocol.CompletionMessage) error return nil } +type GetPromiseEntryMessage struct { + completable + protocol.GetPromiseEntryMessage +} + +var _ CompleteableMessage = (*GetPromiseEntryMessage)(nil) + +func (a *GetPromiseEntryMessage) Complete(c *protocol.CompletionMessage) error { + switch result := c.Result.(type) { + case *protocol.CompletionMessage_Value: + a.Result = &protocol.GetPromiseEntryMessage_Value{Value: result.Value} + case *protocol.CompletionMessage_Failure: + a.Result = &protocol.GetPromiseEntryMessage_Failure{Failure: result.Failure} + case *protocol.CompletionMessage_Empty: + return fmt.Errorf("received empty completion for getpromise") + } + + a.complete() + return nil +} + +type PeekPromiseEntryMessage struct { + completable + protocol.PeekPromiseEntryMessage +} + +var _ CompleteableMessage = (*PeekPromiseEntryMessage)(nil) + +func (a *PeekPromiseEntryMessage) Complete(c *protocol.CompletionMessage) error { + switch result := c.Result.(type) { + case *protocol.CompletionMessage_Value: + a.Result = &protocol.PeekPromiseEntryMessage_Value{Value: result.Value} + case *protocol.CompletionMessage_Failure: + a.Result = &protocol.PeekPromiseEntryMessage_Failure{Failure: result.Failure} + case *protocol.CompletionMessage_Empty: + a.Result = &protocol.PeekPromiseEntryMessage_Empty{} + } + + a.complete() + return nil +} + +type CompletePromiseEntryMessage struct { + completable + protocol.CompletePromiseEntryMessage +} + +var _ CompleteableMessage = (*CompletePromiseEntryMessage)(nil) + +func (a *CompletePromiseEntryMessage) Complete(c *protocol.CompletionMessage) error { + switch result := c.Result.(type) { + case *protocol.CompletionMessage_Failure: + a.Result = &protocol.CompletePromiseEntryMessage_Failure{Failure: result.Failure} + case *protocol.CompletionMessage_Empty: + a.Result = &protocol.CompletePromiseEntryMessage_Empty{} + case *protocol.CompletionMessage_Value: + return fmt.Errorf("received value completion for complete promise") + } + + a.complete() + return nil +} + type CompletionMessage struct { Header protocol.CompletionMessage diff --git a/protoc-gen-go-restate/restate.go b/protoc-gen-go-restate/restate.go index b768eaf..b78db68 100644 --- a/protoc-gen-go-restate/restate.go +++ b/protoc-gen-go-restate/restate.go @@ -22,8 +22,11 @@ func generateClientStruct(g *protogen.GeneratedFile, service *protogen.Service, g.P("type ", unexport(clientName), " struct {") g.P("ctx ", sdkPackage.Ident("Context")) serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) - if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + switch serviceType { + case sdk.ServiceType_VIRTUAL_OBJECT: g.P("key string") + case sdk.ServiceType_WORKFLOW: + g.P("workflowID string") } g.P("options []", sdkPackage.Ident("ClientOption")) g.P("}") @@ -34,8 +37,11 @@ func generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.S g.P("return &", unexport(clientName), "{") g.P("ctx,") serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) - if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + switch serviceType { + case sdk.ServiceType_VIRTUAL_OBJECT: g.P("key,") + case sdk.ServiceType_WORKFLOW: + g.P("workflowID,") } g.P("cOpts,") g.P("}") @@ -170,8 +176,11 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog } newClientSignature := "New" + clientName + " (ctx " + g.QualifiedGoIdent(sdkPackage.Ident("Context")) serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType) - if serviceType == sdk.ServiceType_VIRTUAL_OBJECT { + switch serviceType { + case sdk.ServiceType_VIRTUAL_OBJECT: newClientSignature += ", key string" + case sdk.ServiceType_WORKFLOW: + newClientSignature += ", workflowID string" } newClientSignature += ", opts..." + g.QualifiedGoIdent(sdkPackage.Ident("ClientOption")) + ") " + clientName @@ -286,6 +295,8 @@ func genClientMethod(gen *protogen.Plugin, g *protogen.GeneratedFile, method *pr getClient = g.QualifiedGoIdent(sdkPackage.Ident("Service")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `",` case sdk.ServiceType_VIRTUAL_OBJECT: getClient = g.QualifiedGoIdent(sdkPackage.Ident("Object")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.key,` + case sdk.ServiceType_WORKFLOW: + getClient = g.QualifiedGoIdent(sdkPackage.Ident("Workflow")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.workflowID,` default: gen.Error(fmt.Errorf("Unexpected service type: %s", serviceType.String())) return @@ -328,6 +339,22 @@ func contextType(gen *protogen.Plugin, g *protogen.GeneratedFile, method *protog gen.Error(fmt.Errorf("Handlers in services of type VIRTUAL_OBJECT must have type SHARED, EXCLUSIVE, or unset (defaults to EXCLUSIVE)")) return "" } + case sdk.ServiceType_WORKFLOW: + switch handlerType { + case sdk.HandlerType_SHARED: + return g.QualifiedGoIdent(sdkPackage.Ident("WorkflowSharedContext")) + case sdk.HandlerType_WORKFLOW_RUN: + return g.QualifiedGoIdent(sdkPackage.Ident("WorkflowContext")) + case sdk.HandlerType_UNSET: + if method.GoName == "Run" { + return g.QualifiedGoIdent(sdkPackage.Ident("WorkflowContext")) + } else { + return g.QualifiedGoIdent(sdkPackage.Ident("WorkflowSharedContext")) + } + default: + gen.Error(fmt.Errorf("Handlers in services of type WORKFLOW must have type SHARED, WORKFLOW_RUN, or unset (defaults to SHARED unless the method name is 'Run')")) + return "" + } default: gen.Error(fmt.Errorf("Unexpected service type: %s", serviceType.String())) return "" @@ -341,6 +368,8 @@ func newRouterType(gen *protogen.Plugin, g *protogen.GeneratedFile, service *pro return g.QualifiedGoIdent(sdkPackage.Ident("NewService")) case sdk.ServiceType_VIRTUAL_OBJECT: return g.QualifiedGoIdent(sdkPackage.Ident("NewObject")) + case sdk.ServiceType_WORKFLOW: + return g.QualifiedGoIdent(sdkPackage.Ident("NewWorkflow")) default: gen.Error(fmt.Errorf("Unexpected service type: %s", serviceType.String())) return "" @@ -369,6 +398,22 @@ func newHandlerType(gen *protogen.Plugin, g *protogen.GeneratedFile, method *pro gen.Error(fmt.Errorf("Handlers in services of type VIRTUAL_OBJECT must have type SHARED, EXCLUSIVE, or unset (defaults to EXCLUSIVE)")) return "" } + case sdk.ServiceType_WORKFLOW: + switch handlerType { + case sdk.HandlerType_SHARED: + return g.QualifiedGoIdent(sdkPackage.Ident("NewWorkflowSharedHandler")) + case sdk.HandlerType_WORKFLOW_RUN: + return g.QualifiedGoIdent(sdkPackage.Ident("NewWorkflowHandler")) + case sdk.HandlerType_UNSET: + if method.GoName == "Run" { + return g.QualifiedGoIdent(sdkPackage.Ident("NewWorkflowHandler")) + } else { + return g.QualifiedGoIdent(sdkPackage.Ident("NewWorkflowSharedHandler")) + } + default: + gen.Error(fmt.Errorf("Handlers in services of type WORKFLOW must have type SHARED, WORKFLOW_RUN, or unset (defaults to SHARED unless the method name is 'Run')")) + return "" + } default: gen.Error(fmt.Errorf("Unexpected service type: %s", serviceType.String())) return "" diff --git a/reflect.go b/reflect.go index 5daf551..e868fe1 100644 --- a/reflect.go +++ b/reflect.go @@ -16,10 +16,12 @@ type serviceNamer interface { } var ( - typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() - typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem() - typeOfSharedObjectContext = reflect.TypeOf((*ObjectSharedContext)(nil)).Elem() - typeOfError = reflect.TypeOf((*error)(nil)).Elem() + typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() + typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem() + typeOfSharedObjectContext = reflect.TypeOf((*ObjectSharedContext)(nil)).Elem() + typeOfWorkflowContext = reflect.TypeOf((*WorkflowContext)(nil)).Elem() + typeOfSharedWorkflowContext = reflect.TypeOf((*WorkflowSharedContext)(nil)).Elem() + typeOfError = reflect.TypeOf((*error)(nil)).Elem() ) // Reflect converts a struct with methods into a service definition where each correctly-typed @@ -34,9 +36,10 @@ var ( // - (ctx) (error) // - (ctx) (O) // - (ctx) (O, error) -// Where ctx is [ObjectContext], [ObjectSharedContext] or [Context]. Other signatures are ignored. +// Where ctx is [WorkflowContext], [WorkflowSharedContext], [ObjectContext], [ObjectSharedContext] or [Context]. Other signatures are ignored. // Signatures without an I or O type will be treated as if [Void] was provided. -// This function will panic if a mixture of object and service method signatures or opts are provided. +// This function will panic if a mixture of object service and workflow method signatures or opts are provided, or if multiple WorkflowContext +// methods are defined. // // Input types will be deserialised with the provided codec (defaults to JSON) except when they are [Void], // in which case no input bytes or content type may be sent. @@ -53,6 +56,7 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio } var definition ServiceDefinition + var foundWorkflowRun bool for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) @@ -98,6 +102,23 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio panic("found a mix of object context arguments and other context arguments") } handlerType = internal.ServiceHandlerType_SHARED + case typeOfWorkflowContext: + if definition == nil { + definition = NewWorkflow(name, opts...) + } else if definition.Type() != internal.ServiceType_WORKFLOW { + panic("found a mix of workflow context arguments and other context arguments") + } else if foundWorkflowRun { + panic("found more than one WorkflowContext argument; a workflow may only have one 'Run' method, the rest must be WorkflowSharedContext.") + } + handlerType = internal.ServiceHandlerType_WORKFLOW + foundWorkflowRun = true + case typeOfSharedWorkflowContext: + if definition == nil { + definition = NewWorkflow(name, opts...) + } else if definition.Type() != internal.ServiceType_WORKFLOW { + panic("found a mix of object context arguments and other context arguments") + } + handlerType = internal.ServiceHandlerType_SHARED default: // first parameter is not a context continue @@ -155,6 +176,17 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio handlerType: &handlerType, }, ) + case *workflow: + def.Handler(mname, &reflectHandler{ + fn: method.Func, + receiver: val, + input: input, + output: output, + hasError: hasError, + options: options.HandlerOptions{}, + handlerType: &handlerType, + }, + ) } } @@ -162,6 +194,10 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio panic("no valid handlers could be found within the exported methods on this struct") } + if definition.Type() == internal.ServiceType_WORKFLOW && !foundWorkflowRun { + panic("no WorkflowContext method found; a workflow must have exactly one 'Run' handler") + } + return definition } @@ -198,6 +234,7 @@ func (h *reflectHandler) HandlerType() *internal.ServiceHandlerType { } func (h *reflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { + var args []reflect.Value if h.input != nil { input := reflect.New(h.input) diff --git a/reflect_test.go b/reflect_test.go index 4e4725a..954ab86 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -22,6 +22,7 @@ type reflectTestParams struct { type expectedMethods = map[string]*internal.ServiceHandlerType var exclusive = internal.ServiceHandlerType_EXCLUSIVE +var workflowRun = internal.ServiceHandlerType_WORKFLOW var shared = internal.ServiceHandlerType_SHARED var tests []reflectTestParams = []reflectTestParams{ @@ -42,6 +43,10 @@ var tests []reflectTestParams = []reflectTestParams{ {rcvr: namedService{}, serviceName: "foobar", expectedMethods: expectedMethods{ "Greet": nil, }}, + {rcvr: validWorkflow{}, serviceName: "validWorkflow", expectedMethods: expectedMethods{ + "Run": &workflowRun, + "Status": &shared, + }}, {rcvr: mixed{}, shouldPanic: true}, {rcvr: empty{}, shouldPanic: true}, } @@ -138,6 +143,16 @@ func (namedService) ServiceName() string { return "foobar" } +type validWorkflow struct{} + +func (validWorkflow) Run(ctx restate.WorkflowContext, _ string) (string, error) { + return "", nil +} + +func (validWorkflow) Status(ctx restate.WorkflowSharedContext, _ string) (string, error) { + return "", nil +} + func (namedService) Greet(ctx restate.Context, _ string) (string, error) { return "", nil } diff --git a/router.go b/router.go index 66680e5..b76d96a 100644 --- a/router.go +++ b/router.go @@ -103,3 +103,35 @@ func (r *object) Handler(name string, handler state.Handler) *object { r.handlers[name] = handler return r } + +type workflow struct { + serviceDefinition +} + +// NewWorkflow creates a new named Workflow +func NewWorkflow(name string, opts ...options.ServiceDefinitionOption) *workflow { + o := options.ServiceDefinitionOptions{} + for _, opt := range opts { + opt.BeforeServiceDefinition(&o) + } + if o.DefaultCodec == nil { + o.DefaultCodec = encoding.JSONCodec + } + return &workflow{ + serviceDefinition: serviceDefinition{ + name: name, + handlers: make(map[string]state.Handler), + options: o, + typ: internal.ServiceType_WORKFLOW, + }, + } +} + +// Handler registers a new Workflow handler by name +func (r *workflow) Handler(name string, handler state.Handler) *workflow { + if handler.GetOptions().Codec == nil { + handler.GetOptions().Codec = r.options.DefaultCodec + } + r.handlers[name] = handler + return r +} diff --git a/test-services/awakeableholder.go b/test-services/awakeableholder.go index c677d31..70e636b 100644 --- a/test-services/awakeableholder.go +++ b/test-services/awakeableholder.go @@ -18,11 +18,11 @@ func init() { })). Handler("hasAwakeable", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (bool, error) { - _, err := restate.Get[string](ctx, ID_KEY) + id, err := restate.Get[string](ctx, ID_KEY) if err != nil { return false, err } - return err == nil, nil + return id != "", nil })). Handler("unlock", restate.NewObjectHandler( func(ctx restate.ObjectContext, payload string) (restate.Void, error) { @@ -34,6 +34,7 @@ func init() { return restate.Void{}, restate.TerminalError(fmt.Errorf("No awakeable registered"), 404) } restate.ResolveAwakeable(ctx, id, payload) + restate.Clear(ctx, ID_KEY) return restate.Void{}, nil }))) } diff --git a/test-services/blockandwaitworkflow.go b/test-services/blockandwaitworkflow.go new file mode 100644 index 0000000..18eecf5 --- /dev/null +++ b/test-services/blockandwaitworkflow.go @@ -0,0 +1,36 @@ +package main + +import ( + restate "github.com/restatedev/sdk-go" +) + +const MY_STATE = "my-state" +const MY_DURABLE_PROMISE = "durable-promise" + +func init() { + REGISTRY.AddDefinition( + restate.NewWorkflow("BlockAndWaitWorkflow"). + Handler("run", restate.NewWorkflowHandler( + func(ctx restate.WorkflowContext, input string) (string, error) { + restate.Set(ctx, MY_STATE, input) + output, err := restate.Promise[string](ctx, MY_DURABLE_PROMISE).Result() + if err != nil { + return "", err + } + + peek, err := restate.Promise[*string](ctx, MY_DURABLE_PROMISE).Peek() + if peek == nil { + return "", restate.TerminalErrorf("Durable promise should be completed") + } + + return output, nil + })). + Handler("unblock", restate.NewWorkflowSharedHandler( + func(ctx restate.WorkflowSharedContext, output string) (restate.Void, error) { + return restate.Void{}, restate.Promise[string](ctx, MY_DURABLE_PROMISE).Resolve(output) + })). + Handler("getState", restate.NewWorkflowSharedHandler( + func(ctx restate.WorkflowSharedContext, input restate.Void) (*string, error) { + return restate.Get[*string](ctx, MY_STATE) + }))) +} diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index c88eb13..4a1fbb8 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,7 +1,4 @@ exclusions: - "default": - - "dev.restate.sdktesting.tests.WorkflowAPI" - "alwaysSuspending": - - "dev.restate.sdktesting.tests.WorkflowAPI" - "singleThreadSinglePartition": - - "dev.restate.sdktesting.tests.WorkflowAPI" + "default": [] + "alwaysSuspending": [] + "singleThreadSinglePartition": []