diff --git a/sdk-dotnet/src/Inferable.cs b/sdk-dotnet/src/Inferable.cs index 01768437..6adc3ec7 100644 --- a/sdk-dotnet/src/Inferable.cs +++ b/sdk-dotnet/src/Inferable.cs @@ -219,7 +219,7 @@ async public Task CreateRunAsync(CreateRunInput input) while (DateTime.Now < end) { var pollResult = await this._client.GetRun(this._clusterId, result.ID); - var transientStates = new List { "pending", "running" }; + var transientStates = new List { "pending", "running", "paused" }; if (transientStates.Contains(pollResult.Status)) { await Task.Delay(Interval); continue; diff --git a/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs b/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs index 04e014a7..81cdae5e 100644 --- a/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs +++ b/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs @@ -278,7 +278,7 @@ async public void Inferable_Run_E2E() var result = await run.PollAsync(null); - await Task.Delay(500); + await Task.Delay(5000); Assert.NotNull(result); Assert.True(didCallSayHello); diff --git a/sdk-go/inferable.go b/sdk-go/inferable.go index 61149c31..c492a67f 100644 --- a/sdk-go/inferable.go +++ b/sdk-go/inferable.go @@ -30,33 +30,33 @@ type Inferable struct { functionRegistry functionRegistry machineID string clusterID string - // Convenience reference to a service with the name 'default'. - // - // Returns: - // A registered service instance. - // - // Example: - // - // // Create a new Inferable instance with an API secret - // client := inferable.New(InferableOptions{ - // ApiSecret: "API_SECRET", - // }) - // - // client.Default.RegisterFunc(Function{ - // Func: func(input EchoInput) string { - // didCallSayHello = true - // return "Hello " + input.Input - // }, - // Name: "SayHello", - // Description: "A simple greeting function", - // }) - // - // // Start the service - // client.Default.Start() - // - // // Stop the service on shutdown - // defer client.Default.Stop() - Default *service + // Convenience reference to a service with the name 'default'. + // + // Returns: + // A registered service instance. + // + // Example: + // + // // Create a new Inferable instance with an API secret + // client := inferable.New(InferableOptions{ + // ApiSecret: "API_SECRET", + // }) + // + // client.Default.RegisterFunc(Function{ + // Func: func(input EchoInput) string { + // didCallSayHello = true + // return "Hello " + input.Input + // }, + // Name: "SayHello", + // Description: "A simple greeting function", + // }) + // + // // Start the service + // client.Default.Start() + // + // // Stop the service on shutdown + // defer client.Default.Stop() + Default *service } type InferableOptions struct { @@ -78,7 +78,7 @@ type OnStatusChangeInput struct { type runResult = OnStatusChangeInput type RunTemplate struct { - ID string `json:"id"` + ID string `json:"id"` Input map[string]interface{} `json:"input"` } @@ -150,28 +150,28 @@ func New(options InferableOptions) (*Inferable, error) { // // Example: // -// // Create a new Inferable instance with an API secret -// client := inferable.New(InferableOptions{ -// ApiSecret: "API_SECRET", -// }) +// // Create a new Inferable instance with an API secret +// client := inferable.New(InferableOptions{ +// ApiSecret: "API_SECRET", +// }) // -// // Define and register the service -// service := client.Service("MyService") +// // Define and register the service +// service := client.Service("MyService") // -// sayHello, err := service.RegisterFunc(Function{ -// Func: func(input EchoInput) string { -// didCallSayHello = true -// return "Hello " + input.Input -// }, -// Name: "SayHello", -// Description: "A simple greeting function", -// }) +// sayHello, err := service.RegisterFunc(Function{ +// Func: func(input EchoInput) string { +// didCallSayHello = true +// return "Hello " + input.Input +// }, +// Name: "SayHello", +// Description: "A simple greeting function", +// }) // -// // Start the service -// service.Start() +// // Start the service +// service.Start() // -// // Stop the service on shutdown -// defer service.Stop() +// // Stop the service on shutdown +// defer service.Stop() func (i *Inferable) RegisterService(serviceName string) (*service, error) { if _, exists := i.functionRegistry.services[serviceName]; exists { return nil, fmt.Errorf("service with name '%s' already registered", serviceName) @@ -186,34 +186,34 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) { } func (i *Inferable) getRun(runID string) (*runResult, error) { - if i.clusterID == "" { - return nil, fmt.Errorf("Cluster ID must be provided to manage runs") - } - - // Prepare headers - headers := map[string]string{ - "Authorization": "Bearer " + i.apiSecret, - "X-Machine-ID": i.machineID, - "X-Machine-SDK-Version": Version, - "X-Machine-SDK-Language": "go", - } - - options := client.FetchDataOptions{ - Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID), - Method: "GET", - Headers: headers, - } - - responseData, _, err, _ := i.fetchData(options) - if err != nil { - return nil, fmt.Errorf("failed to get run: %v", err) - } - var result runResult - err = json.Unmarshal(responseData, &result) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %v", err) - } - return &result, nil + if i.clusterID == "" { + return nil, fmt.Errorf("Cluster ID must be provided to manage runs") + } + + // Prepare headers + headers := map[string]string{ + "Authorization": "Bearer " + i.apiSecret, + "X-Machine-ID": i.machineID, + "X-Machine-SDK-Version": Version, + "X-Machine-SDK-Language": "go", + } + + options := client.FetchDataOptions{ + Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID), + Method: "GET", + Headers: headers, + } + + responseData, _, err, _ := i.fetchData(options) + if err != nil { + return nil, fmt.Errorf("failed to get run: %v", err) + } + var result runResult + err = json.Unmarshal(responseData, &result) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %v", err) + } + return &result, nil } // Creates a run and returns a reference to it. @@ -226,27 +226,27 @@ func (i *Inferable) getRun(runID string) (*runResult, error) { // // Example: // -// // Create a new Inferable instance with an API secret -// client := inferable.New(InferableOptions{ -// ApiSecret: "API_SECRET", -// }) +// // Create a new Inferable instance with an API secret +// client := inferable.New(InferableOptions{ +// ApiSecret: "API_SECRET", +// }) // -// run, err := client.Run(CreateRunInput{ -// Message: "Hello world", -// }) +// run, err := client.Run(CreateRunInput{ +// Message: "Hello world", +// }) // -// if err != nil { -// log.Fatal("Failed to create run:", err) -// } +// if err != nil { +// log.Fatal("Failed to create run:", err) +// } // -// fmt.Println("Started run with ID:", run.ID) +// fmt.Println("Started run with ID:", run.ID) // -// result, err := run.Poll() -// if err != nil { -// log.Fatal("Failed to poll run result:", err) -// } +// result, err := run.Poll() +// if err != nil { +// log.Fatal("Failed to poll run result:", err) +// } // -// fmt.Println("Run result:", result) +// fmt.Println("Run result:", result) func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) { if i.clusterID == "" { return nil, fmt.Errorf("cluster ID must be provided to manage runs") @@ -290,41 +290,41 @@ func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) { } return &runReference{ - ID: response.ID, - Poll: func(options *PollOptions) (*runResult, error) { - // Default values for polling options - maxWaitTime := 60 * time.Second - interval := 500 * time.Millisecond - - if options != nil { - if options.MaxWaitTime != nil { - maxWaitTime = *options.MaxWaitTime - } - - if options.Interval != nil { - interval = *options.Interval - } - } - - start := time.Now() - end := start.Add(maxWaitTime) - - for time.Now().Before(end) { - pollResult, err := i.getRun(response.ID) - if err != nil { - return nil, fmt.Errorf("failed to poll for run: %w", err) - } - - if pollResult.Status != "pending" && pollResult.Status != "running" { - return pollResult, nil - } - - time.Sleep(interval) - } - - return nil, fmt.Errorf("max wait time reached, polling stopped") - }, - }, nil + ID: response.ID, + Poll: func(options *PollOptions) (*runResult, error) { + // Default values for polling options + maxWaitTime := 60 * time.Second + interval := 500 * time.Millisecond + + if options != nil { + if options.MaxWaitTime != nil { + maxWaitTime = *options.MaxWaitTime + } + + if options.Interval != nil { + interval = *options.Interval + } + } + + start := time.Now() + end := start.Add(maxWaitTime) + + for time.Now().Before(end) { + pollResult, err := i.getRun(response.ID) + if err != nil { + return nil, fmt.Errorf("failed to poll for run: %w", err) + } + + if pollResult.Status != "pending" && pollResult.Status != "running" && pollResult.Status != "paused" { + return pollResult, nil + } + + time.Sleep(interval) + } + + return nil, fmt.Errorf("max wait time reached, polling stopped") + }, + }, nil } func (i *Inferable) callFunc(serviceName, funcName string, args ...interface{}) ([]reflect.Value, error) { diff --git a/sdk-node/src/Inferable.test.ts b/sdk-node/src/Inferable.test.ts index ff1d5687..f26a6e7e 100644 --- a/sdk-node/src/Inferable.test.ts +++ b/sdk-node/src/Inferable.test.ts @@ -56,9 +56,7 @@ describe("Inferable", () => { }); it("should initialize without optional args", () => { - expect( - () => new Inferable({ apiSecret: TEST_API_SECRET }), - ).not.toThrow(); + expect(() => new Inferable({ apiSecret: TEST_API_SECRET })).not.toThrow(); }); it("should initialize with API secret in environment", () => { @@ -243,7 +241,7 @@ describe("Functions", () => { // This should match the example in the readme describe("Inferable SDK End to End Test", () => { it("should trigger a run, call a function, and call a status change function", async () => { - const client = inferableInstance(); + const client = inferableInstance(); let didCallSayHello = false; let didCallOnStatusChange = false; @@ -279,7 +277,7 @@ describe("Inferable SDK End to End Test", () => { attachedFunctions: [sayHello], // Optional: Define a schema for the result to conform to resultSchema: z.object({ - didSayHello: z.boolean() + didSayHello: z.boolean(), }), // Optional: Subscribe an Inferable function to receive notifications when the run status changes onStatusChange: { function: onStatusChange }, @@ -287,7 +285,7 @@ describe("Inferable SDK End to End Test", () => { const result = await run.poll(); - await new Promise((resolve) => setTimeout(resolve, 500)); + await new Promise((resolve) => setTimeout(resolve, 5000)); expect(result).not.toBeNull(); expect(didCallSayHello).toBe(true); diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index f9844239..42f654bf 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -30,11 +30,14 @@ debug.formatters.J = (json) => { export const log = debug("inferable:client"); -type RunInput = Omit["createRun"]>[0] - >["body"], "resultSchema"> & { - id?: string, - resultSchema?: z.ZodType | JsonSchemaInput + >["body"], + "resultSchema" +> & { + id?: string; + resultSchema?: z.ZodType | JsonSchemaInput; }; /** @@ -254,7 +257,6 @@ export class Inferable { }, }); - if (runResult.status != 201) { throw new InferableError("Failed to create run", { body: runResult.body, @@ -266,13 +268,15 @@ export class Inferable { return { id: runResult.body.id, /** - * Polls until the run reaches a terminal state (!= "pending" && != "running") or maxWaitTime is reached. + * Polls until the run reaches a terminal state (!= "pending" && != "running" && != "paused") or maxWaitTime is reached. * @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. Defaults to 60 seconds. * @param interval The amount of time to wait between polling attempts. Defaults to 500ms. */ - poll: async (options?: { maxWaitTime?: number, interval?: number }) => { + poll: async (options?: { maxWaitTime?: number; interval?: number }) => { if (!this.clusterId) { - throw new InferableError("Cluster ID must be provided to manage runs"); + throw new InferableError( + "Cluster ID must be provided to manage runs", + ); } const start = Date.now(); @@ -292,7 +296,11 @@ export class Inferable { status: pollResult.status, }); } - if (["pending", "running"].includes(pollResult.body.status ?? "")) { + if ( + ["pending", "running", "paused"].includes( + pollResult.body.status ?? "", + ) + ) { await new Promise((resolve) => { setTimeout(resolve, options?.interval || 500); });