From d5401525e113cfa3226f42896bc2c2624bc2fd11 Mon Sep 17 00:00:00 2001 From: John Smith Date: Wed, 30 Oct 2024 21:23:55 +1030 Subject: [PATCH] chore(node): Use latest createRun semantics (#13) --- sdk-node/README.md | 15 ++- sdk-node/src/Inferable.ts | 121 +++++----------------- sdk-node/src/contract.ts | 211 ++++++++++---------------------------- sdk-node/src/index.ts | 2 +- sdk-node/src/types.ts | 8 +- 5 files changed, 96 insertions(+), 261 deletions(-) diff --git a/sdk-node/README.md b/sdk-node/README.md index c5e2d885..5e860d1e 100644 --- a/sdk-node/README.md +++ b/sdk-node/README.md @@ -98,15 +98,24 @@ The following code will create an [Inferable run](https://docs.inferable.ai/page ```typescript const run = await i.run({ message: "Say hello to John", - functions: [sayHello], - // Alternatively, subscribe an Inferable function as a result handler which will be called when the run is complete. - //result: { handler: YOUR_HANDLER_FUNCTION } + // Optional: Explicitly attach the `sayHello` function (All functions attached by default) + attachedFunctions: [{ + function: "sayHello", + service: "default", + }], + // Optional: Define a schema for the result to conform to + resultSchema: z.object({ + didSayHello: z.boolean() + }), + // Optional: Subscribe an Inferable function to receive notifications when the run status changes + //onStatusChange: { function: { function: "handler", service: "default" } }, }); console.log("Started Run", { result: run.id, }); +// Wait for the run to complete and log. console.log("Run result", { result: await run.poll(), }); diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index 0759370c..d3b28eb1 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -30,29 +30,18 @@ debug.formatters.J = (json) => { export const log = debug("inferable:client"); -type FunctionIdentifier = { - service: string; - function: string; - event?: "result"; -}; -type RunInput = { - functions?: FunctionIdentifier[] | undefined; -} & Omit< - Required< +type RunInput = Omit["createRun"]>[0] - >["body"], - "attachedFunctions" -> & { id?: string }; + >["body"], "resultSchema"> & { + id?: string, + resultSchema?: z.ZodType | JsonSchemaInput +}; type TemplateRunInput = Omit & { input: Record; }; -type UpsertTemplateInput = Required< - Parameters["upsertPromptTemplate"]>[0] - >["body"] & { id: string, structuredOutput: z.ZodTypeAny }; - /** * The Inferable client. This is the main entry point for using Inferable. * @@ -140,7 +129,6 @@ export class Inferable { apiSecret?: string; endpoint?: string; clusterId?: string; - jobPollWaitTime?: number; }) { if (options?.apiSecret && process.env.INFERABLE_API_SECRET) { log( @@ -229,93 +217,28 @@ export class Inferable { }); } - /** - * Registers or references a template instance. This can be used to trigger runs of a template. - * @param input The template definition or reference. - * @returns A registered template instance. - * @example - * ```ts - * const d = new Inferable({apiSecret: "API_SECRET"}); - * - * const template = await d.template({ - * id: "new-template-id", - * name: "my-template", - * attachedFunctions: ["my-service.hello"], - * prompt: "Hello {{name}}", - * structuredOutput: { greeting: z.string() } - * }); - * - * await template.run({ input: { name: "Jane Doe" } }); - * ``` - */ - public async template(input: UpsertTemplateInput) { - if (!this.clusterId) { - throw new InferableError( - "Cluster ID must be provided to manage templates", - ); - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let jsonSchema: any = undefined; - - if (!!input.structuredOutput) { - try { - jsonSchema = zodToJsonSchema(input.structuredOutput); - } catch (e) { - throw new InferableError("structuredOutput must be a valid JSON schema"); - } - } - - const upserted = await this.client.upsertPromptTemplate({ - body: { - ...input, - structuredOutput: jsonSchema, - }, - params: { - clusterId: this.clusterId, - templateId: input.id, - }, - }); - - if (upserted.status != 200) { - throw new InferableError(`Failed to register prompt template`, { - body: upserted.body, - status: upserted.status, - }); - } - - return { - id: input.id, - run: (input: TemplateRunInput) => - this.run({ - ...input, - template: { id: upserted.body.id, input: input.input }, - }), - }; - } - /** * Creates a template reference. This can be used to trigger runs of a template that was previously registered via the UI. * @param id The ID of the template to reference. * @returns A referenced template instance. */ - public async templateReference(id: string) { + public async template(template: { id: string }) { return { - id, + id: template.id, run: (input: TemplateRunInput) => - this.run({ ...input, template: { id, input: input.input } }), + this.run({ ...input, template: { id: template.id, input: input.input } }), }; } /** - * Creates a run. + * Creates a run (or retrieves an existing run if an ID is provided) and returns a reference to it. * @param input The run definition. * @returns A run handle. * @example * ```ts * const d = new Inferable({apiSecret: "API_SECRET"}); * - * const run = await d.run({ message: "Hello world", functions: ["my-service.hello"] }); + * const run = await d.run({ message: "Hello world" }); * * console.log("Started run with ID:", run.id); * @@ -329,6 +252,14 @@ export class Inferable { throw new InferableError("Cluster ID must be provided to manage runs"); } + let resultSchema: JsonSchemaInput | undefined; + + if (!!input.resultSchema && isZodType(input.resultSchema)) { + resultSchema = zodToJsonSchema(input.resultSchema) as JsonSchemaInput; + } else { + resultSchema = input.resultSchema; + } + let runResult; if (input.id) { runResult = await this.client.getRun({ @@ -351,15 +282,11 @@ export class Inferable { }, body: { ...input, - attachedFunctions: input.functions?.map((f) => { - if (typeof f === "string") { - return f; - } - return `${f.service}_${f.function}`; - }), + resultSchema, }, }); + if (runResult.status != 201) { throw new InferableError("Failed to create run", { body: runResult.body, @@ -373,11 +300,11 @@ export class Inferable { /** * Polls until the run reaches a terminal state (!= "pending" && != "running") or maxWaitTime is reached. * @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. Defaults to 60 seconds. - * @param delay The amount of time to wait between polling attempts. Defaults to 500ms. + * @param interval The amount of time to wait between polling attempts. Defaults to 500ms. */ - poll: async (maxWaitTime?: number, delay?: number) => { + poll: async (options: { maxWaitTime?: number, interval?: number }) => { const start = Date.now(); - const end = start + (maxWaitTime || 60_000); + const end = start + (options.maxWaitTime || 60_000); while (Date.now() < end) { const pollResult = await this.client.getRun({ @@ -395,7 +322,7 @@ export class Inferable { } if (["pending", "running"].includes(pollResult.body.status ?? "")) { await new Promise((resolve) => { - setTimeout(resolve, delay || 500); + setTimeout(resolve, options.interval || 500); }); continue; } diff --git a/sdk-node/src/contract.ts b/sdk-node/src/contract.ts index 958a4f12..8ba02a34 100644 --- a/sdk-node/src/contract.ts +++ b/sdk-node/src/contract.ts @@ -15,6 +15,13 @@ const machineHeaders = { // Alphanumeric, underscore, hyphen, no whitespace. From 6 to 128 characters. const userDefinedIdRegex = /^[a-zA-Z0-9-]{6,128}$/; +const functionReference = z.object({ + service: z.string(), + function: z.string(), +}); + +const anyObject = z.object({}).passthrough(); + export const blobSchema = z.object({ id: z.string(), name: z.string(), @@ -79,7 +86,7 @@ export const learningSchema = z.object({ export const agentDataSchema = z .object({ done: z.boolean().optional(), - result: z.any().optional(), + result: anyObject.optional(), summary: z.string().optional(), learnings: z.array(learningSchema).optional(), issue: z.string().optional(), @@ -111,7 +118,6 @@ export const FunctionConfigSchema = z.object({ .optional(), retryCountOnStall: z.number().optional(), timeoutSeconds: z.number().optional(), - executionIdPath: z.string().optional(), requiresApproval: z.boolean().default(false).optional(), private: z.boolean().default(false).optional(), }); @@ -249,27 +255,11 @@ export const definition = { 401: z.undefined(), }, body: z.object({ - name: z.string(), - description: z.string(), - additionalContext: z - .object({ - current: z - .object({ - version: z.string(), - content: z.string(), - }) - .describe("Current cluster context version"), - history: z - .array( - z.object({ - version: z.string(), - content: z.string(), - }), - ) - .describe("History of the cluster context versions"), - }) - .optional() - .describe("Additional cluster context which is included in all runs"), + name: z.string().optional(), + description: z.string().optional(), + additionalContext: VersionedTextsSchema.optional().describe( + "Additional cluster context which is included in all runs", + ), debug: z .boolean() .optional() @@ -289,6 +279,7 @@ export const definition = { id: z.string(), name: z.string(), description: z.string().nullable(), + additionalContext: VersionedTextsSchema.nullable(), createdAt: z.date(), debug: z.boolean(), lastPingAt: z.date().nullable(), @@ -300,56 +291,6 @@ export const definition = { clusterId: z.string(), }), }, - getService: { - method: "GET", - path: "/clusters/:clusterId/service/:serviceName", - headers: z.object({ - authorization: z.string(), - }), - responses: { - 200: z.object({ - jobs: z.array( - z.object({ - id: z.string(), - targetFn: z.string(), - service: z.string().nullable(), - status: z.string(), - resultType: z.string().nullable(), - createdAt: z.date(), - functionExecutionTime: z.number().nullable(), - }), - ), - definition: z - .object({ - name: z.string(), - functions: z - .array( - z.object({ - name: z.string(), - rate: z - .object({ - per: z.enum(["minute", "hour"]), - limit: z.number(), - }) - .optional(), - cacheTTL: z.number().optional(), - }), - ) - .optional(), - }) - .nullable(), - }), - 401: z.undefined(), - 404: z.undefined(), - }, - pathParams: z.object({ - clusterId: z.string(), - serviceName: z.string(), - }), - query: z.object({ - limit: z.coerce.number().min(100).max(5000).default(2000), - }), - }, listEvents: { method: "GET", path: "/clusters/:clusterId/events", @@ -430,6 +371,7 @@ export const definition = { .describe("The model identifier for the run"), result: z .object({ + // TODO: Remove in favour of onStatusChange handler: z .object({ service: z.string(), @@ -439,6 +381,7 @@ export const definition = { .describe( "The Inferable function which will be used to return result for the run", ), + //TODO: Remove in favour of resultSchema schema: z .object({}) .passthrough() @@ -446,12 +389,26 @@ export const definition = { .describe("The JSON schema which the result should conform to"), }) .optional(), - // TODO: Replace with `functions` + resultSchema: anyObject + .optional() + .describe( + "A JSON schema definition which the result object should conform to. By default the result will be a JSON object which does not conform to any schema", + ), attachedFunctions: z - .array(z.string()) + .array(functionReference) .optional() .describe( - "An array of attached functions (Keys should be service in the format _)", + "An array of functions to make available to the run. By default all functions in the cluster will be available", + ), + onStatusChange: z + .object({ + function: functionReference.describe( + "A function to call when the run status changes", + ), + }) + .optional() + .describe( + "Mechanism for receiving notifications when the run status changes", ), metadata: z .record(z.string()) @@ -655,7 +612,7 @@ export const definition = { test: z.boolean(), feedbackComment: z.string().nullable(), feedbackScore: z.number().nullable(), - result: z.string().nullable(), + result: anyObject.nullable(), summary: z.string().nullable(), metadata: z.record(z.string()).nullable(), attachedFunctions: z.array(z.string()).nullable(), @@ -900,23 +857,7 @@ export const definition = { 204: z.undefined(), }, }, - getClusterContext: { - method: "GET", - path: "/clusters/:clusterId/additional-context", - headers: z.object({ - authorization: z.string(), - }), - responses: { - 200: z.object({ - additionalContext: VersionedTextsSchema.nullable(), - }), - 401: z.undefined(), - 404: z.undefined(), - }, - pathParams: z.object({ - clusterId: z.string(), - }), - }, + listMachines: { method: "GET", path: "/clusters/:clusterId/machines", @@ -936,6 +877,7 @@ export const definition = { clusterId: z.string(), }), }, + listServices: { method: "GET", path: "/clusters/:clusterId/services", @@ -1063,75 +1005,47 @@ export const definition = { 404: z.undefined(), }, }, - upsertToolMetadata: { + upsertFunctionMetadata: { method: "PUT", - path: "/clusters/:clusterId/tool-metadata/:service/:function_name", + path: "/clusters/:clusterId/:service/:function/metadata", headers: z.object({ authorization: z.string(), }), pathParams: z.object({ clusterId: z.string(), service: z.string(), - function_name: z.string(), + function: z.string(), }), body: z.object({ - user_defined_context: z.string().nullable(), - result_schema: z.unknown().nullable(), + additionalContext: z.string().optional(), }), responses: { 204: z.undefined(), 401: z.undefined(), }, }, - getToolMetadata: { + getFunctionMetadata: { method: "GET", - path: "/clusters/:clusterId/tool-metadata/:service/:function_name", + path: "/clusters/:clusterId/:service/:function/metadata", headers: z.object({ authorization: z.string(), }), pathParams: z.object({ clusterId: z.string(), service: z.string(), - function_name: z.string(), + function: z.string(), }), responses: { 200: z.object({ - cluster_id: z.string(), - service: z.string(), - function_name: z.string(), - user_defined_context: z.string().nullable(), - result_schema: z.unknown().nullable(), + additionalContext: z.string().nullable(), }), 401: z.undefined(), 404: z.object({ message: z.string() }), }, }, - getAllToolMetadataForService: { - method: "GET", - path: "/clusters/:clusterId/tool-metadata/:service", - headers: z.object({ - authorization: z.string(), - }), - pathParams: z.object({ - clusterId: z.string(), - service: z.string(), - }), - responses: { - 200: z.array( - z.object({ - cluster_id: z.string(), - service: z.string(), - function_name: z.string(), - user_defined_context: z.string().nullable(), - result_schema: z.unknown().nullable(), - }), - ), - 401: z.undefined(), - }, - }, - deleteToolMetadata: { + deleteFunctionMetadata: { method: "DELETE", - path: "/clusters/:clusterId/tool-metadata/:service/:function_name", + path: "/clusters/:clusterId/:service/:function/metadata", headers: z.object({ authorization: z.string(), }), @@ -1159,7 +1073,7 @@ export const definition = { name: z.string(), prompt: z.string(), attachedFunctions: z.array(z.string()), - structuredOutput: z.unknown().nullable(), + resultSchema: anyObject.nullable().optional(), }), }, }, @@ -1174,7 +1088,7 @@ export const definition = { name: z.string(), prompt: z.string(), attachedFunctions: z.array(z.string()), - structuredOutput: z.unknown().nullable(), + resultSchema: anyObject.nullable(), createdAt: z.date(), updatedAt: z.date(), versions: z.array( @@ -1183,7 +1097,7 @@ export const definition = { name: z.string(), prompt: z.string(), attachedFunctions: z.array(z.string()), - structuredOutput: z.unknown().nullable(), + resultSchema: anyObject.nullable(), }), ), }), @@ -1205,8 +1119,8 @@ export const definition = { body: z.object({ name: z.string(), prompt: z.string(), - attachedFunctions: z.array(z.string()), - structuredOutput: z.object({}).passthrough().optional(), + attachedFunctions: z.array(z.string()).optional(), + resultSchema: anyObject.optional(), }), responses: { 201: z.object({ id: z.string() }), @@ -1228,7 +1142,7 @@ export const definition = { name: z.string().optional(), prompt: z.string().optional(), attachedFunctions: z.array(z.string()).optional(), - structuredOutput: z.object({}).passthrough().optional().nullable(), + resultSchema: z.object({}).passthrough().optional().nullable(), }), responses: { 200: z.object({ @@ -1237,7 +1151,7 @@ export const definition = { name: z.string(), prompt: z.string(), attachedFunctions: z.array(z.string()), - structuredOutput: z.unknown().nullable(), + resultSchema: z.unknown().nullable(), createdAt: z.date(), updatedAt: z.date(), }), @@ -1272,7 +1186,7 @@ export const definition = { name: z.string(), prompt: z.string(), attachedFunctions: z.array(z.string()), - structuredOutput: z.unknown().nullable(), + resultSchema: z.unknown().nullable(), createdAt: z.date(), updatedAt: z.date(), }), @@ -1298,7 +1212,7 @@ export const definition = { name: z.string(), prompt: z.string(), attachedFunctions: z.array(z.string()), - structuredOutput: z.unknown().nullable(), + resultSchema: z.unknown().nullable(), createdAt: z.date(), updatedAt: z.date(), similarity: z.number(), @@ -1310,21 +1224,6 @@ export const definition = { clusterId: z.string(), }), }, - upsertClusterContext: { - method: "PUT", - path: "/clusters/:clusterId/additional-context", - headers: z.object({ authorization: z.string() }), - body: z.object({ - additionalContext: VersionedTextsSchema, - }), - responses: { - 204: z.undefined(), - 401: z.undefined(), - }, - pathParams: z.object({ - clusterId: z.string(), - }), - }, getTemplateMetrics: { method: "GET", path: "/clusters/:clusterId/prompt-templates/:templateId/metrics", diff --git a/sdk-node/src/index.ts b/sdk-node/src/index.ts index c242c035..b8f0c7e1 100644 --- a/sdk-node/src/index.ts +++ b/sdk-node/src/index.ts @@ -22,7 +22,7 @@ export const masked = () => { }; export * as InferablePromptfooProvider from "./eval/promptfoo"; -export { resultHandlerSchema } from "./types"; +export { statusChangeSchema } from "./types"; export { validateDescription, diff --git a/sdk-node/src/types.ts b/sdk-node/src/types.ts index 096ea413..d9ab6bab 100644 --- a/sdk-node/src/types.ts +++ b/sdk-node/src/types.ts @@ -11,13 +11,13 @@ export type FunctionInput = : // eslint-disable-next-line @typescript-eslint/no-explicit-any any; -export const resultHandlerSchema = { +export const statusChangeSchema = { input: z.object({ runId: z.string(), status: z.enum(["pending", "running", "paused", "done", "failed"]), - result: z.object({}).passthrough().nullable(), - summary: z.string().nullable(), - metadata: z.record(z.string()).nullable(), + result: z.object({}).passthrough().nullable().optional(), + summary: z.string().nullable().optional(), + metadata: z.record(z.string()).nullable().optional(), }), };