Skip to content

Commit

Permalink
feat: Update interrupt semantics (#335)
Browse files Browse the repository at this point in the history
* feat: Update interrupt semantics

* chore: Remove structured output API
  • Loading branch information
johnjcsmith authored Dec 19, 2024
1 parent b2ee27b commit 2a935e1
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 47 deletions.
26 changes: 0 additions & 26 deletions sdk-node/src/Inferable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,30 +289,4 @@ describe("Inferable SDK End to End Test", () => {
await client.default.stop();
}
});

describe("api", () => {
it("should be able to call the api directly", async () => {
const client = inferableInstance();

const result = await client.api.createStructuredOutput<{
capital: string;
}>({
prompt: "What is the capital of France?",
modelId: "claude-3-5-sonnet",
resultSchema: {
type: "object",
properties: {
capital: { type: "string" },
},
},
});

expect(result).toMatchObject({
success: true,
data: {
capital: "Paris",
},
});
});
});
});
19 changes: 0 additions & 19 deletions sdk-node/src/Inferable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -514,25 +514,6 @@ export class Inferable {
this.functionRegistry[registration.name] = registration;
}

public get api() {
return {
createStructuredOutput: async <T>(
input: Parameters<typeof this.client.createStructuredOutput>[0]["body"],
) =>
this.client
.createStructuredOutput({
params: {
clusterId: await this.getClusterId(),
},
body: input,
})
.then((r) => r.body) as Promise<{
success: boolean;
data?: T;
}>,
};
}

public async getClusterId() {
if (!this.clusterId) {
// Call register machine without any services to test API key and get clusterId
Expand Down
31 changes: 29 additions & 2 deletions sdk-node/src/execute-fn.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,34 @@
import { executeFn } from "./execute-fn";
import { Interrupt } from "./util";

describe("executeFn", () => {
it("should run a function with arguments", async () => {
const fn = (val: { [key: string]: string }) => Promise.resolve(val.foo);
const result = await fn({ foo: "bar" });
expect(result).toBe("bar");
const result = await executeFn(fn, [{foo: "bar"}] as any);
expect(result).toEqual({
content: "bar",
type: "resolution",
functionExecutionTime: expect.any(Number),
});
});

it("should extract interrupt from resolution", async () => {
const fn = (_: string) => Promise.resolve(Interrupt.approval());
const result = await executeFn(fn, [{}] as any);
expect(result).toEqual({
content: { type: "approval" },
type: "interrupt",
functionExecutionTime: expect.any(Number),
});
});

it("should extract interrupt from rejection", async () => {
const fn = () => Promise.reject(Interrupt.approval());
const result = await executeFn(fn, [{}] as any);
expect(result).toEqual({
content: { type: "approval" },
type: "interrupt",
functionExecutionTime: expect.any(Number),
});
});
});
8 changes: 8 additions & 0 deletions sdk-node/src/execute-fn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ export const executeFn = async (
functionExecutionTime: Date.now() - start,
};
} catch (e) {
const interupt = extractInterrupt(e);
if (interupt) {
return {
content: interupt,
type: "interrupt",
functionExecutionTime: Date.now() - start,
};
}
const functionExecutionTime = Date.now() - start;
if (e instanceof Error) {
return {
Expand Down
17 changes: 17 additions & 0 deletions sdk-node/src/util.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import {
ajvErrorToFailures,
blob,
extractBlobs,
extractInterrupt,
Interrupt,
validateFunctionSchema,
} from "./util";

Expand Down Expand Up @@ -64,6 +66,21 @@ describe("ajvErrorToFailures", () => {
});
});

describe("extractInterrupt", () => {
it("should extract extract interrupt", () => {
const interrupt = extractInterrupt(Interrupt.approval());
expect(interrupt).toEqual({
type: "approval",
})
})

it("should not extract interrupt from non-interrupt", () => {
const interrupt = extractInterrupt({ foo: "bar" });
expect(interrupt).toBeUndefined();
})

});

describe("extractBlobs", () => {
it("should extract blobs from content", () => {
const initialContent = {
Expand Down
20 changes: 20 additions & 0 deletions sdk-node/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ export const blob = ({


export const INTERRUPT_KEY = "__inferable_interrupt";
type VALID_INTERRUPT_TYPES = "approval";
const interruptResultSchema = z.discriminatedUnion("type", [
z.object({
type: z.literal("approval"),
Expand All @@ -297,6 +298,25 @@ export const extractInterrupt = (input: unknown): z.infer<typeof interruptResult
}
}

export class Interrupt {
[INTERRUPT_KEY]: {
type: VALID_INTERRUPT_TYPES
}

constructor(type: VALID_INTERRUPT_TYPES) {
this[INTERRUPT_KEY] = {
type
}
}

static approval() {
return new Interrupt("approval");
}
}

/**
* @deprecated Use Interrupt.approval() instea
*/
export const approvalRequest = () => {
return {
[INTERRUPT_KEY]: {
Expand Down

0 comments on commit 2a935e1

Please sign in to comment.