From 1d5c3e3a06f634b3fa4283d0735988c87dab9ba5 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 6 Dec 2024 13:00:05 +1100 Subject: [PATCH] feat: Introduce assertions for run validation --- sdk-node/package.json | 4 +- sdk-node/src/Inferable.ts | 32 ++++++-- sdk-node/src/assertions.test.ts | 132 ++++++++++++++++++++++++++++++++ sdk-node/src/assertions.ts | 58 ++++++++++++++ 4 files changed, 219 insertions(+), 7 deletions(-) create mode 100644 sdk-node/src/assertions.test.ts create mode 100644 sdk-node/src/assertions.ts diff --git a/sdk-node/package.json b/sdk-node/package.json index 1fa989c5..6e08dde1 100644 --- a/sdk-node/package.json +++ b/sdk-node/package.json @@ -8,7 +8,7 @@ "clean": "rm -rf ./bin", "prepare": "husky", "test": "jest ./src --runInBand --forceExit --setupFiles dotenv/config", - "test:dev": "jest ./src --watch --setupFiles dotenv/config" + "test:dev": "jest dotenv/config ./src --watch --setupFiles dotenv/config" }, "author": "Inferable, Inc.", "license": "MIT", @@ -54,4 +54,4 @@ "email": "hi@inferable.ai", "url": "https://github.com/inferablehq/inferable/issues" } -} +} \ No newline at end of file diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index bce5c287..7cfe07ef 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -23,6 +23,7 @@ import { validateFunctionSchema, validateServiceName, } from "./util"; +import { assert, Assertions, assertRun } from "./assertions"; // Custom json formatter debug.formatters.J = (json) => { @@ -108,8 +109,8 @@ export class Inferable { */ constructor(options?: { apiSecret?: string; - endpoint?: string - machineId?: string + endpoint?: string; + machineId?: string; }) { if (options?.apiSecret && process.env.INFERABLE_API_SECRET) { log( @@ -257,14 +258,18 @@ export class Inferable { } } - return { + const returnable = { id: runResult.body.id, /** * Polls until the run reaches a terminal state (!= "pending" && != "running" && != "paused") or maxWaitTime is reached. * @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. Defaults to 60 seconds. * @param interval The amount of time to wait between polling attempts. Defaults to 500ms. */ - poll: async (options?: { maxWaitTime?: number; interval?: number }) => { + poll: async (options?: { + maxWaitTime?: number; + interval?: number; + assertions?: Assertions; + }): Promise => { if (!this.clusterId) { throw new InferableError( "Cluster ID must be provided to manage runs", @@ -299,8 +304,23 @@ export class Inferable { continue; } - return pollResult.body; + const assertionsResult = await assertRun({ + clusterId: this.clusterId, + runId: runResult.body.id, + client: this.client, + result: pollResult.body.result as T, + functionCalls: [], // TODO: Add function calls + assertions: options?.assertions || [], + }); + + if (!assertionsResult.assertionsPassed) { + return returnable.poll(options); + } + + return pollResult.body.result as T; // Cast the result to the inferred type } + + throw new InferableError("Run timed out"); }, /** * Retrieves the messages for a run. @@ -325,6 +345,8 @@ export class Inferable { }; }, }; + + return returnable; } /** diff --git a/sdk-node/src/assertions.test.ts b/sdk-node/src/assertions.test.ts new file mode 100644 index 00000000..1fd88bbb --- /dev/null +++ b/sdk-node/src/assertions.test.ts @@ -0,0 +1,132 @@ +import { Inferable } from "./inferable"; +import { z } from "zod"; +import assert from "assert"; +import { TEST_ENDPOINT } from "./tests/utils"; +import { TEST_API_SECRET } from "./tests/utils"; + +describe("assertions", () => { + it("should be able to assert a run", async () => { + const client = new Inferable({ + apiSecret: TEST_API_SECRET, + endpoint: TEST_ENDPOINT, + }); + + client.default.register({ + name: "fetch", + func: async ({ url }: { url: string }) => { + if (url === "https://news.ycombinator.com/show") { + return ``; + } else { + const randomPastYear = Math.min( + 2024, + Math.floor(Math.random() * 100) + 1900, + ); + return `
+

Comments

+
    +
  • +
    This is surprising because I observe the same thing in ${randomPastYear}
    +
  • +
+
`; + } + }, + }); + + await client.default.start(); + + const resultSchema = z.object({ + topPosts: z.array( + z.object({ + id: z.string().describe("The ID of the post"), + title: z.string().describe("The title of the post"), + }), + ), + comments: z.array( + z.object({ + commentsPageUrl: z.string().describe("The URL of the comments page"), + text: z.string().describe("The text of the comment"), + }), + ), + }); + + const run = await client.run({ + initialPrompt: + "Get the top comment for the top 10 posts on Show HN: https://news.ycombinator.com/show", + resultSchema: resultSchema, + }); + + const result = await run.poll>({ + assertions: [ + function assertCorrect(result) { + const missingComments = result.topPosts.filter( + (post) => + !result.comments.some((comment) => + comment.commentsPageUrl.includes(post.id), + ), + ); + + const duplicateComments = result.comments.filter( + (comment, index, self) => + self.findIndex((c) => c.text === comment.text) !== index, + ); + + assert( + missingComments.length === 0, + `Some posts were missing comments: ${missingComments.map((m) => m.id).join(", ")}`, + ); + assert( + duplicateComments.length === 0, + `Detected duplicate comments: ${duplicateComments.map((d) => d.text).join(", ")}`, + ); + }, + ], + }); + + console.log(result); + }); +}); diff --git a/sdk-node/src/assertions.ts b/sdk-node/src/assertions.ts new file mode 100644 index 00000000..57ff3cb3 --- /dev/null +++ b/sdk-node/src/assertions.ts @@ -0,0 +1,58 @@ +import { createApiClient } from "./create-client"; + +type FunctionCall = { + service: string; + function: string; +}; + +type AssertionFunction = ( + result: T, + functionCalls: FunctionCall[], +) => void | Promise; + +export type Assertions = AssertionFunction[]; + +export async function assertRun({ + clusterId, + runId, + client, + result, + functionCalls, + assertions, +}: { + clusterId: string; + runId: string; + client: ReturnType; + result: T; + functionCalls: FunctionCall[]; + assertions: Assertions; +}): Promise<{ + assertionsPassed: boolean; +}> { + const results = await Promise.allSettled( + assertions.map((a) => a(result, functionCalls)), + ); + + const hasRejections = results.some((r) => r.status === "rejected"); + + if (hasRejections) { + await client.createMessage({ + body: { + message: `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + type: "human", + }, + params: { + clusterId, + runId, + }, + }); + + return { + assertionsPassed: false, + }; + } + + return { + assertionsPassed: true, + }; +}