Skip to content

Commit

Permalink
feat: Use new Template upsert endpoint (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith authored Oct 27, 2024
1 parent b8ab4eb commit 9cbc1f1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 89 deletions.
40 changes: 16 additions & 24 deletions sdk-node/src/Inferable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type TemplateRunInput = Omit<RunInput, "template" | "message" | "id"> & {
input: Record<string, unknown>;
};

type UpsertTemplateInput = Required<
Parameters<ReturnType<typeof createApiClient>["upsertPromptTemplate"]>[0]
>["body"] & { id: string, structuredOutput: z.ZodTypeAny };

/**
* The Inferable client. This is the main entry point for using Inferable.
*
Expand Down Expand Up @@ -244,44 +248,32 @@ export class Inferable {
* await template.run({ input: { name: "Jane Doe" } });
* ```
*/
public async template({
id,
attachedFunctions,
name,
prompt,
structuredOutput,
}: {
id: string;
attachedFunctions: string[];
name: string;
prompt: string;
structuredOutput: z.ZodTypeAny;
}) {
public async template(input: UpsertTemplateInput) {
if (!this.clusterId) {
throw new InferableError(
"Cluster ID must be provided to manage templates",
);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
let jsonSchema: any;
let jsonSchema: any = undefined;

try {
jsonSchema = zodToJsonSchema(structuredOutput);
} catch (e) {
throw new InferableError("structuredOutput must be a valid JSON schema");
if (!!input.structuredOutput) {
try {
jsonSchema = zodToJsonSchema(input.structuredOutput);
} catch (e) {
throw new InferableError("structuredOutput must be a valid JSON schema");
}
}

const upserted = await this.client.upsertPromptTemplate({
body: {
id,
attachedFunctions,
name,
prompt,
...input,
structuredOutput: jsonSchema,
},
params: {
clusterId: this.clusterId,
templateId: input.id,
},
});

Expand All @@ -293,11 +285,11 @@ export class Inferable {
}

return {
id,
id: input.id,
run: (input: TemplateRunInput) =>
this.run({
...input,
template: { id, input: input.input },
template: { id: upserted.body.id, input: input.input },
}),
};
}
Expand Down
96 changes: 31 additions & 65 deletions sdk-node/src/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const machineHeaders = {
};

// Alphanumeric, underscore, hyphen, no whitespace. From 6 to 128 characters.
const userDefinedIdRegex = /^[a-zA-Z0-9_-]{6,128}$/;
const userDefinedIdRegex = /^[a-zA-Z0-9-]{6,128}$/;

export const blobSchema = z.object({
id: z.string(),
Expand Down Expand Up @@ -860,7 +860,6 @@ export const definition = {
}),
body: z.object({
name: z.string(),
type: z.enum(["cluster_manage", "cluster_consume", "cluster_machine"]),
}),
responses: {
200: z.object({
Expand All @@ -881,11 +880,6 @@ export const definition = {
z.object({
id: z.string(),
name: z.string(),
type: z.enum([
"cluster_manage",
"cluster_consume",
"cluster_machine",
]),
createdAt: z.date(),
createdBy: z.string(),
revokedAt: z.date().nullable(),
Expand Down Expand Up @@ -1071,16 +1065,16 @@ export const definition = {
},
upsertToolMetadata: {
method: "PUT",
path: "/clusters/:clusterId/tool-metadata",
path: "/clusters/:clusterId/tool-metadata/:service/:function_name",
headers: z.object({
authorization: z.string(),
}),
pathParams: z.object({
clusterId: z.string(),
}),
body: z.object({
service: z.string(),
function_name: z.string(),
}),
body: z.object({
user_defined_context: z.string().nullable(),
result_schema: z.unknown().nullable(),
}),
Expand Down Expand Up @@ -1169,52 +1163,6 @@ export const definition = {
}),
},
},
createPromptTemplate: {
method: "POST",
path: "/clusters/:clusterId/prompt-templates",
headers: z.object({ authorization: z.string() }),
body: z.object({
name: z.string(),
prompt: z.string(),
attachedFunctions: z.array(z.string()),
structuredOutput: z.object({}).passthrough().optional(),
}),
responses: {
201: z.object({ id: z.string() }),
401: z.undefined(),
},
pathParams: z.object({
clusterId: z.string(),
}),
},
upsertPromptTemplate: {
method: "PUT",
path: "/clusters/:clusterId/prompt-templates",
headers: z.object({ authorization: z.string() }),
body: z.object({
id: z.string().regex(userDefinedIdRegex),
name: z.string(),
prompt: z.string(),
attachedFunctions: z.array(z.string()),
structuredOutput: z.object({}).passthrough().optional(),
}),
responses: {
201: z.object({
id: z.string(),
clusterId: z.string(),
name: z.string(),
prompt: z.string(),
attachedFunctions: z.array(z.string()),
structuredOutput: z.unknown().nullable(),
createdAt: z.date(),
updatedAt: z.date(),
}),
401: z.undefined(),
},
pathParams: z.object({
clusterId: z.string(),
}),
},
getPromptTemplate: {
method: "GET",
path: "/clusters/:clusterId/prompt-templates/:templateId",
Expand Down Expand Up @@ -1250,15 +1198,37 @@ export const definition = {
withPreviousVersions: z.enum(["true", "false"]).default("false"),
}),
},
updatePromptTemplate: {
createPromptTemplate: {
method: "POST",
path: "/clusters/:clusterId/prompt-templates",
headers: z.object({ authorization: z.string() }),
body: z.object({
name: z.string(),
prompt: z.string(),
attachedFunctions: z.array(z.string()),
structuredOutput: z.object({}).passthrough().optional(),
}),
responses: {
201: z.object({ id: z.string() }),
401: z.undefined(),
},
pathParams: z.object({
clusterId: z.string(),
}),
},
upsertPromptTemplate: {
method: "PUT",
path: "/clusters/:clusterId/prompt-templates/:templateId",
headers: z.object({ authorization: z.string() }),
pathParams: z.object({
clusterId: z.string(),
templateId: z.string().regex(userDefinedIdRegex),
}),
body: z.object({
name: z.string().optional(),
prompt: z.string().optional(),
attachedFunctions: z.array(z.string()).optional(),
structuredOutput: z.object({}).passthrough().optional(),
structuredOutput: z.object({}).passthrough().optional().nullable(),
}),
responses: {
200: z.object({
Expand All @@ -1274,10 +1244,6 @@ export const definition = {
401: z.undefined(),
404: z.object({ message: z.string() }),
},
pathParams: z.object({
clusterId: z.string(),
templateId: z.string(),
}),
},
deletePromptTemplate: {
method: "DELETE",
Expand Down Expand Up @@ -1344,7 +1310,7 @@ export const definition = {
clusterId: z.string(),
}),
},
updateClusterContext: {
upsertClusterContext: {
method: "PUT",
path: "/clusters/:clusterId/additional-context",
headers: z.object({ authorization: z.string() }),
Expand Down Expand Up @@ -1427,10 +1393,9 @@ export const definition = {
},
upsertKnowledgeArtifact: {
method: "PUT",
path: "/clusters/:clusterId/knowledge",
path: "/clusters/:clusterId/knowledge/:artifactId",
headers: z.object({ authorization: z.string() }),
body: z.object({
id: z.string(),
data: z.string(),
tags: z.array(z.string()),
title: z.string(),
Expand All @@ -1443,6 +1408,7 @@ export const definition = {
},
pathParams: z.object({
clusterId: z.string(),
artifactId: z.string().regex(userDefinedIdRegex),
}),
},
deleteKnowledgeArtifact: {
Expand Down

0 comments on commit 9cbc1f1

Please sign in to comment.