Skip to content

Commit

Permalink
fix: Expose internal inferable client as client.api (#92)
Browse files Browse the repository at this point in the history
* feat: Implement direct API call with structured output

* fix: Remove non-null assertion from return statement in Inferable.ts

* update
  • Loading branch information
nadeesha authored Nov 13, 2024
1 parent 7dcc6f2 commit ca2d3b3
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 8 deletions.
3 changes: 2 additions & 1 deletion sdk-node/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sdk-node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"devDependencies": {
"@babel/preset-env": "^7.26.0",
"@babel/preset-typescript": "^7.22.11",
"@types/jest": "^29.5.4",
"@types/jest": "^29.5.14",
"@types/node-os-utils": "^1.3.4",
"@typescript-eslint/eslint-plugin": "^8.12.2",
"@typescript-eslint/parser": "^8.12.2",
Expand Down
26 changes: 26 additions & 0 deletions sdk-node/src/Inferable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,30 @@ describe("Inferable SDK End to End Test", () => {
await client.default.stop();
}
});

describe("api", () => {
it("should be able to call the api directly", async () => {
const client = inferableInstance();

const result = await client.api.createStructuredOutput<{
capital: string;
}>({
prompt: "What is the capital of France?",
modelId: "claude-3-5-sonnet",
resultSchema: {
type: "object",
properties: {
capital: { type: "string" },
},
},
});

expect(result).toMatchObject({
success: true,
data: {
capital: "Paris",
},
});
});
});
});
27 changes: 21 additions & 6 deletions sdk-node/src/Inferable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ export class Inferable {
* const client = new Inferable();
* ```
*/
constructor(options?: {
apiSecret?: string;
endpoint?: string;
}) {
constructor(options?: { apiSecret?: string; endpoint?: string }) {
if (options?.apiSecret && process.env.INFERABLE_API_SECRET) {
log(
"API Secret was provided as an option and environment variable. Constructor argument will be used.",
Expand Down Expand Up @@ -217,7 +214,6 @@ export class Inferable {
* ```
*/
public async run(input: RunInput) {

let resultSchema: JsonSchemaInput | undefined;

if (!!input.resultSchema && isZodType(input.resultSchema)) {
Expand Down Expand Up @@ -494,13 +490,32 @@ export class Inferable {
this.functionRegistry[registration.name] = registration;
}

public get api() {
return {
createStructuredOutput: async <T>(
input: Parameters<typeof this.client.createStructuredOutput>[0]["body"],
) =>
this.client
.createStructuredOutput({
params: {
clusterId: await this.getClusterId(),
},
body: input,
})
.then((r) => r.body) as Promise<{
success: boolean;
data?: T;
}>,
};
}

private async getClusterId() {
if (!this.clusterId) {
// Call register machine without any services to test API key and get clusterId
const registerResult = await registerMachine(this.client);
this.clusterId = registerResult.clusterId;
}

return this.clusterId!
return this.clusterId;
}
}
26 changes: 26 additions & 0 deletions sdk-node/src/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,32 @@ export const definition = {
artifactId: z.string(),
}),
},
createStructuredOutput: {
method: "POST",
path: "/clusters/:clusterId/structured-output",
headers: z.object({ authorization: z.string() }),
body: z.object({
prompt: z.string(),
resultSchema: anyObject
.optional()
.describe(
"A JSON schema definition which the result object should conform to. By default the result will be a JSON object which does not conform to any schema",
),
modelId: z.enum(["claude-3-5-sonnet", "claude-3-haiku"]),
temperature: z
.number()
.optional()
.describe("The temperature to use for the model")
.default(0.5),
}),
responses: {
200: z.unknown(),
401: z.undefined(),
},
pathParams: z.object({
clusterId: z.string(),
}),
},
} as const;

export const contract = c.router(definition);

0 comments on commit ca2d3b3

Please sign in to comment.