Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use Json schema for model calls #308

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 13 additions & 60 deletions control-plane/src/modules/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import AsyncRetry from "async-retry";
import { zodToJsonSchema } from "zod-to-json-schema";
import { JsonSchema7Type } from "zod-to-json-schema";
import Anthropic from "@anthropic-ai/sdk";
import { ZodError, z } from "zod";
import { ToolUseBlock } from "@anthropic-ai/sdk/resources";
import {
ChatIdentifiers,
Expand Down Expand Up @@ -33,27 +32,20 @@ type CallOutput = {
raw: Anthropic.Message;
};


type StructuredCallInput = CallInput & {
schema: z.ZodType;
schema: JsonSchema7Type;
};

type StructuredCallOutput<T extends StructuredCallInput> = CallOutput & {
parsed:
| {
success: true;
data: z.infer<T["schema"]>;
}
| {
success: false;
error: ZodError<T["schema"]>;
};
type StructuredCallOutput = CallOutput & {
structured: unknown
};

export type Model = {
call: (options: CallInput) => Promise<CallOutput>;
structured: <T extends StructuredCallInput>(
options: T,
) => Promise<StructuredCallOutput<T>>;
) => Promise<StructuredCallOutput>;
identifier: ChatIdentifiers | EmbeddingIdentifiers;
embedQuery: (input: string) => Promise<number[]>;
};
Expand Down Expand Up @@ -219,9 +211,7 @@ export const buildModel = ({
// This is enforced above
...(tools as Anthropic.Tool[]),
{
input_schema: zodToJsonSchema(
options.schema,
) as Anthropic.Tool.InputSchema,
input_schema: options.schema as Anthropic.Tool.InputSchema,
name: "extract",
},
],
Expand Down Expand Up @@ -259,7 +249,7 @@ export const buildModel = ({
throw new Error("Model did not return output");
}

return parseStructuredResponse({ response, options });
return parseStructuredResponse({ response });
},
};
};
Expand Down Expand Up @@ -297,10 +287,8 @@ const handleErrror = async ({

const parseStructuredResponse = ({
response,
options,
}: {
response: Anthropic.Message;
options: StructuredCallInput;
}): Awaited<ReturnType<Model["structured"]>> => {
const toolCalls = response.content.filter((m) => m.type === "tool_use");

Expand All @@ -311,29 +299,9 @@ const parseStructuredResponse = ({
throw new Error("Model did not return structured output");
}

const extractToolResult = options.schema.safeParse(extractResult.input);

const returnVal = {
raw: response,
parsed: {},
};

if (extractToolResult.success) {
return {
...returnVal,
parsed: {
success: true,
data: extractToolResult.data,
},
};
}

return {
...returnVal,
parsed: {
success: false,
error: extractToolResult.error,
},
raw: response,
structured: extractResult.input,
};
};

Expand All @@ -353,36 +321,21 @@ export const buildMockModel = ({
call: async () => {
throw new Error("Not implemented");
},
structured: async (options) => {
structured: async () => {
if (responseCount >= mockResponses.length) {
throw new Error("Mock model ran out of responses");
}

const parsed = options.schema.safeParse(
JSON.parse(mockResponses[responseCount]),
);
const data = JSON.parse(mockResponses[responseCount]);

// Sleep for between 500 and 1500 ms
await new Promise((resolve) =>
setTimeout(resolve, Math.random() * 1000 + 500),
);

if (!parsed.success) {
return {
raw: { content: [] } as unknown as Anthropic.Message,
parsed: {
success: false,
error: parsed.error,
},
};
}

return {
raw: { content: [] } as unknown as Anthropic.Message,
parsed: {
success: true,
data: parsed.data,
},
structured: data,
};
},
};
Expand Down
5 changes: 2 additions & 3 deletions control-plane/src/modules/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import {
import { callsRouter } from "./calls/router";
import { buildModel } from "./models";
import {
deserializeFunctionSchema,
getServiceDefinitions,
} from "./service-definitions";
import { integrationsRouter } from "./integrations/router";
Expand Down Expand Up @@ -906,12 +905,12 @@ export const router = initServer().router(contract, {

const result = await model.structured({
messages: [{ role: "user", content: prompt }],
schema: deserializeFunctionSchema(resultSchema),
schema: resultSchema,
});

return {
status: 200,
body: result.parsed,
body: result.structured,
};
},
getServerStats: async () => {
Expand Down
90 changes: 32 additions & 58 deletions control-plane/src/modules/workflows/agent/nodes/model-call.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,9 @@ describe("handleModelCall", () => {
raw: {
content: [],
},
parsed: {
success: true,
data: {
done: true,
result: { reason: "nothing to do" },
},
structured: {
done: true,
message: "nothing to do"
},
});

Expand All @@ -95,19 +92,16 @@ describe("handleModelCall", () => {
raw: {
content: [],
},
parsed: {
success: true,
data: {
done: true,
result: { reason: "nothing to do" },
invocations: [
{
toolName: "notify",
input: { message: "A message" },
reasoning: "notify the system",
},
],
},
structured: {
done: true,
message: "nothing to do",
invocations: [
{
toolName: "notify",
input: { message: "A message" },
reasoning: "notify the system",
},
],
},
});

Expand Down Expand Up @@ -139,11 +133,8 @@ describe("handleModelCall", () => {
raw: {
content: [],
},
parsed: {
success: true,
data: {
result: { reason: "nothing to do" },
},
structured: {
message: "nothing to do",
},
});

Expand All @@ -159,7 +150,7 @@ describe("handleModelCall", () => {
expect(result.messages![0].data).toHaveProperty(
"details",
expect.objectContaining({
result: { reason: "nothing to do" },
message: "nothing to do",
}),
);

Expand All @@ -177,11 +168,8 @@ describe("handleModelCall", () => {
raw: {
content: [],
},
parsed: {
success: true,
data: {
done: true,
},
structured: {
done: true,
},
});

Expand Down Expand Up @@ -258,16 +246,8 @@ describe("handleModelCall", () => {
raw: {
content: [],
},
parsed: {
success: false,
error: {
errors: [
{
path: [""],
message: "Test error",
},
],
},
structured: {
randomStuff: "123",
},
});

Expand Down Expand Up @@ -295,11 +275,8 @@ describe("handleModelCall", () => {
describe("additional tool calls", () => {
it("should add call to empty invocations array", async () => {
mockWithStructuredOutput.mockReturnValueOnce({
parsed: {
success: true,
data: {
done: false,
},
structured: {
done: false,
},
raw: {
content: [
Expand Down Expand Up @@ -344,20 +321,17 @@ describe("handleModelCall", () => {

it("should add to existing invocations array", async () => {
mockWithStructuredOutput.mockReturnValueOnce({
parsed: {
success: true,
data: {
done: false,
invocations: [
{
toolName: "notify",
reasoning: "notify the system",
input: {
message: "the first notification",
},
structured: {
done: false,
invocations: [
{
toolName: "notify",
reasoning: "notify the system",
input: {
message: "the first notification",
},
],
},
},
],
},
raw: {
content: [
Expand Down
6 changes: 4 additions & 2 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import { JsonSchemaInput } from "inferable/bin/types";
import { Model } from "../../../models";
import { ToolUseBlock } from "@anthropic-ai/sdk/resources";

import { zodToJsonSchema } from "zod-to-json-schema";

type WorkflowStateUpdate = Partial<WorkflowAgentState>;

export const MODEL_CALL_NODE_NAME = "model";
Expand Down Expand Up @@ -154,7 +156,7 @@ const _handleModelCall = async (
const response = await model.structured({
messages: renderedMessages,
system: systemPrompt,
schema: modelSchema,
schema: zodToJsonSchema(modelSchema),
});

if (!response) {
Expand All @@ -165,7 +167,7 @@ const _handleModelCall = async (
.filter((m) => m.type === "tool_use" && m.name !== "extract")
.map((m) => m as ToolUseBlock);

const parsed = response.parsed;
const parsed = modelSchema.safeParse(response.structured);

if (!parsed.success) {
logger.info("Model provided invalid response object", {
Expand Down
19 changes: 2 additions & 17 deletions control-plane/src/modules/workflows/agent/tools/mock-function.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { AgentError } from "../../../../utilities/errors";
import { logger } from "../../../observability/logger";
import {
deserializeFunctionSchema,
serviceFunctionEmbeddingId,
} from "../../../service-definitions";
import { AgentTool } from "../tool";
Expand All @@ -20,31 +19,17 @@ export const buildMockFunctionTool = ({
functionName: string;
serviceName: string;
description?: string;
schema: unknown;
schema?: string;
mockResult: unknown;
}): AgentTool => {
const toolName = serviceFunctionEmbeddingId({ serviceName, functionName });

let deserialized = null;

try {
deserialized = deserializeFunctionSchema(schema);
} catch (e) {
logger.error(
`Failed to deserialize schema for ${toolName} (${serviceName}.${functionName})`,
{ schema, error: e },
);
throw new AgentError(
`Failed to deserialize schema for ${toolName} (${serviceName}.${functionName})`,
);
}

return new AgentTool({
name: toolName,
description: (
description ?? `${serviceName}-${functionName} function`
).substring(0, 1024),
schema: deserialized,
schema,
func: async (input: unknown) => {
logger.info("Mock tool call", { toolName, input });

Expand Down
5 changes: 3 additions & 2 deletions control-plane/src/modules/workflows/summarization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ export const generateTitle = async (
schema,
});

const parsed = response.parsed;
const parsed = schema.safeParse(response.structured);

if (!parsed.success) {
logger.error("Model did not return valid output", {
errors: parsed.error.errors,
errors: parsed.error.issues,
});

throw new RetryableError("Invalid title output from model");
Expand Down
Loading