Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow inline structs in function registration #30

Merged
merged 3 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ jobs:
uses: actions/setup-go@v4
with:
go-version: "1.22"
- name: Check formatting
run: |
if [ "$(gofmt -l . | wc -l)" -gt 0 ]; then
echo "The following files are not formatted correctly:"
gofmt -l .
exit 1
fi
- name: Get dependencies
run: go mod download
- name: Build
Expand Down
2 changes: 1 addition & 1 deletion sdk-go/inferable.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {

func (i *Inferable) getRun(runID string) (*runResult, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("Cluster ID must be provided to manage runs")
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
}

// Prepare headers
Expand Down
6 changes: 3 additions & 3 deletions sdk-go/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ func (c *Client) FetchData(options FetchDataOptions) (string, http.Header, error

resp, err := c.httpClient.Do(req)
if err != nil {
if resp == nil {
return "", nil, fmt.Errorf("error making request: %v", err), -1
}
if resp == nil {
return "", nil, fmt.Errorf("error making request: %v", err), -1
}
return "", nil, fmt.Errorf("error making request: %v", err), resp.StatusCode
}
defer resp.Body.Close()
Expand Down
5 changes: 3 additions & 2 deletions sdk-go/internal/util/test_util.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package util

import (
"github.com/joho/godotenv"
"os"

"github.com/joho/godotenv"
)

func GetTestVars() (string, string, string, string) {
Expand All @@ -17,7 +18,7 @@ func GetTestVars() (string, string, string, string) {
apiEndpoint := os.Getenv("INFERABLE_TEST_API_ENDPOINT")

if apiEndpoint == "" {
panic("INFERABLE_TEST_API_ENDPOINT is not available")
panic("INFERABLE_TEST_API_ENDPOINT is not available")
}
if machineSecret == "" {
panic("INFERABLE_TEST_API_SECRET is not available")
Expand Down
2 changes: 1 addition & 1 deletion sdk-go/internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

const (
MachineIDFile = "inferable_machine_id.json"
MachineIDFile = "inferable_machine_id.json"
)

func GetMachineID() string {
Expand Down
147 changes: 77 additions & 70 deletions sdk-go/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ type EchoInput struct {
Input string
}


func echo(input EchoInput) string {
return input.Input
}
Expand Down Expand Up @@ -150,73 +149,81 @@ func TestInferableFunctions(t *testing.T) {

// This should match the example in the readme
func TestInferableE2E(t *testing.T) {
machineSecret, _, clusterID, apiEndpoint := util.GetTestVars()

client, err := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: machineSecret,
ClusterID: clusterID,
})

if err != nil {
t.Fatalf("Error creating Inferable instance: %v", err)
}

didCallSayHello := false
didCallResultHandler := false

sayHello, err := client.Default.RegisterFunc(Function{
Func: func(input EchoInput) string {
didCallSayHello = true
return "Hello " + input.Input
},
Name: "SayHello",
Description: "A simple greeting function",
})

resultHandler, err := client.Default.RegisterFunc(Function{
Func: func(input OnStatusChangeInput) string {
didCallResultHandler = true
fmt.Println("OnStatusChange: ", input)
return ""
},
Name: "ResultHandler",
})

client.Default.Start()

run, err := client.CreateRun(CreateRunInput{
Message: "Say hello to John Smith",
AttachedFunctions: []*FunctionReference{
sayHello,
},
OnStatusChange: &OnStatusChange{
Function: resultHandler,
},
})

if err != nil {
panic(err)
}

fmt.Println("Run started: ", run.ID)
result, err := run.Poll(nil)
if err != nil {
panic(err)
}
fmt.Println("Run Result: ", result)

time.Sleep(1000 * time.Millisecond)

if result == nil {
t.Error("Result is nil")
}

if !didCallSayHello {
t.Error("SayHello function was not called")
}

if !didCallResultHandler {
t.Error("OnStatusChange function was not called")
}
machineSecret, _, clusterID, apiEndpoint := util.GetTestVars()

client, err := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: machineSecret,
ClusterID: clusterID,
})

if err != nil {
t.Fatalf("Error creating Inferable instance: %v", err)
}

didCallSayHello := false
didCallResultHandler := false

sayHello, err := client.Default.RegisterFunc(Function{
Func: func(input EchoInput) string {
didCallSayHello = true
return "Hello " + input.Input
},
Name: "SayHello",
Description: "A simple greeting function",
})

if err != nil {
t.Fatalf("Error registering SayHello function: %v", err)
}

resultHandler, err := client.Default.RegisterFunc(Function{
Func: func(input OnStatusChangeInput) string {
didCallResultHandler = true
fmt.Println("OnStatusChange: ", input)
return ""
},
Name: "ResultHandler",
})

if err != nil {
t.Fatalf("Error registering ResultHandler function: %v", err)
}

client.Default.Start()

run, err := client.CreateRun(CreateRunInput{
Message: "Say hello to John Smith",
AttachedFunctions: []*FunctionReference{
sayHello,
},
OnStatusChange: &OnStatusChange{
Function: resultHandler,
},
})

if err != nil {
panic(err)
}

fmt.Println("Run started: ", run.ID)
result, err := run.Poll(nil)
if err != nil {
panic(err)
}
fmt.Println("Run Result: ", result)

time.Sleep(1000 * time.Millisecond)

if result == nil {
t.Error("Result is nil")
}

if !didCallSayHello {
t.Error("SayHello function was not called")
}

if !didCallResultHandler {
t.Error("OnStatusChange function was not called")
}
}
44 changes: 24 additions & 20 deletions sdk-go/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,31 +69,31 @@ type FunctionReference struct {
//
// Example:
//
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
//
// // Define and register the service
// service := client.Service("MyService")
// // Define and register the service
// service := client.Service("MyService")
//
// sayHello, err := service.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
// sayHello, err := service.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
//
// // Start the service
// service.Start()
// // Start the service
// service.Start()
//
// // Stop the service on shutdown
// defer service.Stop()
// // Stop the service on shutdown
// defer service.Stop()
func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {
if s.isPolling() {
return nil, fmt.Errorf("functions must be registered before starting the service.")
return nil, fmt.Errorf("functions must be registered before starting the service")
}

if _, exists := s.Functions[fn.Name]; exists {
Expand All @@ -120,8 +120,12 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {

// Extract the relevant part of the schema
defs, ok := schema.Definitions[argType.Name()]

// If the definition is not found, use the whole schema.
// This tends to happen for inline structs.
// For example: func(input struct { A int `json:"a"` }) int
if !ok {
return nil, fmt.Errorf("failed to find schema definition for %s", argType.Name())
defs = schema
}

defsString, err := json.Marshal(defs)
Expand Down
38 changes: 38 additions & 0 deletions sdk-go/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ func TestRegisterFunc(t *testing.T) {
assert.Error(t, err)
}

func TestRegisterFuncWithInlineStruct(t *testing.T) {
_, _, _, apiEndpoint := util.GetTestVars()

i, _ := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: "test-secret",
})
service, _ := i.RegisterService("TestService1")

testFunc := func(input struct {
A int `json:"a"`
B int `json:"b"`
}) int {
return input.A + input.B
}
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
Description: "Test function",
})
require.NoError(t, err)

// Try to register the same function again
_, err = service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
})
assert.Error(t, err)

// Try to register a function with invalid input
invalidFunc := func(a, b int) int { return a + b }
_, err = service.RegisterFunc(Function{
Func: invalidFunc,
Name: "InvalidFunc",
})
assert.Error(t, err)
}

func TestRegistrationAndConfig(t *testing.T) {
machineSecret, _, _, apiEndpoint := util.GetTestVars()

Expand Down
36 changes: 0 additions & 36 deletions workflows/.github/workflows/build.yml

This file was deleted.

Loading
Loading