From 2b0104e0b344ce15f85a8fe8df52dd95af934276 Mon Sep 17 00:00:00 2001 From: John Smith Date: Sun, 15 Dec 2024 14:46:05 +1030 Subject: [PATCH] feat: Use Json schema for model calls --- control-plane/src/modules/models/index.ts | 73 +++------------ control-plane/src/modules/router.ts | 5 +- .../workflows/agent/nodes/model-call.test.ts | 90 +++++++------------ .../workflows/agent/nodes/model-call.ts | 6 +- .../workflows/agent/tools/mock-function.ts | 19 +--- .../src/modules/workflows/summarization.ts | 5 +- 6 files changed, 56 insertions(+), 142 deletions(-) diff --git a/control-plane/src/modules/models/index.ts b/control-plane/src/modules/models/index.ts index d3aa4113..188b690f 100644 --- a/control-plane/src/modules/models/index.ts +++ b/control-plane/src/modules/models/index.ts @@ -1,7 +1,6 @@ import AsyncRetry from "async-retry"; -import { zodToJsonSchema } from "zod-to-json-schema"; +import { JsonSchema7Type } from "zod-to-json-schema"; import Anthropic from "@anthropic-ai/sdk"; -import { ZodError, z } from "zod"; import { ToolUseBlock } from "@anthropic-ai/sdk/resources"; import { ChatIdentifiers, @@ -33,27 +32,20 @@ type CallOutput = { raw: Anthropic.Message; }; + type StructuredCallInput = CallInput & { - schema: z.ZodType; + schema: JsonSchema7Type; }; -type StructuredCallOutput = CallOutput & { - parsed: - | { - success: true; - data: z.infer; - } - | { - success: false; - error: ZodError; - }; +type StructuredCallOutput = CallOutput & { + structured: unknown }; export type Model = { call: (options: CallInput) => Promise; structured: ( options: T, - ) => Promise>; + ) => Promise; identifier: ChatIdentifiers | EmbeddingIdentifiers; embedQuery: (input: string) => Promise; }; @@ -219,9 +211,7 @@ export const buildModel = ({ // This is enforced above ...(tools as Anthropic.Tool[]), { - input_schema: zodToJsonSchema( - options.schema, - ) as Anthropic.Tool.InputSchema, + input_schema: options.schema as Anthropic.Tool.InputSchema, name: "extract", }, ], @@ -259,7 +249,7 @@ export const buildModel = ({ throw new Error("Model did not return output"); } - return parseStructuredResponse({ response, options }); + return parseStructuredResponse({ response }); }, }; }; @@ -297,10 +287,8 @@ const handleErrror = async ({ const parseStructuredResponse = ({ response, - options, }: { response: Anthropic.Message; - options: StructuredCallInput; }): Awaited> => { const toolCalls = response.content.filter((m) => m.type === "tool_use"); @@ -311,29 +299,9 @@ const parseStructuredResponse = ({ throw new Error("Model did not return structured output"); } - const extractToolResult = options.schema.safeParse(extractResult.input); - - const returnVal = { - raw: response, - parsed: {}, - }; - - if (extractToolResult.success) { - return { - ...returnVal, - parsed: { - success: true, - data: extractToolResult.data, - }, - }; - } - return { - ...returnVal, - parsed: { - success: false, - error: extractToolResult.error, - }, + raw: response, + structured: extractResult.input, }; }; @@ -353,36 +321,21 @@ export const buildMockModel = ({ call: async () => { throw new Error("Not implemented"); }, - structured: async (options) => { + structured: async () => { if (responseCount >= mockResponses.length) { throw new Error("Mock model ran out of responses"); } - const parsed = options.schema.safeParse( - JSON.parse(mockResponses[responseCount]), - ); + const data = JSON.parse(mockResponses[responseCount]); // Sleep for between 500 and 1500 ms await new Promise((resolve) => setTimeout(resolve, Math.random() * 1000 + 500), ); - if (!parsed.success) { - return { - raw: { content: [] } as unknown as Anthropic.Message, - parsed: { - success: false, - error: parsed.error, - }, - }; - } - return { raw: { content: [] } as unknown as Anthropic.Message, - parsed: { - success: true, - data: parsed.data, - }, + structured: data, }; }, }; diff --git a/control-plane/src/modules/router.ts b/control-plane/src/modules/router.ts index 8ad72ebb..05442b5c 100644 --- a/control-plane/src/modules/router.ts +++ b/control-plane/src/modules/router.ts @@ -45,7 +45,6 @@ import { import { callsRouter } from "./calls/router"; import { buildModel } from "./models"; import { - deserializeFunctionSchema, getServiceDefinitions, } from "./service-definitions"; import { integrationsRouter } from "./integrations/router"; @@ -906,12 +905,12 @@ export const router = initServer().router(contract, { const result = await model.structured({ messages: [{ role: "user", content: prompt }], - schema: deserializeFunctionSchema(resultSchema), + schema: resultSchema, }); return { status: 200, - body: result.parsed, + body: result.structured, }; }, getServerStats: async () => { diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts index 23ac0202..3e1b1919 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts @@ -69,12 +69,9 @@ describe("handleModelCall", () => { raw: { content: [], }, - parsed: { - success: true, - data: { - done: true, - result: { reason: "nothing to do" }, - }, + structured: { + done: true, + message: "nothing to do" }, }); @@ -95,19 +92,16 @@ describe("handleModelCall", () => { raw: { content: [], }, - parsed: { - success: true, - data: { - done: true, - result: { reason: "nothing to do" }, - invocations: [ - { - toolName: "notify", - input: { message: "A message" }, - reasoning: "notify the system", - }, - ], - }, + structured: { + done: true, + message: "nothing to do", + invocations: [ + { + toolName: "notify", + input: { message: "A message" }, + reasoning: "notify the system", + }, + ], }, }); @@ -139,11 +133,8 @@ describe("handleModelCall", () => { raw: { content: [], }, - parsed: { - success: true, - data: { - result: { reason: "nothing to do" }, - }, + structured: { + message: "nothing to do", }, }); @@ -159,7 +150,7 @@ describe("handleModelCall", () => { expect(result.messages![0].data).toHaveProperty( "details", expect.objectContaining({ - result: { reason: "nothing to do" }, + message: "nothing to do", }), ); @@ -177,11 +168,8 @@ describe("handleModelCall", () => { raw: { content: [], }, - parsed: { - success: true, - data: { - done: true, - }, + structured: { + done: true, }, }); @@ -258,16 +246,8 @@ describe("handleModelCall", () => { raw: { content: [], }, - parsed: { - success: false, - error: { - errors: [ - { - path: [""], - message: "Test error", - }, - ], - }, + structured: { + randomStuff: "123", }, }); @@ -295,11 +275,8 @@ describe("handleModelCall", () => { describe("additional tool calls", () => { it("should add call to empty invocations array", async () => { mockWithStructuredOutput.mockReturnValueOnce({ - parsed: { - success: true, - data: { - done: false, - }, + structured: { + done: false, }, raw: { content: [ @@ -344,20 +321,17 @@ describe("handleModelCall", () => { it("should add to existing invocations array", async () => { mockWithStructuredOutput.mockReturnValueOnce({ - parsed: { - success: true, - data: { - done: false, - invocations: [ - { - toolName: "notify", - reasoning: "notify the system", - input: { - message: "the first notification", - }, + structured: { + done: false, + invocations: [ + { + toolName: "notify", + reasoning: "notify the system", + input: { + message: "the first notification", }, - ], - }, + }, + ], }, raw: { content: [ diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index f0d799dd..841cad96 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -16,6 +16,8 @@ import { JsonSchemaInput } from "inferable/bin/types"; import { Model } from "../../../models"; import { ToolUseBlock } from "@anthropic-ai/sdk/resources"; +import { zodToJsonSchema } from "zod-to-json-schema"; + type WorkflowStateUpdate = Partial; export const MODEL_CALL_NODE_NAME = "model"; @@ -154,7 +156,7 @@ const _handleModelCall = async ( const response = await model.structured({ messages: renderedMessages, system: systemPrompt, - schema: modelSchema, + schema: zodToJsonSchema(modelSchema), }); if (!response) { @@ -165,7 +167,7 @@ const _handleModelCall = async ( .filter((m) => m.type === "tool_use" && m.name !== "extract") .map((m) => m as ToolUseBlock); - const parsed = response.parsed; + const parsed = modelSchema.safeParse(response.structured); if (!parsed.success) { logger.info("Model provided invalid response object", { diff --git a/control-plane/src/modules/workflows/agent/tools/mock-function.ts b/control-plane/src/modules/workflows/agent/tools/mock-function.ts index 09efb073..de31d68d 100644 --- a/control-plane/src/modules/workflows/agent/tools/mock-function.ts +++ b/control-plane/src/modules/workflows/agent/tools/mock-function.ts @@ -1,7 +1,6 @@ import { AgentError } from "../../../../utilities/errors"; import { logger } from "../../../observability/logger"; import { - deserializeFunctionSchema, serviceFunctionEmbeddingId, } from "../../../service-definitions"; import { AgentTool } from "../tool"; @@ -20,31 +19,17 @@ export const buildMockFunctionTool = ({ functionName: string; serviceName: string; description?: string; - schema: unknown; + schema?: string; mockResult: unknown; }): AgentTool => { const toolName = serviceFunctionEmbeddingId({ serviceName, functionName }); - let deserialized = null; - - try { - deserialized = deserializeFunctionSchema(schema); - } catch (e) { - logger.error( - `Failed to deserialize schema for ${toolName} (${serviceName}.${functionName})`, - { schema, error: e }, - ); - throw new AgentError( - `Failed to deserialize schema for ${toolName} (${serviceName}.${functionName})`, - ); - } - return new AgentTool({ name: toolName, description: ( description ?? `${serviceName}-${functionName} function` ).substring(0, 1024), - schema: deserialized, + schema, func: async (input: unknown) => { logger.info("Mock tool call", { toolName, input }); diff --git a/control-plane/src/modules/workflows/summarization.ts b/control-plane/src/modules/workflows/summarization.ts index e6a445a9..4017b49e 100644 --- a/control-plane/src/modules/workflows/summarization.ts +++ b/control-plane/src/modules/workflows/summarization.ts @@ -42,10 +42,11 @@ export const generateTitle = async ( schema, }); - const parsed = response.parsed; + const parsed = schema.safeParse(response.structured); + if (!parsed.success) { logger.error("Model did not return valid output", { - errors: parsed.error.errors, + errors: parsed.error.issues, }); throw new RetryableError("Invalid title output from model");