From 1d5c3e3a06f634b3fa4283d0735988c87dab9ba5 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 6 Dec 2024 13:00:05 +1100 Subject: [PATCH 1/4] feat: Introduce assertions for run validation --- sdk-node/package.json | 4 +- sdk-node/src/Inferable.ts | 32 ++++++-- sdk-node/src/assertions.test.ts | 132 ++++++++++++++++++++++++++++++++ sdk-node/src/assertions.ts | 58 ++++++++++++++ 4 files changed, 219 insertions(+), 7 deletions(-) create mode 100644 sdk-node/src/assertions.test.ts create mode 100644 sdk-node/src/assertions.ts diff --git a/sdk-node/package.json b/sdk-node/package.json index 1fa989c5..6e08dde1 100644 --- a/sdk-node/package.json +++ b/sdk-node/package.json @@ -8,7 +8,7 @@ "clean": "rm -rf ./bin", "prepare": "husky", "test": "jest ./src --runInBand --forceExit --setupFiles dotenv/config", - "test:dev": "jest ./src --watch --setupFiles dotenv/config" + "test:dev": "jest dotenv/config ./src --watch --setupFiles dotenv/config" }, "author": "Inferable, Inc.", "license": "MIT", @@ -54,4 +54,4 @@ "email": "hi@inferable.ai", "url": "https://github.com/inferablehq/inferable/issues" } -} +} \ No newline at end of file diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index bce5c287..7cfe07ef 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -23,6 +23,7 @@ import { validateFunctionSchema, validateServiceName, } from "./util"; +import { assert, Assertions, assertRun } from "./assertions"; // Custom json formatter debug.formatters.J = (json) => { @@ -108,8 +109,8 @@ export class Inferable { */ constructor(options?: { apiSecret?: string; - endpoint?: string - machineId?: string + endpoint?: string; + machineId?: string; }) { if (options?.apiSecret && process.env.INFERABLE_API_SECRET) { log( @@ -257,14 +258,18 @@ export class Inferable { } } - return { + const returnable = { id: runResult.body.id, /** * Polls until the run reaches a terminal state (!= "pending" && != "running" && != "paused") or maxWaitTime is reached. * @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. Defaults to 60 seconds. * @param interval The amount of time to wait between polling attempts. Defaults to 500ms. */ - poll: async (options?: { maxWaitTime?: number; interval?: number }) => { + poll: async (options?: { + maxWaitTime?: number; + interval?: number; + assertions?: Assertions; + }): Promise => { if (!this.clusterId) { throw new InferableError( "Cluster ID must be provided to manage runs", @@ -299,8 +304,23 @@ export class Inferable { continue; } - return pollResult.body; + const assertionsResult = await assertRun({ + clusterId: this.clusterId, + runId: runResult.body.id, + client: this.client, + result: pollResult.body.result as T, + functionCalls: [], // TODO: Add function calls + assertions: options?.assertions || [], + }); + + if (!assertionsResult.assertionsPassed) { + return returnable.poll(options); + } + + return pollResult.body.result as T; // Cast the result to the inferred type } + + throw new InferableError("Run timed out"); }, /** * Retrieves the messages for a run. @@ -325,6 +345,8 @@ export class Inferable { }; }, }; + + return returnable; } /** diff --git a/sdk-node/src/assertions.test.ts b/sdk-node/src/assertions.test.ts new file mode 100644 index 00000000..1fd88bbb --- /dev/null +++ b/sdk-node/src/assertions.test.ts @@ -0,0 +1,132 @@ +import { Inferable } from "./inferable"; +import { z } from "zod"; +import assert from "assert"; +import { TEST_ENDPOINT } from "./tests/utils"; +import { TEST_API_SECRET } from "./tests/utils"; + +describe("assertions", () => { + it("should be able to assert a run", async () => { + const client = new Inferable({ + apiSecret: TEST_API_SECRET, + endpoint: TEST_ENDPOINT, + }); + + client.default.register({ + name: "fetch", + func: async ({ url }: { url: string }) => { + if (url === "https://news.ycombinator.com/show") { + return ``; + } else { + const randomPastYear = Math.min( + 2024, + Math.floor(Math.random() * 100) + 1900, + ); + return `
+

Comments

