diff --git a/README.md b/README.md index b8314d60..c5e2d885 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,9 @@ export const d = new Inferable({ }); ``` -### 2. Hello World Service +### 2. Hello World Function -In a separate file, create the "Hello World" service. This file will import the Inferable instance from i.ts and define the service. +In a separate file, register a "sayHello" [function](https://docs.inferable.ai/pages/functions). This file will import the Inferable instance from `i.ts` and register the [function](https://docs.inferable.ai/pages/functions) with the [control-plane](https://docs.inferable.ai/pages/control-plane). ```typescript // service.ts @@ -63,12 +63,8 @@ const sayHello = async ({ to }: { to: string }) => { return `Hello, ${to}!`; }; -// Create the service -export const helloWorldService = d.service({ - name: "helloWorld", -}); - -helloWorldService.register({ +// Register the service (using the 'default' service) +const sayHello = i.default.register({ name: "sayHello", func: sayHello, schema: { @@ -77,16 +73,47 @@ helloWorldService.register({ }), }, }); + +// Start the 'default' service +i.default.start(); ``` ### 3. Running the Service -To run the service, simply run the file with the service definition. This will start the service and make it available to your Inferable agent. +To run the service, simply run the file with the [function](https://docs.inferable.ai/pages/functions) definition. This will start the `default` [service](https://docs.inferable.ai/pages/services) and make it available to the Inferable agent. ```bash tsx service.ts ``` +### 4. Trigger a run + +The following code will create an [Inferable run](https://docs.inferable.ai/pages/runs) with the prompt "Say hello to John" and the `sayHello` function attached. + +> You can inspect the progress of the run: +> +> - in the [playground UI](https://app.inferable.ai/) via `inf app` +> - in the [CLI](https://www.npmjs.com/package/@inferable/cli) via `inf runs list` + +```typescript +const run = await i.run({ + message: "Say hello to John", + functions: [sayHello], + // Alternatively, subscribe an Inferable function as a result handler which will be called when the run is complete. + //result: { handler: YOUR_HANDLER_FUNCTION } +}); + +console.log("Started Run", { + result: run.id, +}); + +console.log("Run result", { + result: await run.poll(), +}); +``` + +> Runs can also be triggered via the [API](https://docs.inferable.ai/pages/invoking-a-run-api), [CLI](https://www.npmjs.com/package/@inferable/cli) or [playground UI](https://app.inferable.ai/). + ## Documentation - [Inferable documentation](https://docs.inferable.ai/) contains all the information you need to get started with Inferable. diff --git a/src/Inferable.test.ts b/src/Inferable.test.ts index a35a4f9f..2dcdd250 100644 --- a/src/Inferable.test.ts +++ b/src/Inferable.test.ts @@ -74,10 +74,6 @@ describe("Inferable", () => { expect(() => new Inferable({ apiSecret: "invalid" })).toThrow(); }); - it("should throw if incorrect API secret is provided", () => { - expect(() => new Inferable({ apiSecret: TEST_CONSUME_SECRET })).toThrow(); - }); - it("should register a function", async () => { const d = inferableInstance(); diff --git a/src/Inferable.ts b/src/Inferable.ts index f4f2d72d..fe6c2509 100644 --- a/src/Inferable.ts +++ b/src/Inferable.ts @@ -21,6 +21,7 @@ import { validateServiceName, } from "./util"; import * as links from "./links"; +import { createApiClient } from "./create-client"; // Custom json formatter debug.formatters.J = (json) => { @@ -29,6 +30,25 @@ debug.formatters.J = (json) => { export const log = debug("inferable:client"); +type FunctionIdentifier = { + service: string; + function: string; + event?: "result"; +}; + +type RunInput = { + functions?: FunctionIdentifier[] | undefined; +} & Omit< + Required< + Parameters["createRun"]>[0] + >["body"], + "attachedFunctions" +>; + +type TemplateRunInput = Omit & { + input: Record; +}; + /** * The Inferable client. This is the main entry point for using Inferable. * @@ -72,10 +92,14 @@ export class Inferable { return require(path.join(__dirname, "..", "package.json")).version; } + private clusterId?: string; + private apiSecret: string; private endpoint: string; private machineId: string; + private client: ReturnType; + private services: Service[] = []; private functionRegistry: { [key: string]: FunctionRegistration } = {}; @@ -111,6 +135,7 @@ export class Inferable { constructor(options?: { apiSecret?: string; endpoint?: string; + clusterId?: string; jobPollWaitTime?: number; }) { if (options?.apiSecret && process.env.INFERABLE_API_SECRET) { @@ -119,6 +144,8 @@ export class Inferable { ); } + this.clusterId = options?.clusterId || process.env.INFERABLE_CLUSTER_ID; + const apiSecret = options?.apiSecret || process.env.INFERABLE_API_SECRET; if (!apiSecret) { @@ -127,13 +154,7 @@ export class Inferable { ); } - if (!apiSecret.startsWith("sk_cluster_machine")) { - if (apiSecret.startsWith("sk_")) { - throw new InferableError( - `Provided non-Machine API Secret. Please see ${links.DOCS_AUTH}`, - ); - } - + if (!apiSecret.startsWith("sk_cluster_")) { throw new InferableError( `Invalid API Secret. Please see ${links.DOCS_AUTH}`, ); @@ -146,6 +167,12 @@ export class Inferable { process.env.INFERABLE_API_ENDPOINT || "https://api.inferable.ai"; this.machineId = machineId(); + + this.client = createApiClient({ + baseUrl: this.endpoint, + machineId: this.machineId, + apiSecret: this.apiSecret, + }); } /** @@ -198,6 +225,130 @@ export class Inferable { }); } + /** + * Returns a template instance. This can be used to trigger runs of a template. + * @param input The template definition. + * @returns A registered template instance. + * @example + * ```ts + * const d = new Inferable({apiSecret: "API_SECRET"}); + * + * const template = await d.template({ id: "template-id" }); + * + * await template.run({ input: { name: "John Smith" } }); + * ``` + */ + public async template({ id }: { id: string }) { + if (!this.clusterId) { + throw new InferableError( + "Cluster ID must be provided to manage templates", + ); + } + const existingResult = await this.client.getPromptTemplate({ + params: { + clusterId: this.clusterId, + templateId: id, + }, + }); + + if (existingResult.status != 200) { + throw new InferableError(`Failed to get prompt template`, { + body: existingResult.body, + status: existingResult.status, + }); + } + + return { + id, + run: (input: TemplateRunInput) => + this.run({ + ...input, + template: { id, input: input.input }, + }), + }; + } + + /** + * Creates a run. + * @param input The run definition. + * @returns A run handle. + * @example + * ```ts + * const d = new Inferable({apiSecret: "API_SECRET"}); + * + * const run = await d.run({ message: "Hello world", functions: ["my-service.hello"] }); + * + * console.log("Started run with ID:", run.id); + * + * const result = await run.poll(); + * + * console.log("Run result:", result); + * ``` + */ + public async run(input: RunInput) { + if (!this.clusterId) { + throw new InferableError("Cluster ID must be provided to manage runs"); + } + const runResult = await this.client.createRun({ + params: { + clusterId: this.clusterId, + }, + body: { + ...input, + attachedFunctions: input.functions?.map((f) => { + if (typeof f === "string") { + return f; + } + return `${f.service}_${f.function}`; + }), + }, + }); + + if (runResult.status != 201) { + throw new InferableError("Failed to create run", { + body: runResult.body, + status: runResult.status, + }); + } + + return { + id: runResult.body.id, + /** + * Polls until the run reaches a terminal state (!= "pending" && != "running") or maxWaitTime is reached. + * @param maxWaitTime The maximum amount of time to wait for the run to reach a terminal state. + * @param delay The amount of time to wait between polling attempts. + */ + poll: async (maxWaitTime?: number, delay?: number) => { + const start = Date.now(); + const end = start + (maxWaitTime || 60_000); + + while (Date.now() < end) { + const pollResult = await this.client.getRun({ + params: { + clusterId: process.env.INFERABLE_CLUSTER_ID!, + runId: runResult.body.id, + }, + }); + + if (pollResult.status !== 200) { + throw new InferableError("Failed to poll for run", { + body: pollResult.body, + status: pollResult.status, + }); + } + if (["pending", "running"].includes(pollResult.body.status ?? "")) { + await new Promise((resolve) => { + setTimeout(resolve, delay || 500); + }); + continue; + } + + return pollResult.body; + } + }, + }; + } + /** * Registers a service with Inferable. This will register all functions on the service. * @param input The service definition. @@ -248,6 +399,11 @@ export class Inferable { config, description, }); + + return { + service: input.name, + function: name, + }; }; return { @@ -296,13 +452,6 @@ export class Inferable { }; } - /** - * The cluster ID for this Inferable instance. - */ - get clusterId(): string | null { - return this.services[0]?.clusterId || null; - } - private registerFunction({ name, authenticate, diff --git a/src/contract.ts b/src/contract.ts index cc631350..28ee1401 100644 --- a/src/contract.ts +++ b/src/contract.ts @@ -421,16 +421,10 @@ export const definition = { .string() .optional() .describe("The name of the run, if not provided it will be generated"), - reasoningTraces: z - .boolean() - .default(true) + model: z + .enum(["claude-3-5-sonnet", "claude-3-5-sonnet:beta", "claude-3-haiku"]) .optional() - .describe("Enable reasoning traces for the run"), - enableSummarization: z - .boolean() - .default(true) - .optional() - .describe("Bypass summarization for the run"), + .describe("The model identifier for the run"), result: z .object({ handler: z @@ -449,6 +443,7 @@ export const definition = { .describe("The JSON schema which the result should conform to"), }) .optional(), + // TODO: Replace with `functions` attachedFunctions: z .array(z.string()) .optional() @@ -492,6 +487,16 @@ export const definition = { }) .optional() .describe("A prompt template which the run should be created from"), + reasoningTraces: z + .boolean() + .default(true) + .optional() + .describe("Enable reasoning traces"), + callSummarization: z + .boolean() + .default(true) + .optional() + .describe("Enable summarization of oversized call results"), }), responses: { 201: z.object({ diff --git a/src/service.ts b/src/service.ts index d8ea5fa6..db85beb1 100644 --- a/src/service.ts +++ b/src/service.ts @@ -6,8 +6,6 @@ import { serializeError } from "./serialize-error"; import { executeFn, Result } from "./execute-fn"; import { FunctionRegistration } from "./types"; import { extractBlobs, validateFunctionArgs } from "./util"; -import { isZodType } from "@ts-rest/core"; -import ajv, { Ajv } from "ajv"; const MAX_CONSECUTIVE_POLL_FAILURES = 50; const DEFAULT_RETRY_AFTER_SECONDS = 10; diff --git a/src/types.ts b/src/types.ts index 0e923c8a..096ea413 100644 --- a/src/types.ts +++ b/src/types.ts @@ -75,7 +75,7 @@ export type RegisteredService = { */ register: ( input: FunctionRegistrationInput, - ) => void; + ) => { service: string; function: string }; start: () => Promise; stop: () => Promise; };