diff --git a/sdk-node/src/Inferable.test.ts b/sdk-node/src/Inferable.test.ts index 60523a69..bc3c40e7 100644 --- a/sdk-node/src/Inferable.test.ts +++ b/sdk-node/src/Inferable.test.ts @@ -289,30 +289,4 @@ describe("Inferable SDK End to End Test", () => { await client.default.stop(); } }); - - describe("api", () => { - it("should be able to call the api directly", async () => { - const client = inferableInstance(); - - const result = await client.api.createStructuredOutput<{ - capital: string; - }>({ - prompt: "What is the capital of France?", - modelId: "claude-3-5-sonnet", - resultSchema: { - type: "object", - properties: { - capital: { type: "string" }, - }, - }, - }); - - expect(result).toMatchObject({ - success: true, - data: { - capital: "Paris", - }, - }); - }); - }); }); diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index bce5c287..7f49f395 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -514,25 +514,6 @@ export class Inferable { this.functionRegistry[registration.name] = registration; } - public get api() { - return { - createStructuredOutput: async ( - input: Parameters[0]["body"], - ) => - this.client - .createStructuredOutput({ - params: { - clusterId: await this.getClusterId(), - }, - body: input, - }) - .then((r) => r.body) as Promise<{ - success: boolean; - data?: T; - }>, - }; - } - public async getClusterId() { if (!this.clusterId) { // Call register machine without any services to test API key and get clusterId diff --git a/sdk-node/src/execute-fn.test.ts b/sdk-node/src/execute-fn.test.ts index 9c1468ea..c8d28c4e 100644 --- a/sdk-node/src/execute-fn.test.ts +++ b/sdk-node/src/execute-fn.test.ts @@ -1,7 +1,34 @@ +import { executeFn } from "./execute-fn"; +import { Interrupt } from "./util"; + describe("executeFn", () => { it("should run a function with arguments", async () => { const fn = (val: { [key: string]: string }) => Promise.resolve(val.foo); - const result = await fn({ foo: "bar" }); - expect(result).toBe("bar"); + const result = await executeFn(fn, [{foo: "bar"}] as any); + expect(result).toEqual({ + content: "bar", + type: "resolution", + functionExecutionTime: expect.any(Number), + }); + }); + + it("should extract interrupt from resolution", async () => { + const fn = (_: string) => Promise.resolve(Interrupt.approval()); + const result = await executeFn(fn, [{}] as any); + expect(result).toEqual({ + content: { type: "approval" }, + type: "interrupt", + functionExecutionTime: expect.any(Number), + }); + }); + + it("should extract interrupt from rejection", async () => { + const fn = () => Promise.reject(Interrupt.approval()); + const result = await executeFn(fn, [{}] as any); + expect(result).toEqual({ + content: { type: "approval" }, + type: "interrupt", + functionExecutionTime: expect.any(Number), + }); }); }); diff --git a/sdk-node/src/execute-fn.ts b/sdk-node/src/execute-fn.ts index 5c6d2d89..40842142 100644 --- a/sdk-node/src/execute-fn.ts +++ b/sdk-node/src/execute-fn.ts @@ -32,6 +32,14 @@ export const executeFn = async ( functionExecutionTime: Date.now() - start, }; } catch (e) { + const interupt = extractInterrupt(e); + if (interupt) { + return { + content: interupt, + type: "interrupt", + functionExecutionTime: Date.now() - start, + }; + } const functionExecutionTime = Date.now() - start; if (e instanceof Error) { return { diff --git a/sdk-node/src/util.test.ts b/sdk-node/src/util.test.ts index 2cd19ae3..6794d26e 100644 --- a/sdk-node/src/util.test.ts +++ b/sdk-node/src/util.test.ts @@ -2,6 +2,8 @@ import { ajvErrorToFailures, blob, extractBlobs, + extractInterrupt, + Interrupt, validateFunctionSchema, } from "./util"; @@ -64,6 +66,21 @@ describe("ajvErrorToFailures", () => { }); }); +describe("extractInterrupt", () => { + it("should extract extract interrupt", () => { + const interrupt = extractInterrupt(Interrupt.approval()); + expect(interrupt).toEqual({ + type: "approval", + }) + }) + + it("should not extract interrupt from non-interrupt", () => { + const interrupt = extractInterrupt({ foo: "bar" }); + expect(interrupt).toBeUndefined(); + }) + +}); + describe("extractBlobs", () => { it("should extract blobs from content", () => { const initialContent = { diff --git a/sdk-node/src/util.ts b/sdk-node/src/util.ts index 1016834a..358c37ef 100644 --- a/sdk-node/src/util.ts +++ b/sdk-node/src/util.ts @@ -279,6 +279,7 @@ export const blob = ({ export const INTERRUPT_KEY = "__inferable_interrupt"; +type VALID_INTERRUPT_TYPES = "approval"; const interruptResultSchema = z.discriminatedUnion("type", [ z.object({ type: z.literal("approval"), @@ -297,6 +298,25 @@ export const extractInterrupt = (input: unknown): z.infer { return { [INTERRUPT_KEY]: {