Skip to content

Commit

Permalink
fix: Allow inline structs in function registration (#30)
Browse files Browse the repository at this point in the history
* fix: Enhance error handling in function registration

* fix: Add gofmt check

* fix: Fix the gofmt errors
  • Loading branch information
nadeesha authored Nov 2, 2024
1 parent a495ffe commit 8c60d3c
Show file tree
Hide file tree
Showing 16 changed files with 154 additions and 479 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ jobs:
uses: actions/setup-go@v4
with:
go-version: "1.22"
- name: Check formatting
run: |
if [ "$(gofmt -l . | wc -l)" -gt 0 ]; then
echo "The following files are not formatted correctly:"
gofmt -l .
exit 1
fi
- name: Get dependencies
run: go mod download
- name: Build
Expand Down
2 changes: 1 addition & 1 deletion sdk-go/inferable.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {

func (i *Inferable) getRun(runID string) (*runResult, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("Cluster ID must be provided to manage runs")
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
}

// Prepare headers
Expand Down
6 changes: 3 additions & 3 deletions sdk-go/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ func (c *Client) FetchData(options FetchDataOptions) (string, http.Header, error

resp, err := c.httpClient.Do(req)
if err != nil {
if resp == nil {
return "", nil, fmt.Errorf("error making request: %v", err), -1
}
if resp == nil {
return "", nil, fmt.Errorf("error making request: %v", err), -1
}
return "", nil, fmt.Errorf("error making request: %v", err), resp.StatusCode
}
defer resp.Body.Close()
Expand Down
5 changes: 3 additions & 2 deletions sdk-go/internal/util/test_util.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package util

import (
"github.com/joho/godotenv"
"os"

"github.com/joho/godotenv"
)

func GetTestVars() (string, string, string, string) {
Expand All @@ -17,7 +18,7 @@ func GetTestVars() (string, string, string, string) {
apiEndpoint := os.Getenv("INFERABLE_TEST_API_ENDPOINT")

if apiEndpoint == "" {
panic("INFERABLE_TEST_API_ENDPOINT is not available")
panic("INFERABLE_TEST_API_ENDPOINT is not available")
}
if machineSecret == "" {
panic("INFERABLE_TEST_API_SECRET is not available")
Expand Down
2 changes: 1 addition & 1 deletion sdk-go/internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

const (
MachineIDFile = "inferable_machine_id.json"
MachineIDFile = "inferable_machine_id.json"
)

func GetMachineID() string {
Expand Down
147 changes: 77 additions & 70 deletions sdk-go/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ type EchoInput struct {
Input string
}


func echo(input EchoInput) string {
return input.Input
}
Expand Down Expand Up @@ -150,73 +149,81 @@ func TestInferableFunctions(t *testing.T) {

// This should match the example in the readme
func TestInferableE2E(t *testing.T) {
machineSecret, _, clusterID, apiEndpoint := util.GetTestVars()

client, err := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: machineSecret,
ClusterID: clusterID,
})

if err != nil {
t.Fatalf("Error creating Inferable instance: %v", err)
}

didCallSayHello := false
didCallResultHandler := false

sayHello, err := client.Default.RegisterFunc(Function{
Func: func(input EchoInput) string {
didCallSayHello = true
return "Hello " + input.Input
},
Name: "SayHello",
Description: "A simple greeting function",
})

resultHandler, err := client.Default.RegisterFunc(Function{
Func: func(input OnStatusChangeInput) string {
didCallResultHandler = true
fmt.Println("OnStatusChange: ", input)
return ""
},
Name: "ResultHandler",
})

client.Default.Start()

run, err := client.CreateRun(CreateRunInput{
Message: "Say hello to John Smith",
AttachedFunctions: []*FunctionReference{
sayHello,
},
OnStatusChange: &OnStatusChange{
Function: resultHandler,
},
})

if err != nil {
panic(err)
}

fmt.Println("Run started: ", run.ID)
result, err := run.Poll(nil)
if err != nil {
panic(err)
}
fmt.Println("Run Result: ", result)

time.Sleep(1000 * time.Millisecond)

if result == nil {
t.Error("Result is nil")
}

if !didCallSayHello {
t.Error("SayHello function was not called")
}

if !didCallResultHandler {
t.Error("OnStatusChange function was not called")
}
machineSecret, _, clusterID, apiEndpoint := util.GetTestVars()

client, err := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: machineSecret,
ClusterID: clusterID,
})

if err != nil {
t.Fatalf("Error creating Inferable instance: %v", err)
}

didCallSayHello := false
didCallResultHandler := false

sayHello, err := client.Default.RegisterFunc(Function{
Func: func(input EchoInput) string {
didCallSayHello = true
return "Hello " + input.Input
},
Name: "SayHello",
Description: "A simple greeting function",
})

if err != nil {
t.Fatalf("Error registering SayHello function: %v", err)
}

resultHandler, err := client.Default.RegisterFunc(Function{
Func: func(input OnStatusChangeInput) string {
didCallResultHandler = true
fmt.Println("OnStatusChange: ", input)
return ""
},
Name: "ResultHandler",
})

if err != nil {
t.Fatalf("Error registering ResultHandler function: %v", err)
}

client.Default.Start()

run, err := client.CreateRun(CreateRunInput{
Message: "Say hello to John Smith",
AttachedFunctions: []*FunctionReference{
sayHello,
},
OnStatusChange: &OnStatusChange{
Function: resultHandler,
},
})

if err != nil {
panic(err)
}

fmt.Println("Run started: ", run.ID)
result, err := run.Poll(nil)
if err != nil {
panic(err)
}
fmt.Println("Run Result: ", result)

time.Sleep(1000 * time.Millisecond)

if result == nil {
t.Error("Result is nil")
}

if !didCallSayHello {
t.Error("SayHello function was not called")
}

if !didCallResultHandler {
t.Error("OnStatusChange function was not called")
}
}
44 changes: 24 additions & 20 deletions sdk-go/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,31 +69,31 @@ type FunctionReference struct {
//
// Example:
//
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
//
// // Define and register the service
// service := client.Service("MyService")
// // Define and register the service
// service := client.Service("MyService")
//
// sayHello, err := service.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
// sayHello, err := service.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
//
// // Start the service
// service.Start()
// // Start the service
// service.Start()
//
// // Stop the service on shutdown
// defer service.Stop()
// // Stop the service on shutdown
// defer service.Stop()
func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {
if s.isPolling() {
return nil, fmt.Errorf("functions must be registered before starting the service.")
return nil, fmt.Errorf("functions must be registered before starting the service")
}

if _, exists := s.Functions[fn.Name]; exists {
Expand All @@ -120,8 +120,12 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {

// Extract the relevant part of the schema
defs, ok := schema.Definitions[argType.Name()]

// If the definition is not found, use the whole schema.
// This tends to happen for inline structs.
// For example: func(input struct { A int `json:"a"` }) int
if !ok {
return nil, fmt.Errorf("failed to find schema definition for %s", argType.Name())
defs = schema
}

defsString, err := json.Marshal(defs)
Expand Down
38 changes: 38 additions & 0 deletions sdk-go/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ func TestRegisterFunc(t *testing.T) {
assert.Error(t, err)
}

func TestRegisterFuncWithInlineStruct(t *testing.T) {
_, _, _, apiEndpoint := util.GetTestVars()

i, _ := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: "test-secret",
})
service, _ := i.RegisterService("TestService1")

testFunc := func(input struct {
A int `json:"a"`
B int `json:"b"`
}) int {
return input.A + input.B
}
_, err := service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
Description: "Test function",
})
require.NoError(t, err)

// Try to register the same function again
_, err = service.RegisterFunc(Function{
Func: testFunc,
Name: "TestFunc",
})
assert.Error(t, err)

// Try to register a function with invalid input
invalidFunc := func(a, b int) int { return a + b }
_, err = service.RegisterFunc(Function{
Func: invalidFunc,
Name: "InvalidFunc",
})
assert.Error(t, err)
}

func TestRegistrationAndConfig(t *testing.T) {
machineSecret, _, _, apiEndpoint := util.GetTestVars()

Expand Down
36 changes: 0 additions & 36 deletions workflows/.github/workflows/build.yml

This file was deleted.

Loading

0 comments on commit 8c60d3c

Please sign in to comment.