From a8fc02f5c40c7eb48da955857f996ea25b367d39 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Wed, 13 Nov 2024 09:17:51 +1100 Subject: [PATCH] feat: Implement direct API call with structured output --- sdk-node/src/Inferable.test.ts | 29 +++++++++++++++++++++++++++++ sdk-node/src/Inferable.ts | 12 ++++++------ sdk-node/src/contract.ts | 26 ++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/sdk-node/src/Inferable.test.ts b/sdk-node/src/Inferable.test.ts index bc3c40e7..c706ffb8 100644 --- a/sdk-node/src/Inferable.test.ts +++ b/sdk-node/src/Inferable.test.ts @@ -289,4 +289,33 @@ describe("Inferable SDK End to End Test", () => { await client.default.stop(); } }); + + describe("api", () => { + it("should be able to call the api directly", async () => { + const client = inferableInstance(); + + expect( + await client.api.createStructuredOutput({ + params: { + clusterId: TEST_CLUSTER_ID, + }, + body: { + prompt: "What is the capital of France?", + modelId: "claude-3-5-sonnet", + resultSchema: { + type: "object", + properties: { + capital: { type: "string" }, + }, + }, + }, + }), + ).toEqual({ + success: true, + result: { + capital: "Paris", + }, + }); + }); + }); }); diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index 3fb4f7b3..55f637b0 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -110,10 +110,7 @@ export class Inferable { * const client = new Inferable(); * ``` */ - constructor(options?: { - apiSecret?: string; - endpoint?: string; - }) { + constructor(options?: { apiSecret?: string; endpoint?: string }) { if (options?.apiSecret && process.env.INFERABLE_API_SECRET) { log( "API Secret was provided as an option and environment variable. Constructor argument will be used.", @@ -217,7 +214,6 @@ export class Inferable { * ``` */ public async run(input: RunInput) { - let resultSchema: JsonSchemaInput | undefined; if (!!input.resultSchema && isZodType(input.resultSchema)) { @@ -494,6 +490,10 @@ export class Inferable { this.functionRegistry[registration.name] = registration; } + public get api() { + return this.client; + } + private async getClusterId() { if (!this.clusterId) { // Call register machine without any services to test API key and get clusterId @@ -501,6 +501,6 @@ export class Inferable { this.clusterId = registerResult.clusterId; } - return this.clusterId! + return this.clusterId!; } } diff --git a/sdk-node/src/contract.ts b/sdk-node/src/contract.ts index ef49693d..1427f7ae 100644 --- a/sdk-node/src/contract.ts +++ b/sdk-node/src/contract.ts @@ -1424,6 +1424,32 @@ export const definition = { artifactId: z.string(), }), }, + createStructuredOutput: { + method: "POST", + path: "/clusters/:clusterId/structured-output", + headers: z.object({ authorization: z.string() }), + body: z.object({ + prompt: z.string(), + resultSchema: anyObject + .optional() + .describe( + "A JSON schema definition which the result object should conform to. By default the result will be a JSON object which does not conform to any schema", + ), + modelId: z.enum(["claude-3-5-sonnet", "claude-3-haiku"]), + temperature: z + .number() + .optional() + .describe("The temperature to use for the model") + .default(0.5), + }), + responses: { + 200: z.unknown(), + 401: z.undefined(), + }, + pathParams: z.object({ + clusterId: z.string(), + }), + }, } as const; export const contract = c.router(definition);