Skip to content

Commit

Permalink
feat: Initial function initiated approvals (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith authored Nov 21, 2024
1 parent ee22fc4 commit 9e54ea7
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 6 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,10 @@ For language-specific quick start guides, please refer to the README in each SDK

| Feature | Node.js | Go | .NET |
| ------------------------------------------------------------------------------------------------------ | :-----: | :-: | :--: |
| [Blob](https://docs.inferable.ai/pages/functions#blob) results ||||
| [Mask](https://docs.inferable.ai/pages/functions#masked) results ||||
| [Cached](https://docs.inferable.ai/pages/functions#config-cache) results ||||
| Call [Timeouts](https://docs.inferable.ai/pages/functions#config-timeoutseconds) ||||
| Call [Retries](https://docs.inferable.ai/pages/functions#config-retrycountonstall) ||||
| Call [Approval](https://docs.inferable.ai/pages/functions#config-requiresapproval) (Human in the loop) ||||
| Call [Approval](https://docs.inferable.ai/pages/functions#approvalrequest) (Human in the loop) ||||
| [Auth / Run Context](https://docs.inferable.ai/pages/runs#context) ||||

## Documentation
Expand Down
106 changes: 104 additions & 2 deletions sdk-node/src/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ const functionReference = z.object({

const anyObject = z.object({}).passthrough();

export const interruptSchema = z.discriminatedUnion("type", [
z.object({
type: z.literal("approval"),
}),
]);

export const blobSchema = z.object({
id: z.string(),
name: z.string(),
Expand All @@ -46,6 +52,24 @@ export const VersionedTextsSchema = z.object({
),
});

export const integrationSchema = z.object({
toolhouse: z
.object({
apiKey: z.string(),
})
.optional()
.nullable(),
langfuse: z
.object({
publicKey: z.string(),
secretKey: z.string(),
baseUrl: z.string(),
sendMessagePayloads: z.boolean(),
})
.optional()
.nullable(),
});

export const genericMessageDataSchema = z
.object({
message: z.string(),
Expand Down Expand Up @@ -118,6 +142,7 @@ export const FunctionConfigSchema = z.object({
.optional(),
retryCountOnStall: z.number().optional(),
timeoutSeconds: z.number().optional(),
// Deprecated
requiresApproval: z.boolean().default(false).optional(),
private: z.boolean().default(false).optional(),
});
Expand Down Expand Up @@ -159,6 +184,26 @@ export const definition = {
}),
},
},
createCallApproval: {
method: "POST",
path: "/clusters/:clusterId/calls/:callId/approval",
headers: z.object({
authorization: z.string(),
}),
pathParams: z.object({
clusterId: z.string(),
callId: z.string(),
}),
responses: {
204: z.undefined(),
404: z.object({
message: z.string(),
}),
},
body: z.object({
approved: z.boolean(),
}),
},
createCallBlob: {
method: "POST",
path: "/clusters/:clusterId/calls/:callId/blobs",
Expand Down Expand Up @@ -233,6 +278,37 @@ export const definition = {
.describe("Human readable description of the cluster"),
}),
},
upsertIntegrations: {
method: "PUT",
path: "/clusters/:clusterId/integrations",
headers: z.object({
authorization: z.string(),
}),
responses: {
200: z.undefined(),
401: z.undefined(),
400: z.object({
issues: z.array(z.any()),
}),
},
pathParams: z.object({
clusterId: z.string(),
}),
body: integrationSchema,
},
getIntegrations: {
method: "GET",
path: "/clusters/:clusterId/integrations",
headers: z.object({
authorization: z.string(),
}),
responses: {
200: integrationSchema,
},
pathParams: z.object({
clusterId: z.string(),
}),
},
deleteCluster: {
method: "DELETE",
path: "/clusters/:clusterId",
Expand Down Expand Up @@ -266,6 +342,8 @@ export const definition = {
.describe(
"Enable additional logging (Including prompts and results) for use by Inferable support",
),
enableRunConfigs: z.boolean().optional(),
enableKnowledgebase: z.boolean().optional(),
}),
},
getCluster: {
Expand All @@ -282,6 +360,8 @@ export const definition = {
additionalContext: VersionedTextsSchema.nullable(),
createdAt: z.date(),
debug: z.boolean(),
enableRunConfigs: z.boolean(),
enableKnowledgebase: z.boolean(),
lastPingAt: z.date().nullable(),
}),
401: z.undefined(),
Expand Down Expand Up @@ -938,6 +1018,8 @@ export const definition = {
service: z.string(),
resultType: z.string().nullable(),
createdAt: z.date(),
approved: z.boolean().nullable(),
approvalRequested: z.boolean().nullable(),
}),
),
inputRequests: z.array(
Expand Down Expand Up @@ -1306,6 +1388,25 @@ export const definition = {
401: z.undefined(),
},
},
getAllKnowledgeArtifacts: {
method: "GET",
path: "/clusters/:clusterId/knowledge-export",
headers: z.object({ authorization: z.string() }),
responses: {
200: z.array(
z.object({
id: z.string(),
data: z.string(),
tags: z.array(z.string()),
title: z.string(),
}),
),
401: z.undefined(),
},
pathParams: z.object({
clusterId: z.string(),
}),
},
createRunRetry: {
method: "POST",
path: "/clusters/:clusterId/runs/:runId/retry",
Expand Down Expand Up @@ -1344,7 +1445,7 @@ export const definition = {
200: z.object({
id: z.string(),
result: z.any().nullable(),
resultType: z.enum(["resolution", "rejection"]).nullable(),
resultType: z.enum(["resolution", "rejection", "interrupt"]).nullable(),
status: z.enum(["pending", "running", "success", "failure", "stalled"]),
}),
},
Expand All @@ -1366,7 +1467,7 @@ export const definition = {
},
body: z.object({
result: z.any(),
resultType: z.enum(["resolution", "rejection"]),
resultType: z.enum(["resolution", "rejection", "interrupt"]),
meta: z.object({
functionExecutionTime: z.number().optional(),
}),
Expand Down Expand Up @@ -1405,6 +1506,7 @@ export const definition = {
input: z.any(),
authContext: z.any().nullable(),
runContext: z.any().nullable(),
approved: z.boolean(),
}),
),
},
Expand Down
13 changes: 12 additions & 1 deletion sdk-node/src/execute-fn.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { serializeError } from "./serialize-error";
import { FunctionRegistration } from "./types";
import { extractInterrupt } from "./util";

export type Result<T = unknown> = {
content: T;
type: "resolution" | "rejection";
type: "resolution" | "rejection" | "interrupt";
functionExecutionTime?: number;
};

Expand All @@ -15,6 +16,16 @@ export const executeFn = async (
try {
const result = await fn(...args);

const interupt = extractInterrupt(result);

if (interupt) {
return {
content: interupt,
type: "interrupt",
functionExecutionTime: Date.now() - start,
};
}

return {
content: result,
type: "resolution",
Expand Down
1 change: 1 addition & 0 deletions sdk-node/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export {
validateFunctionSchema,
validateFunctionArgs,
blob,
approvalRequest
} from "./util";

export { createApiClient } from "./create-client";
2 changes: 2 additions & 0 deletions sdk-node/src/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type CallMessage = {
input?: unknown;
authContext?: unknown;
runContext?: string;
approved: boolean;
};

export class Service {
Expand Down Expand Up @@ -273,6 +274,7 @@ export class Service {
[args, {
authContext: call.authContext,
runContext: call.runContext,
approved: call.approved,
}],
);

Expand Down
1 change: 1 addition & 0 deletions sdk-node/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { FunctionConfigSchema } from "./contract";
export type ContextInput = {
authContext?: unknown;
runContext?: unknown;
approved: boolean;
};

export type FunctionConfig = z.infer<typeof FunctionConfigSchema>;
Expand Down
29 changes: 29 additions & 0 deletions sdk-node/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,32 @@ export const blob = ({
},
};
};


export const INTERRUPT_KEY = "__inferable_interrupt";
const interruptResultSchema = z.discriminatedUnion("type", [
z.object({
type: z.literal("approval"),
}),
])

export const extractInterrupt = (input: unknown): z.infer<typeof interruptResultSchema> | undefined => {
if (input && typeof input === "object" && INTERRUPT_KEY in input) {
const parsedInterrupt = interruptResultSchema.safeParse(input[INTERRUPT_KEY]);

if (!parsedInterrupt.success) {
throw new InferableError("Found invalid Interrupt data");
}

return parsedInterrupt.data;
}
}

export const approvalRequest = () => {
return {
[INTERRUPT_KEY]: {
type: "approval",
},
};
};

0 comments on commit 9e54ea7

Please sign in to comment.