diff --git a/sdk-node/package.json b/sdk-node/package.json index 1fa989c5..b93cfde7 100644 --- a/sdk-node/package.json +++ b/sdk-node/package.json @@ -54,4 +54,4 @@ "email": "hi@inferable.ai", "url": "https://github.com/inferablehq/inferable/issues" } -} +} \ No newline at end of file diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index bce5c287..bc2cddaf 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -23,6 +23,7 @@ import { validateFunctionSchema, validateServiceName, } from "./util"; +import { Assertions, assertRun } from "./assertions"; // Custom json formatter debug.formatters.J = (json) => { @@ -108,8 +109,8 @@ export class Inferable { */ constructor(options?: { apiSecret?: string; - endpoint?: string - machineId?: string + endpoint?: string; + machineId?: string; }) { if (options?.apiSecret && process.env.INFERABLE_API_SECRET) { log( @@ -257,14 +258,18 @@ export class Inferable { } } - return { + const returnable = { id: runResult.body.id, /** * Polls until the run reaches a terminal state (!= "pending" && != "running" && != "paused") or maxWaitTime is reached. * @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. Defaults to 60 seconds. * @param interval The amount of time to wait between polling attempts. Defaults to 500ms. */ - poll: async (options?: { maxWaitTime?: number; interval?: number }) => { + poll: async (options?: { + maxWaitTime?: number; + interval?: number; + assertions?: Assertions; + }): Promise => { if (!this.clusterId) { throw new InferableError( "Cluster ID must be provided to manage runs", @@ -299,8 +304,23 @@ export class Inferable { continue; } - return pollResult.body; + const assertionsResult = await assertRun({ + clusterId: this.clusterId, + runId: runResult.body.id, + client: this.client, + result: pollResult.body.result as T, + functionCalls: [], // TODO: Add function calls + assertions: options?.assertions || [], + }); + + if (!assertionsResult.assertionsPassed) { + return returnable.poll(options); + } + + return pollResult.body.result as T; // Cast the result to the inferred type } + + throw new InferableError("Run timed out"); }, /** * Retrieves the messages for a run. @@ -325,6 +345,8 @@ export class Inferable { }; }, }; + + return returnable; } /** diff --git a/sdk-node/src/assertions.test.ts b/sdk-node/src/assertions.test.ts new file mode 100644 index 00000000..e8193d2f --- /dev/null +++ b/sdk-node/src/assertions.test.ts @@ -0,0 +1,50 @@ +import { Inferable } from "./Inferable"; +import { z } from "zod"; +import { TEST_ENDPOINT } from "./tests/utils"; +import { TEST_API_SECRET } from "./tests/utils"; + +describe("assertions", () => { + it("should be able to assert a run", async () => { + const client = new Inferable({ + apiSecret: TEST_API_SECRET, + endpoint: TEST_ENDPOINT, + }); + + let timesRun = 0; + + client.default.register({ + name: "generateRandomNumber", + func: async ({ seed }: { seed: number }) => { + timesRun++; + + return seed * timesRun; + }, + }); + + await client.default.start(); + + const resultSchema = z.object({ + result: z.number().describe("The result of the function"), + }); + + const run = await client.run({ + initialPrompt: + "Use the available functions to generate a random number between 0 and 100", + resultSchema: resultSchema, + }); + + const result = await run.poll>({ + assertions: [ + function assertCorrect(result) { + if (timesRun === 1) { + throw new Error( + `The result ${result.result} is unacceptable. Try again with a different seed.`, + ); + } + }, + ], + }); + + expect(result.result).toBeGreaterThan(0); + }); +}); diff --git a/sdk-node/src/assertions.ts b/sdk-node/src/assertions.ts new file mode 100644 index 00000000..53cbaa41 --- /dev/null +++ b/sdk-node/src/assertions.ts @@ -0,0 +1,65 @@ +import { createApiClient } from "./create-client"; + +type FunctionCall = { + service: string; + function: string; +}; + +type AssertionFunction = ( + result: T, + functionCalls: FunctionCall[], +) => void | Promise; + +export type Assertions = AssertionFunction[]; + +export async function assertRun({ + clusterId, + runId, + client, + result, + functionCalls, + assertions, +}: { + clusterId: string; + runId: string; + client: ReturnType; + result: T; + functionCalls: FunctionCall[]; + assertions: Assertions; +}): Promise<{ + assertionsPassed: boolean; +}> { + const results = await Promise.allSettled( + assertions.map(async (a) => await a(result, functionCalls)), + ); + + const hasRejections = results.some((r) => r.status === "rejected"); + + if (hasRejections) { + await client.createMessage({ + body: { + message: [ + `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + ``, + ...results + .filter((r) => r.status === "rejected") + .map((r) => r.reason), + ``, + ].join("\n"), + type: "human", + }, + params: { + clusterId, + runId, + }, + }); + + return { + assertionsPassed: false, + }; + } + + return { + assertionsPassed: true, + }; +}