Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Include paused state in transient states for polling #28

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk-dotnet/src/Inferable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async public Task<RunReference> CreateRunAsync(CreateRunInput input)
while (DateTime.Now < end) {
var pollResult = await this._client.GetRun(this._clusterId, result.ID);

var transientStates = new List<string> { "pending", "running" };
var transientStates = new List<string> { "pending", "running", "paused" };
if (transientStates.Contains(pollResult.Status)) {
await Task.Delay(Interval);
continue;
Expand Down
2 changes: 1 addition & 1 deletion sdk-dotnet/tests/Inferable.Tests/InferableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

Assert.NotNull(inferable);

Assert.NotNull(inferable.Default);

Check warning on line 60 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 60 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 60 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / test-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)
}

[Fact]
Expand All @@ -67,7 +67,7 @@

var service = inferable.RegisterService("test");

Assert.NotNull(service);

Check warning on line 70 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 70 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 70 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / test-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)
}

[Fact]
Expand Down Expand Up @@ -278,7 +278,7 @@

var result = await run.PollAsync(null);

await Task.Delay(500);
await Task.Delay(5000);

Assert.NotNull(result);
Assert.True(didCallSayHello);
Expand Down
250 changes: 125 additions & 125 deletions sdk-go/inferable.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,33 @@ type Inferable struct {
functionRegistry functionRegistry
machineID string
clusterID string
// Convenience reference to a service with the name 'default'.
//
// Returns:
// A registered service instance.
//
// Example:
//
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
//
// client.Default.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
//
// // Start the service
// client.Default.Start()
//
// // Stop the service on shutdown
// defer client.Default.Stop()
Default *service
// Convenience reference to a service with the name 'default'.
//
// Returns:
// A registered service instance.
//
// Example:
//
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
//
// client.Default.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
//
// // Start the service
// client.Default.Start()
//
// // Stop the service on shutdown
// defer client.Default.Stop()
Default *service
}