+
    +
  • +
    This is surprising because I observe the same thing in ${randomPastYear}
    +
  • +
+
`; + } + }, + }); + + await client.default.start(); + + const resultSchema = z.object({ + topPosts: z.array( + z.object({ + id: z.string().describe("The ID of the post"), + title: z.string().describe("The title of the post"), + }), + ), + comments: z.array( + z.object({ + commentsPageUrl: z.string().describe("The URL of the comments page"), + text: z.string().describe("The text of the comment"), + }), + ), + }); + + const run = await client.run({ + initialPrompt: + "Get the top comment for the top 10 posts on Show HN: https://news.ycombinator.com/show", + resultSchema: resultSchema, + }); + + const result = await run.poll>({ + assertions: [ + function assertCorrect(result) { + const missingComments = result.topPosts.filter( + (post) => + !result.comments.some((comment) => + comment.commentsPageUrl.includes(post.id), + ), + ); + + const duplicateComments = result.comments.filter( + (comment, index, self) => + self.findIndex((c) => c.text === comment.text) !== index, + ); + + assert( + missingComments.length === 0, + `Some posts were missing comments: ${missingComments.map((m) => m.id).join(", ")}`, + ); + assert( + duplicateComments.length === 0, + `Detected duplicate comments: ${duplicateComments.map((d) => d.text).join(", ")}`, + ); + }, + ], + }); + + console.log(result); + }); +}); diff --git a/sdk-node/src/assertions.ts b/sdk-node/src/assertions.ts new file mode 100644 index 00000000..57ff3cb3 --- /dev/null +++ b/sdk-node/src/assertions.ts @@ -0,0 +1,58 @@ +import { createApiClient } from "./create-client"; + +type FunctionCall = { + service: string; + function: string; +}; + +type AssertionFunction = ( + result: T, + functionCalls: FunctionCall[], +) => void | Promise; + +export type Assertions = AssertionFunction[]; + +export async function assertRun({ + clusterId, + runId, + client, + result, + functionCalls, + assertions, +}: { + clusterId: string; + runId: string; + client: ReturnType; + result: T; + functionCalls: FunctionCall[]; + assertions: Assertions; +}): Promise<{ + assertionsPassed: boolean; +}> { + const results = await Promise.allSettled( + assertions.map((a) => a(result, functionCalls)), + ); + + const hasRejections = results.some((r) => r.status === "rejected"); + + if (hasRejections) { + await client.createMessage({ + body: { + message: `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + type: "human", + }, + params: { + clusterId, + runId, + }, + }); + + return { + assertionsPassed: false, + }; + } + + return { + assertionsPassed: true, + }; +} From 8985d3f389661b9dca0b33b65bcef2fbae77912d Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 6 Dec 2024 13:01:32 +1100 Subject: [PATCH 2/4] update --- sdk-node/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk-node/package.json b/sdk-node/package.json index 6e08dde1..b93cfde7 100644 --- a/sdk-node/package.json +++ b/sdk-node/package.json @@ -8,7 +8,7 @@ "clean": "rm -rf ./bin", "prepare": "husky", "test": "jest ./src --runInBand --forceExit --setupFiles dotenv/config", - "test:dev": "jest dotenv/config ./src --watch --setupFiles dotenv/config" + "test:dev": "jest ./src --watch --setupFiles dotenv/config" }, "author": "Inferable, Inc.", "license": "MIT", From 1ae0144791e9a84b79f8460ee1edc91f33d12b53 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Sat, 7 Dec 2024 16:42:47 +1100 Subject: [PATCH 3/4] fix: A better assertions test --- sdk-node/src/Inferable.ts | 2 +- sdk-node/src/assertions.test.ts | 112 +++++--------------------------- sdk-node/src/assertions.ts | 11 +++- 3 files changed, 25 insertions(+), 100 deletions(-) diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index 7cfe07ef..bc2cddaf 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -23,7 +23,7 @@ import { validateFunctionSchema, validateServiceName, } from "./util"; -import { assert, Assertions, assertRun } from "./assertions"; +import { Assertions, assertRun } from "./assertions"; // Custom json formatter debug.formatters.J = (json) => { diff --git a/sdk-node/src/assertions.test.ts b/sdk-node/src/assertions.test.ts index 1fd88bbb..f0533886 100644 --- a/sdk-node/src/assertions.test.ts +++ b/sdk-node/src/assertions.test.ts @@ -1,6 +1,5 @@ import { Inferable } from "./inferable"; import { z } from "zod"; -import assert from "assert"; import { TEST_ENDPOINT } from "./tests/utils"; import { TEST_API_SECRET } from "./tests/utils"; @@ -11,122 +10,41 @@ describe("assertions", () => { endpoint: TEST_ENDPOINT, }); + let timesRun = 0; + client.default.register({ - name: "fetch", - func: async ({ url }: { url: string }) => { - if (url === "https://news.ycombinator.com/show") { - return ``; - } else { - const randomPastYear = Math.min( - 2024, - Math.floor(Math.random() * 100) + 1900, - ); - return `
-

