Skip to content

Commit

Permalink
feat: Support Context object in go-sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Dec 19, 2024
1 parent b2ee27b commit e984fe7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 21 deletions.
2 changes: 1 addition & 1 deletion sdk-go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ type UserInput struct {
Tags []string `json:"tags" jsonschema:"uniqueItems=true"`
}

func createUser(input UserInput) string {
func createUser(input UserInput, ctx inferable.ContextInput) string {
// Function implementation
}

Expand Down
24 changes: 16 additions & 8 deletions sdk-go/inferable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ func TestCallFunc(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
i.Default.RegisterFunc(Function{
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := i.Default.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
})

assert.NoError(t, err)

result, err := i.callFunc("default", "TestFunc", TestInput{A: 2, B: 3})
require.NoError(t, err)
assert.Equal(t, 5, result[0].Interface())
Expand All @@ -85,13 +87,15 @@ func TestToJSONDefinition(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
service.RegisterFunc(Function{
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
Description: "Test function",
})

assert.NoError(t, err)

jsonDef, err := i.toJSONDefinition()
require.NoError(t, err)

Expand Down Expand Up @@ -166,25 +170,29 @@ func TestGetSchema(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
service.RegisterFunc(Function{
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
})

assert.NoError(t, err)

type TestInput2 struct {
C struct {
D int `json:"d"`
E []int `json:"e"`
} `json:"c"`
}

testFunc2 := func(input TestInput2) int { return input.C.D * 2 }
service.RegisterFunc(Function{
testFunc2 := func(input TestInput2, ctx ContextInput) int { return input.C.D * 2 }
_, err = service.RegisterFunc(Function{
Func: testFunc2,
Name: "TestFunc2",
})

assert.NoError(t, err)

schema, err := service.getSchema()
require.NoError(t, err)
assert.Equal(t, "TestFunc", schema["TestFunc"].(map[string]interface{})["name"])
Expand Down
44 changes: 32 additions & 12 deletions sdk-go/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ type Function struct {
Func interface{}
}

type ContextInput struct {
AuthContext interface{} `json:"authContext,omitempty"`
RunContext interface{} `json:"runContext,omitempty"`
approved bool `json:"approved"`
}

type service struct {
Name string
Functions map[string]Function
Expand All @@ -38,9 +44,12 @@ type service struct {
}

type callMessage struct {
Id string `json:"id"`
Function string `json:"function"`
Input interface{} `json:"input"`
Id string `json:"id"`
Function string `json:"function"`
Input interface{} `json:"input"`
AuthContext interface{} `json:"authContext,omitempty"`
RunContext interface{} `json:"runContext,omitempty"`
Approved bool `json:"approved"`
}

type callResultMeta struct {
Expand Down Expand Up @@ -101,30 +110,35 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {

// Validate that the function has exactly one argument and it's a struct
fnType := reflect.TypeOf(fn.Func)
if fnType.NumIn() != 1 {
return nil, fmt.Errorf("function '%s' must have exactly one argument", fn.Name)
if fnType.NumIn() != 2 {
return nil, fmt.Errorf("function '%s' must have exactly two arguments", fn.Name)
}
arg1Type := fnType.In(0)
arg2Type := fnType.In(1)

if arg2Type.Kind() != reflect.Struct {
return nil, fmt.Errorf("function '%s' second argument must be a struct (ContextInput)", fn.Name)
}
argType := fnType.In(0)

// Set the argument type to the referenced type
if argType.Kind() == reflect.Ptr {
argType = argType.Elem()
if arg1Type.Kind() == reflect.Ptr {
arg1Type = arg1Type.Elem()
}

if argType.Kind() != reflect.Struct {
if arg1Type.Kind() != reflect.Struct {
return nil, fmt.Errorf("function '%s' first argument must be a struct or a pointer to a struct", fn.Name)
}

// Get the schema for the input struct
reflector := jsonschema.Reflector{DoNotReference: true, Anonymous: true, AllowAdditionalProperties: false}
schema := reflector.Reflect(reflect.New(argType).Interface())
schema := reflector.Reflect(reflect.New(arg1Type).Interface())

if schema == nil {
return nil, fmt.Errorf("failed to get schema for function '%s'", fn.Name)
}

// Extract the relevant part of the schema
defs, ok := schema.Definitions[argType.Name()]
defs, ok := schema.Definitions[arg1Type.Name()]

// If the definition is not found, use the whole schema.
// This tends to happen for inline structs.
Expand Down Expand Up @@ -279,10 +293,16 @@ func (s *service) handleMessage(msg callMessage) error {
return fmt.Errorf("failed to unmarshal input: %v", err)
}

context := ContextInput{
AuthContext: msg.AuthContext,
RunContext: msg.RunContext,
approved: msg.Approved,
}

start := time.Now()
// Call the function with the unmarshaled argument
fnValue := reflect.ValueOf(fn.Func)
returnValues := fnValue.Call([]reflect.Value{argPtr.Elem()})
returnValues := fnValue.Call([]reflect.Value{argPtr.Elem(), reflect.ValueOf(context)})

resultType := "resolution"
resultValue := returnValues[0].Interface()
Expand Down

0 comments on commit e984fe7

Please sign in to comment.