Skip to content

Commit

Permalink
feat: Support Context object in go-sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Dec 19, 2024
1 parent b2ee27b commit 08d0aad
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 31 deletions.
2 changes: 1 addition & 1 deletion sdk-go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ type UserInput struct {
Tags []string `json:"tags" jsonschema:"uniqueItems=true"`
}

func createUser(input UserInput) string {
func createUser(input UserInput, ctx inferable.ContextInput) string {
// Function implementation
}

Expand Down
26 changes: 17 additions & 9 deletions sdk-go/inferable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ func TestCallFunc(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
i.Default.RegisterFunc(Function{
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := i.Default.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
})

result, err := i.callFunc("default", "TestFunc", TestInput{A: 2, B: 3})
assert.NoError(t, err)

result, err := i.callFunc("default", "TestFunc", TestInput{A: 2, B: 3}, ContextInput{})
require.NoError(t, err)
assert.Equal(t, 5, result[0].Interface())

Expand All @@ -85,13 +87,15 @@ func TestToJSONDefinition(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
service.RegisterFunc(Function{
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
Description: "Test function",
})

assert.NoError(t, err)

jsonDef, err := i.toJSONDefinition()
require.NoError(t, err)

Expand Down Expand Up @@ -166,25 +170,29 @@ func TestGetSchema(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
service.RegisterFunc(Function{
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
})

assert.NoError(t, err)

type TestInput2 struct {
C struct {
D int `json:"d"`
E []int `json:"e"`
} `json:"c"`
}

testFunc2 := func(input TestInput2) int { return input.C.D * 2 }
service.RegisterFunc(Function{
testFunc2 := func(input TestInput2, ctx ContextInput) int { return input.C.D * 2 }
_, err = service.RegisterFunc(Function{
Func: testFunc2,
Name: "TestFunc2",
})

assert.NoError(t, err)

schema, err := service.getSchema()
require.NoError(t, err)
assert.Equal(t, "TestFunc", schema["TestFunc"].(map[string]interface{})["name"])
Expand Down
8 changes: 4 additions & 4 deletions sdk-go/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ type EchoInput struct {
Input string
}

func echo(input EchoInput) string {
func echo(input EchoInput, ctx ContextInput) string {
return input.Input
}

type ReverseInput struct {
Input string
}

func reverse(input ReverseInput) string {
func reverse(input ReverseInput, ctx ContextInput) string {
runes := []rune(input.Input)
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
runes[i], runes[j] = runes[j], runes[i]
Expand Down Expand Up @@ -161,7 +161,7 @@ func TestInferableE2E(t *testing.T) {
didCallResultHandler := false

sayHello, err := client.Default.RegisterFunc(Function{
Func: func(input EchoInput) string {
Func: func(input EchoInput, ctx ContextInput) string {
didCallSayHello = true
return "Hello " + input.Input
},
Expand All @@ -174,7 +174,7 @@ func TestInferableE2E(t *testing.T) {
}

resultHandler, err := client.Default.RegisterFunc(Function{
Func: func(input OnStatusChangeInput) string {
Func: func(input OnStatusChangeInput, ctx ContextInput) string {
didCallResultHandler = true
fmt.Println("OnStatusChange: ", input)
return ""
Expand Down
44 changes: 32 additions & 12 deletions sdk-go/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ type Function struct {
Func interface{}
}

type ContextInput struct {
AuthContext interface{} `json:"authContext,omitempty"`
RunContext interface{} `json:"runContext,omitempty"`
approved bool `json:"approved"`
}

type service struct {
Name string
Functions map[string]Function
Expand All @@ -38,9 +44,12 @@ type service struct {
}

type callMessage struct {
Id string `json:"id"`
Function string `json:"function"`
Input interface{} `json:"input"`
Id string `json:"id"`
Function string `json:"function"`
Input interface{} `json:"input"`
AuthContext interface{} `json:"authContext,omitempty"`
RunContext interface{} `json:"runContext,omitempty"`
Approved bool `json:"approved"`
}

type callResultMeta struct {
Expand Down Expand Up @@ -101,30 +110,35 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {

// Validate that the function has exactly one argument and it's a struct
fnType := reflect.TypeOf(fn.Func)
if fnType.NumIn() != 1 {
return nil, fmt.Errorf("function '%s' must have exactly one argument", fn.Name)
if fnType.NumIn() != 2 {
return nil, fmt.Errorf("function '%s' must have exactly two arguments", fn.Name)
}
arg1Type := fnType.In(0)
arg2Type := fnType.In(1)

if arg2Type.Kind() != reflect.Struct {
return nil, fmt.Errorf("function '%s' second argument must be a struct (ContextInput)", fn.Name)
}
argType := fnType.In(0)

// Set the argument type to the referenced type
if argType.Kind() == reflect.Ptr {
argType = argType.Elem()
if arg1Type.Kind() == reflect.Ptr {
arg1Type = arg1Type.Elem()
}

if argType.Kind() != reflect.Struct {
if arg1Type.Kind() != reflect.Struct {
return nil, fmt.Errorf("function '%s' first argument must be a struct or a pointer to a struct", fn.Name)
}

// Get the schema for the input struct
reflector := jsonschema.Reflector{DoNotReference: true, Anonymous: true, AllowAdditionalProperties: false}
schema := reflector.Reflect(reflect.New(argType).Interface())
schema := reflector.Reflect(reflect.New(arg1Type).Interface())

if schema == nil {
return nil, fmt.Errorf("failed to get schema for function '%s'", fn.Name)
}

// Extract the relevant part of the schema
defs, ok := schema.Definitions[argType.Name()]
defs, ok := schema.Definitions[arg1Type.Name()]

// If the definition is not found, use the whole schema.
// This tends to happen for inline structs.
Expand Down Expand Up @@ -279,10 +293,16 @@ func (s *service) handleMessage(msg callMessage) error {
return fmt.Errorf("failed to unmarshal input: %v", err)
}

context := ContextInput{
AuthContext: msg.AuthContext,
RunContext: msg.RunContext,
approved: msg.Approved,
}

start := time.Now()
// Call the function with the unmarshaled argument
fnValue := reflect.ValueOf(fn.Func)
returnValues := fnValue.Call([]reflect.Value{argPtr.Elem()})
returnValues := fnValue.Call([]reflect.Value{argPtr.Elem(), reflect.ValueOf(context)})

resultType := "resolution"
resultValue := returnValues[0].Interface()
Expand Down
10 changes: 5 additions & 5 deletions sdk-go/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestRegisterFunc(t *testing.T) {
B int `json:"b"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
Expand Down Expand Up @@ -64,7 +64,7 @@ func TestRegisterFuncWithInlineStruct(t *testing.T) {
testFunc := func(input struct {
A int `json:"a"`
B int `json:"b"`
}) int {
}, ctx ContextInput) int {
return input.A + input.B
}
_, err := service.RegisterFunc(Function{
Expand Down Expand Up @@ -118,7 +118,7 @@ func TestRegistrationAndConfig(t *testing.T) {
} `json:"c"`
}

testFunc := func(input TestInput) int { return input.A + input.B }
testFunc := func(input TestInput, ctx ContextInput) int { return input.A + input.B }

_, err = service.RegisterFunc(Function{
Func: testFunc,
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestServiceStartAndReceiveMessage(t *testing.T) {
Message string `json:"message"`
}

testFunc := func(input TestInput) string { return "Received: " + input.Message }
testFunc := func(input TestInput, ctx ContextInput) string { return "Received: " + input.Message }

_, err = service.RegisterFunc(Function{
Func: testFunc,
Expand Down Expand Up @@ -231,7 +231,7 @@ func TestServiceStartAndReceiveFailingMessage(t *testing.T) {
}

// Purposfuly failing function
testFailingFunc := func(input TestInput) (*string, error) { return nil, fmt.Errorf("test error") }
testFailingFunc := func(input TestInput, ctx ContextInput) (*string, error) { return nil, fmt.Errorf("test error") }

_, err = service.RegisterFunc(Function{
Func: testFailingFunc,
Expand Down

0 comments on commit 08d0aad

Please sign in to comment.