From 30fa57021104eeff37d8b912be8f3a91fd132842 Mon Sep 17 00:00:00 2001 From: John Smith Date: Tue, 19 Nov 2024 10:31:04 +1030 Subject: [PATCH] feat: Initial function initiated approvals --- sdk-node/src/contract.ts | 70 ++++++++++++++++++++++++++++++++++++-- sdk-node/src/execute-fn.ts | 13 ++++++- sdk-node/src/index.ts | 3 ++ sdk-node/src/util.ts | 54 +++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 3 deletions(-) diff --git a/sdk-node/src/contract.ts b/sdk-node/src/contract.ts index 8ea6b174..dc9cc39b 100644 --- a/sdk-node/src/contract.ts +++ b/sdk-node/src/contract.ts @@ -33,6 +33,19 @@ export const blobSchema = z.object({ workflowId: z.string().nullable(), }); +export const interruptSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal("adhoc"), + }), + z.object({ + type: z.literal("request"), + }), + z.object({ + type: z.literal("sleep"), + seconds: z.number(), + }) +]) + export const VersionedTextsSchema = z.object({ current: z.object({ version: z.string(), @@ -46,6 +59,24 @@ export const VersionedTextsSchema = z.object({ ), }); +export const integrationSchema = z.object({ + toolhouse: z + .object({ + apiKey: z.string(), + }) + .optional() + .nullable(), + langfuse: z + .object({ + publicKey: z.string(), + secretKey: z.string(), + baseUrl: z.string(), + sendMessagePayloads: z.boolean(), + }) + .optional() + .nullable(), +}); + export const genericMessageDataSchema = z .object({ message: z.string(), @@ -233,6 +264,37 @@ export const definition = { .describe("Human readable description of the cluster"), }), }, + upsertIntegrations: { + method: "PUT", + path: "/clusters/:clusterId/integrations", + headers: z.object({ + authorization: z.string(), + }), + responses: { + 200: z.undefined(), + 401: z.undefined(), + 400: z.object({ + issues: z.array(z.any()), + }), + }, + pathParams: z.object({ + clusterId: z.string(), + }), + body: integrationSchema, + }, + getIntegrations: { + method: "GET", + path: "/clusters/:clusterId/integrations", + headers: z.object({ + authorization: z.string(), + }), + responses: { + 200: integrationSchema, + }, + pathParams: z.object({ + clusterId: z.string(), + }), + }, deleteCluster: { method: "DELETE", path: "/clusters/:clusterId", @@ -266,6 +328,8 @@ export const definition = { .describe( "Enable additional logging (Including prompts and results) for use by Inferable support", ), + enableRunConfigs: z.boolean().optional(), + enableKnowledgebase: z.boolean().optional(), }), }, getCluster: { @@ -282,6 +346,8 @@ export const definition = { additionalContext: VersionedTextsSchema.nullable(), createdAt: z.date(), debug: z.boolean(), + enableRunConfigs: z.boolean(), + enableKnowledgebase: z.boolean(), lastPingAt: z.date().nullable(), }), 401: z.undefined(), @@ -1344,7 +1410,7 @@ export const definition = { 200: z.object({ id: z.string(), result: z.any().nullable(), - resultType: z.enum(["resolution", "rejection"]).nullable(), + resultType: z.enum(["resolution", "rejection", "interrupt"]).nullable(), status: z.enum(["pending", "running", "success", "failure", "stalled"]), }), }, @@ -1366,7 +1432,7 @@ export const definition = { }, body: z.object({ result: z.any(), - resultType: z.enum(["resolution", "rejection"]), + resultType: z.enum(["resolution", "rejection", "interrupt"]), meta: z.object({ functionExecutionTime: z.number().optional(), }), diff --git a/sdk-node/src/execute-fn.ts b/sdk-node/src/execute-fn.ts index 0cd4c8ae..5c6d2d89 100644 --- a/sdk-node/src/execute-fn.ts +++ b/sdk-node/src/execute-fn.ts @@ -1,9 +1,10 @@ import { serializeError } from "./serialize-error"; import { FunctionRegistration } from "./types"; +import { extractInterrupt } from "./util"; export type Result = { content: T; - type: "resolution" | "rejection"; + type: "resolution" | "rejection" | "interrupt"; functionExecutionTime?: number; }; @@ -15,6 +16,16 @@ export const executeFn = async ( try { const result = await fn(...args); + const interupt = extractInterrupt(result); + + if (interupt) { + return { + content: interupt, + type: "interrupt", + functionExecutionTime: Date.now() - start, + }; + } + return { content: result, type: "resolution", diff --git a/sdk-node/src/index.ts b/sdk-node/src/index.ts index 354a208e..14ae0a52 100644 --- a/sdk-node/src/index.ts +++ b/sdk-node/src/index.ts @@ -31,6 +31,9 @@ export { validateFunctionSchema, validateFunctionArgs, blob, + interrupt, + sleep, + requestApproval } from "./util"; export { createApiClient } from "./create-client"; diff --git a/sdk-node/src/util.ts b/sdk-node/src/util.ts index f8df3a0e..1f515561 100644 --- a/sdk-node/src/util.ts +++ b/sdk-node/src/util.ts @@ -276,3 +276,57 @@ export const blob = ({ }, }; }; + + +export const INTERRUPT_KEY = "__inferable_interrupt"; +const interruptResultSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal("adhoc"), + }), + z.object({ + type: z.literal("approval"), + }), + z.object({ + type: z.literal("sleep"), + seconds: z.number(), + }) +]) + +export const extractInterrupt = (input: unknown): z.infer | undefined => { + if (input && typeof input === "object" && INTERRUPT_KEY in input) { + const parsedInterrupt = interruptResultSchema.safeParse(input[INTERRUPT_KEY]); + + if (!parsedInterrupt.success) { + throw new InferableError("Found invalid Interrupt data"); + } + + return parsedInterrupt.data; + } +} + + +export const interrupt = () => { + return { + [INTERRUPT_KEY]: { + type: "adhoc", + }, + }; +}; + +export const sleep = (seconds: number) => { + return { + [INTERRUPT_KEY]: { + type: "sleep", + seconds, + }, + }; +}; + +export const requestApproval = () => { + return { + [INTERRUPT_KEY]: { + type: "approval", + }, + }; +}; +