Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Introduce assertions for run validation #241

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk-node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@
"email": "[email protected]",
"url": "https://github.com/inferablehq/inferable/issues"
}
}
}
32 changes: 27 additions & 5 deletions sdk-node/src/Inferable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
validateFunctionSchema,
validateServiceName,
} from "./util";
import { Assertions, assertRun } from "./assertions";

// Custom json formatter
debug.formatters.J = (json) => {
Expand Down Expand Up @@ -108,8 +109,8 @@ export class Inferable {
*/
constructor(options?: {
apiSecret?: string;
endpoint?: string
machineId?: string
endpoint?: string;
machineId?: string;
}) {
if (options?.apiSecret && process.env.INFERABLE_API_SECRET) {
log(
Expand Down Expand Up @@ -257,14 +258,18 @@ export class Inferable {
}
}

return {
const returnable = {
id: runResult.body.id,
/**
* Polls until the run reaches a terminal state (!= "pending" && != "running" && != "paused") or maxWaitTime is reached.
* @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. Defaults to 60 seconds.
* @param interval The amount of time to wait between polling attempts. Defaults to 500ms.
*/
poll: async (options?: { maxWaitTime?: number; interval?: number }) => {
poll: async <T = unknown>(options?: {
maxWaitTime?: number;
interval?: number;
assertions?: Assertions<T>;
}): Promise<T> => {
if (!this.clusterId) {
throw new InferableError(
"Cluster ID must be provided to manage runs",
Expand Down Expand Up @@ -299,8 +304,23 @@ export class Inferable {
continue;
}

return pollResult.body;
const assertionsResult = await assertRun<T>({
clusterId: this.clusterId,
runId: runResult.body.id,
client: this.client,
result: pollResult.body.result as T,
functionCalls: [], // TODO: Add function calls
assertions: options?.assertions || [],
});

if (!assertionsResult.assertionsPassed) {
return returnable.poll(options);
}

return pollResult.body.result as T; // Cast the result to the inferred type
}

throw new InferableError("Run timed out");
},
/**
* Retrieves the messages for a run.
Expand All @@ -325,6 +345,8 @@ export class Inferable {
};
},
};

return returnable;
}

/**
Expand Down
50 changes: 50 additions & 0 deletions sdk-node/src/assertions.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { Inferable } from "./Inferable";
import { z } from "zod";
import { TEST_ENDPOINT } from "./tests/utils";
import { TEST_API_SECRET } from "./tests/utils";

describe("assertions", () => {
it("should be able to assert a run", async () => {
const client = new Inferable({
apiSecret: TEST_API_SECRET,
endpoint: TEST_ENDPOINT,
});

let timesRun = 0;

client.default.register({
name: "generateRandomNumber",
func: async ({ seed }: { seed: number }) => {
timesRun++;

return seed * timesRun;
},
});

await client.default.start();

const resultSchema = z.object({
result: z.number().describe("The result of the function"),
});

const run = await client.run({
initialPrompt:
"Use the available functions to generate a random number between 0 and 100",
resultSchema: resultSchema,
});

const result = await run.poll<z.infer<typeof resultSchema>>({
assertions: [
function assertCorrect(result) {
if (timesRun === 1) {
throw new Error(
`The result ${result.result} is unacceptable. Try again with a different seed.`,
);
}
},
],
});

expect(result.result).toBeGreaterThan(0);
});
});
65 changes: 65 additions & 0 deletions sdk-node/src/assertions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { createApiClient } from "./create-client";

type FunctionCall = {
service: string;
function: string;
};

type AssertionFunction<T> = (
result: T,
functionCalls: FunctionCall[],
) => void | Promise<void>;

export type Assertions<T> = AssertionFunction<T>[];

export async function assertRun<T>({
clusterId,
runId,
client,
result,
functionCalls,
assertions,
}: {
clusterId: string;
runId: string;
client: ReturnType<typeof createApiClient>;
result: T;
functionCalls: FunctionCall[];
assertions: Assertions<T>;
}): Promise<{
assertionsPassed: boolean;
}> {
const results = await Promise.allSettled(
assertions.map(async (a) => await a(result, functionCalls)),
);

const hasRejections = results.some((r) => r.status === "rejected");

if (hasRejections) {
await client.createMessage({
body: {
message: [
`You attempted to return a result, but I have determined the result is possibly incorrect due to failing assertions.`,
`<failures>`,
...results
.filter((r) => r.status === "rejected")
.map((r) => r.reason),
`</failures>`,
].join("\n"),
type: "human",
},
params: {
clusterId,
runId,
},
});

return {
assertionsPassed: false,
};
}

return {
assertionsPassed: true,
};
}
Loading