Comments

-
    -
  • -
    This is surprising because I observe the same thing in ${randomPastYear}
    -
  • -
-
`; - } + name: "generateRandomNumber", + func: async ({ seed }: { seed: number }) => { + timesRun++; + + return seed * timesRun; }, }); await client.default.start(); const resultSchema = z.object({ - topPosts: z.array( - z.object({ - id: z.string().describe("The ID of the post"), - title: z.string().describe("The title of the post"), - }), - ), - comments: z.array( - z.object({ - commentsPageUrl: z.string().describe("The URL of the comments page"), - text: z.string().describe("The text of the comment"), - }), - ), + result: z.number().describe("The result of the function"), }); const run = await client.run({ initialPrompt: - "Get the top comment for the top 10 posts on Show HN: https://news.ycombinator.com/show", + "Use the available functions to generate a random number between 0 and 100", resultSchema: resultSchema, }); const result = await run.poll>({ assertions: [ function assertCorrect(result) { - const missingComments = result.topPosts.filter( - (post) => - !result.comments.some((comment) => - comment.commentsPageUrl.includes(post.id), - ), - ); - - const duplicateComments = result.comments.filter( - (comment, index, self) => - self.findIndex((c) => c.text === comment.text) !== index, - ); - - assert( - missingComments.length === 0, - `Some posts were missing comments: ${missingComments.map((m) => m.id).join(", ")}`, - ); - assert( - duplicateComments.length === 0, - `Detected duplicate comments: ${duplicateComments.map((d) => d.text).join(", ")}`, - ); + if (timesRun === 1) { + throw new Error( + `The result ${result.result} is unacceptable. Try again with a different seed.`, + ); + } }, ], }); - console.log(result); + expect(result.result).toBeGreaterThan(0); }); }); diff --git a/sdk-node/src/assertions.ts b/sdk-node/src/assertions.ts index 57ff3cb3..53cbaa41 100644 --- a/sdk-node/src/assertions.ts +++ b/sdk-node/src/assertions.ts @@ -30,7 +30,7 @@ export async function assertRun({ assertionsPassed: boolean; }> { const results = await Promise.allSettled( - assertions.map((a) => a(result, functionCalls)), + assertions.map(async (a) => await a(result, functionCalls)), ); const hasRejections = results.some((r) => r.status === "rejected"); @@ -38,7 +38,14 @@ export async function assertRun({ if (hasRejections) { await client.createMessage({ body: { - message: `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + message: [ + `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + ``, + ...results + .filter((r) => r.status === "rejected") + .map((r) => r.reason), + ``, + ].join("\n"), type: "human", }, params: { From 5473d1fbd43af2249c65689eb0f6df83295f49c9 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Sat, 7 Dec 2024 21:37:34 +1100 Subject: [PATCH 4/4] fix: Correct import case sensitivity in assertions.test.ts --- sdk-node/src/assertions.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk-node/src/assertions.test.ts b/sdk-node/src/assertions.test.ts index f0533886..e8193d2f 100644 --- a/sdk-node/src/assertions.test.ts +++ b/sdk-node/src/assertions.test.ts @@ -1,4 +1,4 @@ -import { Inferable } from "./inferable"; +import { Inferable } from "./Inferable"; import { z } from "zod"; import { TEST_ENDPOINT } from "./tests/utils"; import { TEST_API_SECRET } from "./tests/utils";