diff --git a/sdk-node/src/assertions.test.ts b/sdk-node/src/assertions.test.ts index 1fd88bbb..f0533886 100644 --- a/sdk-node/src/assertions.test.ts +++ b/sdk-node/src/assertions.test.ts @@ -1,6 +1,5 @@ import { Inferable } from "./inferable"; import { z } from "zod"; -import assert from "assert"; import { TEST_ENDPOINT } from "./tests/utils"; import { TEST_API_SECRET } from "./tests/utils"; @@ -11,122 +10,41 @@ describe("assertions", () => { endpoint: TEST_ENDPOINT, }); + let timesRun = 0; + client.default.register({ - name: "fetch", - func: async ({ url }: { url: string }) => { - if (url === "https://news.ycombinator.com/show") { - return `
-

Top Posts

- -
`; - } else { - const randomPastYear = Math.min( - 2024, - Math.floor(Math.random() * 100) + 1900, - ); - return `
-

Comments

- -
`; - } + name: "generateRandomNumber", + func: async ({ seed }: { seed: number }) => { + timesRun++; + + return seed * timesRun; }, }); await client.default.start(); const resultSchema = z.object({ - topPosts: z.array( - z.object({ - id: z.string().describe("The ID of the post"), - title: z.string().describe("The title of the post"), - }), - ), - comments: z.array( - z.object({ - commentsPageUrl: z.string().describe("The URL of the comments page"), - text: z.string().describe("The text of the comment"), - }), - ), + result: z.number().describe("The result of the function"), }); const run = await client.run({ initialPrompt: - "Get the top comment for the top 10 posts on Show HN: https://news.ycombinator.com/show", + "Use the available functions to generate a random number between 0 and 100", resultSchema: resultSchema, }); const result = await run.poll>({ assertions: [ function assertCorrect(result) { - const missingComments = result.topPosts.filter( - (post) => - !result.comments.some((comment) => - comment.commentsPageUrl.includes(post.id), - ), - ); - - const duplicateComments = result.comments.filter( - (comment, index, self) => - self.findIndex((c) => c.text === comment.text) !== index, - ); - - assert( - missingComments.length === 0, - `Some posts were missing comments: ${missingComments.map((m) => m.id).join(", ")}`, - ); - assert( - duplicateComments.length === 0, - `Detected duplicate comments: ${duplicateComments.map((d) => d.text).join(", ")}`, - ); + if (timesRun === 1) { + throw new Error( + `The result ${result.result} is unacceptable. Try again with a different seed.`, + ); + } }, ], }); - console.log(result); + expect(result.result).toBeGreaterThan(0); }); }); diff --git a/sdk-node/src/assertions.ts b/sdk-node/src/assertions.ts index 57ff3cb3..53cbaa41 100644 --- a/sdk-node/src/assertions.ts +++ b/sdk-node/src/assertions.ts @@ -30,7 +30,7 @@ export async function assertRun({ assertionsPassed: boolean; }> { const results = await Promise.allSettled( - assertions.map((a) => a(result, functionCalls)), + assertions.map(async (a) => await a(result, functionCalls)), ); const hasRejections = results.some((r) => r.status === "rejected"); @@ -38,7 +38,14 @@ export async function assertRun({ if (hasRejections) { await client.createMessage({ body: { - message: `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + message: [ + `You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`, + ``, + ...results + .filter((r) => r.status === "rejected") + .map((r) => r.reason), + ``, + ].join("\n"), type: "human", }, params: {