diff --git a/sdk-go/README.md b/sdk-go/README.md index c483af7..baf7fb9 100644 --- a/sdk-go/README.md +++ b/sdk-go/README.md @@ -84,7 +84,7 @@ type UserInput struct { Tags []string `json:"tags" jsonschema:"uniqueItems=true"` } -func createUser(input UserInput) string { +func createUser(input UserInput, ctx inferable.ContextInput) string { // Function implementation } diff --git a/sdk-go/inferable_test.go b/sdk-go/inferable_test.go index 559b17e..eedf41b 100644 --- a/sdk-go/inferable_test.go +++ b/sdk-go/inferable_test.go @@ -58,12 +58,14 @@ func TestCallFunc(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } - i.Default.RegisterFunc(Function{ + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } + _, err := i.Default.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", }) + assert.NoError(t, err) + result, err := i.callFunc("default", "TestFunc", TestInput{A: 2, B: 3}) require.NoError(t, err) assert.Equal(t, 5, result[0].Interface()) @@ -85,13 +87,15 @@ func TestToJSONDefinition(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } - service.RegisterFunc(Function{ + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } + _, err := service.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", Description: "Test function", }) + assert.NoError(t, err) + jsonDef, err := i.toJSONDefinition() require.NoError(t, err) @@ -166,12 +170,14 @@ func TestGetSchema(t *testing.T) { B int `json:"b"` } - testFunc := func(input TestInput) int { return input.A + input.B } - service.RegisterFunc(Function{ + testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B } + _, err := service.RegisterFunc(Function{ Func: testFunc, Name: "TestFunc", }) + assert.NoError(t, err) + type TestInput2 struct { C struct { D int `json:"d"` @@ -179,12 +185,14 @@ func TestGetSchema(t *testing.T) { } `json:"c"` } - testFunc2 := func(input TestInput2) int { return input.C.D * 2 } - service.RegisterFunc(Function{ + testFunc2 := func(input TestInput2, ctx ContextInput) int { return input.C.D * 2 } + _, err = service.RegisterFunc(Function{ Func: testFunc2, Name: "TestFunc2", }) + assert.NoError(t, err) + schema, err := service.getSchema() require.NoError(t, err) assert.Equal(t, "TestFunc", schema["TestFunc"].(map[string]interface{})["name"]) diff --git a/sdk-go/service.go b/sdk-go/service.go index 0c9cd8f..ed54b5e 100644 --- a/sdk-go/service.go +++ b/sdk-go/service.go @@ -28,6 +28,12 @@ type Function struct { Func interface{} } +type ContextInput struct { + AuthContext interface{} `json:"authContext,omitempty"` + RunContext interface{} `json:"runContext,omitempty"` + approved bool `json:"approved"` +} + type service struct { Name string Functions map[string]Function @@ -38,9 +44,12 @@ type service struct { } type callMessage struct { - Id string `json:"id"` - Function string `json:"function"` - Input interface{} `json:"input"` + Id string `json:"id"` + Function string `json:"function"` + Input interface{} `json:"input"` + AuthContext interface{} `json:"authContext,omitempty"` + RunContext interface{} `json:"runContext,omitempty"` + Approved bool `json:"approved"` } type callResultMeta struct { @@ -101,30 +110,35 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) { // Validate that the function has exactly one argument and it's a struct fnType := reflect.TypeOf(fn.Func) - if fnType.NumIn() != 1 { - return nil, fmt.Errorf("function '%s' must have exactly one argument", fn.Name) + if fnType.NumIn() != 2 { + return nil, fmt.Errorf("function '%s' must have exactly two arguments", fn.Name) + } + arg1Type := fnType.In(0) + arg2Type := fnType.In(1) + + if arg2Type.Kind() != reflect.Struct { + return nil, fmt.Errorf("function '%s' second argument must be a struct (ContextInput)", fn.Name) } - argType := fnType.In(0) // Set the argument type to the referenced type - if argType.Kind() == reflect.Ptr { - argType = argType.Elem() + if arg1Type.Kind() == reflect.Ptr { + arg1Type = arg1Type.Elem() } - if argType.Kind() != reflect.Struct { + if arg1Type.Kind() != reflect.Struct { return nil, fmt.Errorf("function '%s' first argument must be a struct or a pointer to a struct", fn.Name) } // Get the schema for the input struct reflector := jsonschema.Reflector{DoNotReference: true, Anonymous: true, AllowAdditionalProperties: false} - schema := reflector.Reflect(reflect.New(argType).Interface()) + schema := reflector.Reflect(reflect.New(arg1Type).Interface()) if schema == nil { return nil, fmt.Errorf("failed to get schema for function '%s'", fn.Name) } // Extract the relevant part of the schema - defs, ok := schema.Definitions[argType.Name()] + defs, ok := schema.Definitions[arg1Type.Name()] // If the definition is not found, use the whole schema. // This tends to happen for inline structs. @@ -279,10 +293,16 @@ func (s *service) handleMessage(msg callMessage) error { return fmt.Errorf("failed to unmarshal input: %v", err) } + context := ContextInput{ + AuthContext: msg.AuthContext, + RunContext: msg.RunContext, + approved: msg.Approved, + } + start := time.Now() // Call the function with the unmarshaled argument fnValue := reflect.ValueOf(fn.Func) - returnValues := fnValue.Call([]reflect.Value{argPtr.Elem()}) + returnValues := fnValue.Call([]reflect.Value{argPtr.Elem(), reflect.ValueOf(context)}) resultType := "resolution" resultValue := returnValues[0].Interface()