type InferableOptions struct {
Expand All @@ -78,7 +78,7 @@ type OnStatusChangeInput struct {
type runResult = OnStatusChangeInput

type RunTemplate struct {
ID string `json:"id"`
ID string `json:"id"`
Input map[string]interface{} `json:"input"`
}

Expand Down Expand Up @@ -150,28 +150,28 @@ func New(options InferableOptions) (*Inferable, error) {
//
// Example:
//
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
//
// // Define and register the service
// service := client.Service("MyService")
// // Define and register the service
// service := client.Service("MyService")
//
// sayHello, err := service.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
// sayHello, err := service.RegisterFunc(Function{
// Func: func(input EchoInput) string {
// didCallSayHello = true
// return "Hello " + input.Input
// },
// Name: "SayHello",
// Description: "A simple greeting function",
// })
//
// // Start the service
// service.Start()
// // Start the service
// service.Start()
//
// // Stop the service on shutdown
// defer service.Stop()
// // Stop the service on shutdown
// defer service.Stop()
func (i *Inferable) RegisterService(serviceName string) (*service, error) {
if _, exists := i.functionRegistry.services[serviceName]; exists {
return nil, fmt.Errorf("service with name '%s' already registered", serviceName)
Expand All @@ -186,34 +186,34 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {
}

func (i *Inferable) getRun(runID string) (*runResult, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("Cluster ID must be provided to manage runs")
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + i.apiSecret,
"X-Machine-ID": i.machineID,
"X-Machine-SDK-Version": Version,
"X-Machine-SDK-Language": "go",
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID),
Method: "GET",
Headers: headers,
}

responseData, _, err, _ := i.fetchData(options)
if err != nil {
return nil, fmt.Errorf("failed to get run: %v", err)
}
var result runResult
err = json.Unmarshal(responseData, &result)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return &result, nil
if i.clusterID == "" {
return nil, fmt.Errorf("Cluster ID must be provided to manage runs")
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + i.apiSecret,
"X-Machine-ID": i.machineID,
"X-Machine-SDK-Version": Version,
"X-Machine-SDK-Language": "go",
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID),
Method: "GET",
Headers: headers,
}

responseData, _, err, _ := i.fetchData(options)
if err != nil {
return nil, fmt.Errorf("failed to get run: %v", err)
}
var result runResult
err = json.Unmarshal(responseData, &result)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
return &result, nil
}

// Creates a run and returns a reference to it.
Expand All @@ -226,27 +226,27 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
//
// Example:
//
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
// // Create a new Inferable instance with an API secret
// client := inferable.New(InferableOptions{
// ApiSecret: "API_SECRET",
// })
//
// run, err := client.Run(CreateRunInput{
// Message: "Hello world",
// })
// run, err := client.Run(CreateRunInput{
// Message: "Hello world",
// })
//
// if err != nil {
// log.Fatal("Failed to create run:", err)
// }
// if err != nil {
// log.Fatal("Failed to create run:", err)
// }
//
// fmt.Println("Started run with ID:", run.ID)
// fmt.Println("Started run with ID:", run.ID)
//
// result, err := run.Poll()
// if err != nil {
// log.Fatal("Failed to poll run result:", err)
// }
// result, err := run.Poll()
// if err != nil {
// log.Fatal("Failed to poll run result:", err)
// }
//
// fmt.Println("Run result:", result)
// fmt.Println("Run result:", result)
func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
Expand Down Expand Up @@ -290,41 +290,41 @@ func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
}

return &runReference{
ID: response.ID,
Poll: func(options *PollOptions) (*runResult, error) {
// Default values for polling options
maxWaitTime := 60 * time.Second
interval := 500 * time.Millisecond

if options != nil {
if options.MaxWaitTime != nil {
maxWaitTime = *options.MaxWaitTime
}

if options.Interval != nil {
interval = *options.Interval
}
}

start := time.Now()
end := start.Add(maxWaitTime)

for time.Now().Before(end) {
pollResult, err := i.getRun(response.ID)
if err != nil {
return nil, fmt.Errorf("failed to poll for run: %w", err)
}

if pollResult.Status != "pending" && pollResult.Status != "running" {
return pollResult, nil
}

time.Sleep(interval)
}

return nil, fmt.Errorf("max wait time reached, polling stopped")
},
}, nil
ID: response.ID,
Poll: func(options *PollOptions) (*runResult, error) {
// Default values for polling options
maxWaitTime := 60 * time.Second
interval := 500 * time.Millisecond

if options != nil {
if options.MaxWaitTime != nil {
maxWaitTime = *options.MaxWaitTime
}

if options.Interval != nil {
interval = *options.Interval
}
}

start := time.Now()
end := start.Add(maxWaitTime)

for time.Now().Before(end) {
pollResult, err := i.getRun(response.ID)
if err != nil {
return nil, fmt.Errorf("failed to poll for run: %w", err)
}

if pollResult.Status != "pending" && pollResult.Status != "running" && pollResult.Status != "paused" {
return pollResult, nil
}

time.Sleep(interval)
}

return nil, fmt.Errorf("max wait time reached, polling stopped")
},
}, nil
}

func (i *Inferable) callFunc(serviceName, funcName string, args ...interface{}) ([]reflect.Value, error) {
Expand Down
10 changes: 4 additions & 6 deletions sdk-node/src/Inferable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ describe("Inferable", () => {
});

it("should initialize without optional args", () => {
expect(
() => new Inferable({ apiSecret: TEST_API_SECRET }),
).not.toThrow();
expect(() => new Inferable({ apiSecret: TEST_API_SECRET })).not.toThrow();
});

it("should initialize with API secret in environment", () => {
Expand Down Expand Up @@ -243,7 +241,7 @@ describe("Functions", () => {
// This should match the example in the readme
describe("Inferable SDK End to End Test", () => {
it("should trigger a run, call a function, and call a status change function", async () => {
const client = inferableInstance();
const client = inferableInstance();

let didCallSayHello = false;
let didCallOnStatusChange = false;
Expand Down Expand Up @@ -279,15 +277,15 @@ describe("Inferable SDK End to End Test", () => {
attachedFunctions: [sayHello],
// Optional: Define a schema for the result to conform to
resultSchema: z.object({
didSayHello: z.boolean()
didSayHello: z.boolean(),
}),
// Optional: Subscribe an Inferable function to receive notifications when the run status changes
onStatusChange: { function: onStatusChange },
});

const result = await run.poll();

await new Promise((resolve) => setTimeout(resolve, 500));
await new Promise((resolve) => setTimeout(resolve, 5000));

expect(result).not.toBeNull();
expect(didCallSayHello).toBe(true);
Expand Down
Loading
Loading