From f121e3b4f150cb0651c07a4d873eb6fd8ea895fc Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 19 Dec 2024 17:16:28 +1030 Subject: [PATCH] feat: Support Context object in go-sdk --- sdk-go/README.md | 2 +- sdk-go/inferable_test.go | 26 ++++++++++++++++-------- sdk-go/main_test.go | 6 +++--- sdk-go/service.go | 44 +++++++++++++++++++++++++++++----------- sdk-go/service_test.go | 10 ++++----- 5 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sdk-go/README.md b/sdk-go/README.md index c483af7b..baf7fb96 100644 --- a/sdk-go/README.md +++ b/sdk-go/README.md @@ -84,7 +84,7 @@ type UserInput struct { Tags []string `json:"tags" jsonschema:"uniqueItems=true"` } -func createUser(input UserInput) string { +func createUser(input UserInput, ctx inferable.ContextInput) string { // Function implementation } diff --git a/sdk-go/inferable_test.go b/sdk-go/inferable_test.go index 559b17eb..1be63736 100644 --- a/sdk-go/inferable_test.go +++ b/sdk-go/inferable_test.go @@ -58,13 +58,15 @@ func TestCallFunc(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } - i.Default.RegisterFunc(Function{ + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } + _, err := i.Default.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", }) - result, err := i.callFunc("default", "TestFunc", TestInput{A: 2, B: 3}) + assert.NoError(t, err) + + result, err := i.callFunc("default", "TestFunc", TestInput{A: 2, B: 3}, ContextInput{}) require.NoError(t, err) assert.Equal(t, 5, result[0].Interface()) @@ -85,13 +87,15 @@ func TestToJSONDefinition(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } - service.RegisterFunc(Function{ + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } + _, err := service.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", Description: "Test function", }) + assert.NoError(t, err) + jsonDef, err := i.toJSONDefinition() require.NoError(t, err) @@ -166,12 +170,14 @@ func TestGetSchema(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } - service.RegisterFunc(Function{ + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } + _, err := service.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", }) + assert.NoError(t, err) + type TestInput2 struct { C struct { D int `json:"d"` @@ -179,12 +185,14 @@ func TestGetSchema(t *testing.T) { } `json:"c"` } - testFunc2 := func(input TestInput2) int { return input.C.D * 2 } - service.RegisterFunc(Function{ + testFunc2 := func(input TestInput2, ctx ContextInput) int { return input.C.D * 2 } + _, err = service.RegisterFunc(Function{ Func: testFunc2, Name: "TestFunc2", }) + assert.NoError(t, err) + schema, err := service.getSchema() require.NoError(t, err) assert.Equal(t, "TestFunc", schema["TestFunc"].(map[string]interface{})["name"]) diff --git a/sdk-go/main_test.go b/sdk-go/main_test.go index 3ebe87b3..23d38be5 100644 --- a/sdk-go/main_test.go +++ b/sdk-go/main_test.go @@ -12,7 +12,7 @@ type EchoInput struct { Input string } -func echo(input EchoInput) string { +func echo(input EchoInput, ctx ContextInput) string { return input.Input } @@ -161,7 +161,7 @@ func TestInferableE2E(t *testing.T) { didCallResultHandler := false sayHello, err := client.Default.RegisterFunc(Function{ - Func: func(input EchoInput) string { + Func: func(input EchoInput, ctx ContextInput) string { didCallSayHello = true return "Hello " + input.Input }, @@ -174,7 +174,7 @@ func TestInferableE2E(t *testing.T) { } resultHandler, err := client.Default.RegisterFunc(Function{ - Func: func(input OnStatusChangeInput) string { + Func: func(input OnStatusChangeInput, ctx ContextInput) string { didCallResultHandler = true fmt.Println("OnStatusChange: ", input) return "" diff --git a/sdk-go/service.go b/sdk-go/service.go index 0c9cd8f7..ed54b5e0 100644 --- a/sdk-go/service.go +++ b/sdk-go/service.go @@ -28,6 +28,12 @@ type Function struct { Func interface{} } +type ContextInput struct { + AuthContext interface{} `json:"authContext,omitempty"` + RunContext interface{} `json:"runContext,omitempty"` + approved bool `json:"approved"` +} + type service struct { Name string Functions map[string]Function @@ -38,9 +44,12 @@ type service struct { } type callMessage struct { - Id string `json:"id"` - Function string `json:"function"` - Input interface{} `json:"input"` + Id string `json:"id"` + Function string `json:"function"` + Input interface{} `json:"input"` + AuthContext interface{} `json:"authContext,omitempty"` + RunContext interface{} `json:"runContext,omitempty"` + Approved bool `json:"approved"` } type callResultMeta struct { @@ -101,30 +110,35 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) { // Validate that the function has exactly one argument and it's a struct fnType := reflect.TypeOf(fn.Func) - if fnType.NumIn() != 1 { - return nil, fmt.Errorf("function '%s' must have exactly one argument", fn.Name) + if fnType.NumIn() != 2 { + return nil, fmt.Errorf("function '%s' must have exactly two arguments", fn.Name) + } + arg1Type := fnType.In(0) + arg2Type := fnType.In(1) + + if arg2Type.Kind() != reflect.Struct { + return nil, fmt.Errorf("function '%s' second argument must be a struct (ContextInput)", fn.Name) } - argType := fnType.In(0) // Set the argument type to the referenced type - if argType.Kind() == reflect.Ptr { - argType = argType.Elem() + if arg1Type.Kind() == reflect.Ptr { + arg1Type = arg1Type.Elem() } - if argType.Kind() != reflect.Struct { + if arg1Type.Kind() != reflect.Struct { return nil, fmt.Errorf("function '%s' first argument must be a struct or a pointer to a struct", fn.Name) } // Get the schema for the input struct reflector := jsonschema.Reflector{DoNotReference: true, Anonymous: true, AllowAdditionalProperties: false} - schema := reflector.Reflect(reflect.New(argType).Interface()) + schema := reflector.Reflect(reflect.New(arg1Type).Interface()) if schema == nil { return nil, fmt.Errorf("failed to get schema for function '%s'", fn.Name) } // Extract the relevant part of the schema - defs, ok := schema.Definitions[argType.Name()] + defs, ok := schema.Definitions[arg1Type.Name()] // If the definition is not found, use the whole schema. // This tends to happen for inline structs. @@ -279,10 +293,16 @@ func (s *service) handleMessage(msg callMessage) error { return fmt.Errorf("failed to unmarshal input: %v", err) } + context := ContextInput{ + AuthContext: msg.AuthContext, + RunContext: msg.RunContext, + approved: msg.Approved, + } + start := time.Now() // Call the function with the unmarshaled argument fnValue := reflect.ValueOf(fn.Func) - returnValues := fnValue.Call([]reflect.Value{argPtr.Elem()}) + returnValues := fnValue.Call([]reflect.Value{argPtr.Elem(), reflect.ValueOf(context)}) resultType := "resolution" resultValue := returnValues[0].Interface() diff --git a/sdk-go/service_test.go b/sdk-go/service_test.go index 74803320..9a406e9d 100644 --- a/sdk-go/service_test.go +++ b/sdk-go/service_test.go @@ -28,7 +28,7 @@ func TestRegisterFunc(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } _, err := service.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", @@ -64,7 +64,7 @@ func TestRegisterFuncWithInlineStruct(t *testing.T) { testFunc := func(input struct { A int `json:"a"` B int `json:"b"` - }) int { + }, ctx ContextInput) int { return input.A + input.B } _, err := service.RegisterFunc(Function{ @@ -118,7 +118,7 @@ func TestRegistrationAndConfig(t *testing.T) { } `json:"c"` } - testFunc := func(input TestInput) int { return input.A + input.B } + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } _, err = service.RegisterFunc(Function{ Func: testFunc, @@ -155,7 +155,7 @@ func TestServiceStartAndReceiveMessage(t *testing.T) { Message string `json:"message"` } - testFunc := func(input TestInput) string { return "Received: " + input.Message } + testFunc := func(input TestInput, ctx ContextInput) string { return "Received: " + input.Message } _, err = service.RegisterFunc(Function{ Func: testFunc, @@ -231,7 +231,7 @@ func TestServiceStartAndReceiveFailingMessage(t *testing.T) { } // Purposfuly failing function - testFailingFunc := func(input TestInput) (*string, error) { return nil, fmt.Errorf("test error") } + testFailingFunc := func(input TestInput, ctx ContextInput) (*string, error) { return nil, fmt.Errorf("test error") } _, err = service.RegisterFunc(Function{ Func: testFailingFunc,