From 58e4fff65beb38e6ea19abe4c3c8fad1b4ce6869 Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 16 Dec 2024 08:09:38 +1030 Subject: [PATCH] feat: Remove `zod-to-json-schema` usage (#310) * chore: Remove json-schema-to-zod * feat: Build model schema with JsonSchema * chore: Add agent result test --- control-plane/package-lock.json | 10 -- control-plane/package.json | 1 - .../src/modules/service-definitions.test.ts | 72 ---------- .../src/modules/service-definitions.ts | 53 +------- .../modules/workflows/agent/agent.ai.test.ts | 50 +++++++ .../workflows/agent/nodes/model-call.ts | 91 +++---------- .../agent/nodes/model-output.test.ts | 127 ++++++++++++++++++ .../workflows/agent/nodes/model-output.ts | 101 ++++++++++++++ 8 files changed, 299 insertions(+), 206 deletions(-) create mode 100644 control-plane/src/modules/workflows/agent/nodes/model-output.test.ts create mode 100644 control-plane/src/modules/workflows/agent/nodes/model-output.ts diff --git a/control-plane/package-lock.json b/control-plane/package-lock.json index 428c9d3b..24756c2f 100644 --- a/control-plane/package-lock.json +++ b/control-plane/package-lock.json @@ -36,7 +36,6 @@ "inferable": "^0.30.59", "jest": "^29.6.4", "js-tiktoken": "^1.0.12", - "json-schema-to-zod": "^2.1.0", "jsonpath": "^1.1.1", "jsonschema": "^1.4.1", "jsonwebtoken": "^9.0.2", @@ -17541,15 +17540,6 @@ "fast-deep-equal": "^3.1.3" } }, - "node_modules/json-schema-to-zod": { - "version": "2.4.1", - "resolved": "https://registry.npmjs.org/json-schema-to-zod/-/json-schema-to-zod-2.4.1.tgz", - "integrity": "sha512-aMoez9TxgnfLAIZaWTPaQ+j7rOt1K9Ew/TBI85XcnhcFlo/47b1MDgpi4r07XndLSZWOX/KsJiRJvhdzSvo2Dw==", - "license": "ISC", - "bin": { - "json-schema-to-zod": "dist/cjs/cli.js" - } - }, "node_modules/json-schema-traverse": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", diff --git a/control-plane/package.json b/control-plane/package.json index 59aa6d90..d43ffe7a 100644 --- a/control-plane/package.json +++ b/control-plane/package.json @@ -46,7 +46,6 @@ "inferable": "^0.30.59", "jest": "^29.6.4", "js-tiktoken": "^1.0.12", - "json-schema-to-zod": "^2.1.0", "jsonpath": "^1.1.1", "jsonschema": "^1.4.1", "jsonwebtoken": "^9.0.2", diff --git a/control-plane/src/modules/service-definitions.test.ts b/control-plane/src/modules/service-definitions.test.ts index af33e1e4..4c464d63 100644 --- a/control-plane/src/modules/service-definitions.test.ts +++ b/control-plane/src/modules/service-definitions.test.ts @@ -1,8 +1,6 @@ -import { dereferenceSync, JSONSchema } from "dereference-json-schema"; import { InvalidJobArgumentsError, InvalidServiceRegistrationError } from "../utilities/errors"; import { packer } from "./packer"; import { - deserializeFunctionSchema, embeddableServiceFunction, parseJobArgs, serviceFunctionEmbeddingId, @@ -214,76 +212,6 @@ describe("parseJobArgs", () => { }); }); -describe("deserializeFunctionSchema", () => { - const jsonSchema = { - $schema: "http://json-schema.org/draft-04/schema#", - title: "ExtractResult", - type: "object", - additionalProperties: false, - properties: { - posts: { - type: "array", - items: { - $ref: "#/definitions/Post", - }, - }, - }, - definitions: { - Post: { - type: "object", - additionalProperties: false, - properties: { - id: { - type: "string", - }, - title: { - type: "string", - }, - points: { - type: "string", - }, - comments_url: { - type: "string", - }, - }, - }, - }, - }; - - it("should convert a JSON schema to a Zod schema", () => { - const zodSchema = deserializeFunctionSchema( - dereferenceSync(jsonSchema as any), - ); - const jsonSchema2 = zodToJsonSchema(zodSchema); - const dereferenced = dereferenceSync(jsonSchema2 as JSONSchema); - expect(dereferenced).toMatchObject({ - properties: { - posts: { - type: "array", - items: { - type: "object", - additionalProperties: false, - properties: { - id: { - type: "string", - }, - title: { - type: "string", - }, - points: { - type: "string", - }, - comments_url: { - type: "string", - }, - }, - }, - }, - }, - }); - }); -}); - describe("validateServiceRegistration", () => { it("should reject invalid schema", () => { expect(() => { diff --git a/control-plane/src/modules/service-definitions.ts b/control-plane/src/modules/service-definitions.ts index fcdd7430..05b25d1d 100644 --- a/control-plane/src/modules/service-definitions.ts +++ b/control-plane/src/modules/service-definitions.ts @@ -5,7 +5,6 @@ import { validateFunctionSchema, validateServiceName, } from "inferable"; -import jsonSchemaToZod, { JsonSchema } from "json-schema-to-zod"; import { Validator } from "jsonschema"; import { z } from "zod"; import { @@ -465,9 +464,8 @@ export const validateServiceRegistration = ({ } // Check that the schema accepts and expected value - const zodSchema = deserializeFunctionSchema(fn.schema); - const schema = zodSchema.safeParse({ token: "test" }); - if (!schema.success) { + const result = validator.validate({ token: "test" }, JSON.parse(fn.schema)); + if (!result.valid) { throw new InvalidServiceRegistrationError( `${fn.name} schema is not valid`, "https://docs.inferable.ai/pages/auth#handlecustomerauth", @@ -486,53 +484,6 @@ export const start = () => }, ); // 10 seconds -/** - * Convert a JSON schema (Object or String) to a Zod schema object - */ -export const deserializeFunctionSchema = (schema: unknown) => { - if (typeof schema === "object") { - let zodSchema; - - try { - zodSchema = jsonSchemaToZod(schema as JsonSchema); - } catch (e) { - logger.error("Failed to convert schema to Zod", { schema, error: e }); - throw new Error("Failed to load the tool definition"); - } - - return eval(` -const { z } = require("zod"); -${zodSchema} -`); - } else if (typeof schema === "string") { - let parsed; - - try { - parsed = JSON.parse(schema); - } catch (e) { - logger.error("Failed to parse schema", { schema, error: e }); - throw new Error("Failed to parse the tool definition"); - } - - let zodSchema; - - try { - zodSchema = jsonSchemaToZod(parsed); - } catch (e) { - logger.error("Failed to convert schema to Zod", { schema, error: e }); - throw new Error("Failed to load the tool definition"); - } - - return eval(` -const { z } = require("zod"); -${zodSchema} -`); - } else { - logger.error("Invalid schema", { schema }); - throw new Error("Invalid schema"); - } -}; - export const normalizeFunctionReference = ( fn: string | { service: string; function: string }, ) => diff --git a/control-plane/src/modules/workflows/agent/agent.ai.test.ts b/control-plane/src/modules/workflows/agent/agent.ai.test.ts index 01c70920..5cac8d2b 100644 --- a/control-plane/src/modules/workflows/agent/agent.ai.test.ts +++ b/control-plane/src/modules/workflows/agent/agent.ai.test.ts @@ -190,6 +190,56 @@ describe("Agent", () => { }); }); + describe("result schema", () => { + jest.setTimeout(120000); + + it("should result result schema", async () => { + + const app = await createWorkflowAgent({ + workflow: { + ...workflow, + resultSchema: { + type: "object", + properties: { + word: { + type: "string" + } + } + } + }, + findRelevantTools: async () => tools, + getTool: async () => tools[0], + postStepSave: async () => {}, + }); + + const messages = [ + { + type: "human", + data: { + message: "Return the word 'hello'", + }, + }, + ]; + + const outputState = await app.invoke({ + workflow, + messages, + }); + + expect(outputState.messages).toHaveLength(2); + expect(outputState.messages[0]).toHaveProperty("type", "human"); + expect(outputState.messages[1]).toHaveProperty("type", "agent"); + expect(outputState.messages[1].data.result).toHaveProperty( + "word", + "hello", + ); + + expect(outputState.result).toEqual({ + word: "hello" + }); + }); + }); + describe("early exit", () => { jest.setTimeout(120000); diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index 841cad96..a98e73a8 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -7,21 +7,21 @@ import { withSpan, } from "../../../observability/tracer"; import { AgentError } from "../../../../utilities/errors"; -import { z } from "zod"; import { ulid } from "ulid"; -import { deserializeFunctionSchema } from "../../../service-definitions"; import { validateFunctionSchema } from "inferable"; import { JsonSchemaInput } from "inferable/bin/types"; import { Model } from "../../../models"; import { ToolUseBlock } from "@anthropic-ai/sdk/resources"; -import { zodToJsonSchema } from "zod-to-json-schema"; +import { Schema, Validator } from "jsonschema"; +import { buildModelSchema, ModelOutput } from "./model-output"; type WorkflowStateUpdate = Partial; export const MODEL_CALL_NODE_NAME = "model"; +const validator = new Validator(); export const handleModelCall = ( state: WorkflowAgentState, model: Model, @@ -56,64 +56,11 @@ const _handleModelCall = async ( } } - const resultSchema = state.workflow.resultSchema - ? deserializeFunctionSchema(state.workflow.resultSchema) - : null; - - const modelSchema = z - .object({ - done: z - .boolean() - .describe( - "Whether the workflow is done. All tasks have been completed or you can not progress further.", - ) - .optional(), - - // If we have a result schema, specify it as the result output - ...(!!resultSchema - ? { - result: resultSchema - .optional() - .describe( - "Structrued object describing The final result of the workflow, only provided once all tasks have been completed.", - ), - } - : {}), - - // Otherwise request a string message - ...(!resultSchema - ? { - message: z.string().optional(), - } - : {}), - - issue: z - .string() - .describe( - "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", - ) - .optional(), - - invocations: z - .array( - z.object({ - // @ts-expect-error: We don't care about the type information here, but we want to constrain the model's `toolName` choices. - toolName: z.enum([ - ...relevantSchemas.map((tool) => tool.name), - ...state.allAvailableTools, - ] as string[] as const), - ...(state.workflow.reasoningTraces - ? { reasoning: z.string() } - : {}), - input: z.object({}).passthrough(), - }), - ) - .optional() - .describe( - "Any tools calls you need to make. If multiple are provided, they will be executed in parallel (Do this where possible). DO NOT describe previous tool calls.", - ), - }) - .strict(); + const schema = buildModelSchema({ + state, + relevantSchemas, + resultSchema: state.workflow.resultSchema as JsonSchemaInput, + }); const schemaString = relevantSchemas.map((tool) => { return `${tool.name} - ${tool.description} ${tool.schema}`; @@ -156,7 +103,7 @@ const _handleModelCall = async ( const response = await model.structured({ messages: renderedMessages, system: systemPrompt, - schema: zodToJsonSchema(modelSchema), + schema, }); if (!response) { @@ -167,11 +114,12 @@ const _handleModelCall = async ( .filter((m) => m.type === "tool_use" && m.name !== "extract") .map((m) => m as ToolUseBlock); - const parsed = modelSchema.safeParse(response.structured); + const validation = validator.validate(response.structured, schema as Schema); + const data = response.structured as ModelOutput; - if (!parsed.success) { + if (!validation.valid) { logger.info("Model provided invalid response object", { - error: parsed.error, + errors: validation.errors, }); return { messages: [ @@ -191,7 +139,7 @@ const _handleModelCall = async ( type: "supervisor", data: { message: "Provided object was invalid, check your input", - details: { errors: parsed.error.errors }, + details: { errors: validation.errors }, }, runId: state.workflow.id, clusterId: state.workflow.clusterId, @@ -216,12 +164,12 @@ const _handleModelCall = async ( .filter(Boolean); if (invocations && invocations.length > 0) { - if (!parsed.data.invocations || !Array.isArray(parsed.data.invocations)) { - parsed.data.invocations = []; + if (!data.invocations || !Array.isArray(data.invocations)) { + data.invocations = []; } // Add them to the invocation array to be handled as if they were provided correctly - parsed.data.invocations.push( + data.invocations.push( // eslint-disable-next-line @typescript-eslint/no-explicit-any ...(invocations as any), ); @@ -238,12 +186,11 @@ const _handleModelCall = async ( }); } - const data = parsed.data; const hasInvocations = data.invocations && data.invocations.length > 0; if (state.workflow.debug && hasInvocations) { addAttributes({ - "model.invocations": data.invocations?.map((invoc) => + "model.invocations": data.invocations?.map((invoc: any) => JSON.stringify(invoc), ), }); @@ -323,7 +270,7 @@ const _handleModelCall = async ( id: ulid(), type: "agent", data: { - invocations: data.invocations?.map((invocation) => ({ + invocations: data.invocations?.map((invocation: any) => ({ ...invocation, id: ulid(), reasoning: invocation.reasoning as string | undefined, diff --git a/control-plane/src/modules/workflows/agent/nodes/model-output.test.ts b/control-plane/src/modules/workflows/agent/nodes/model-output.test.ts new file mode 100644 index 00000000..e4dc0ca6 --- /dev/null +++ b/control-plane/src/modules/workflows/agent/nodes/model-output.test.ts @@ -0,0 +1,127 @@ +import { JsonSchema7ObjectType } from "zod-to-json-schema"; +import { WorkflowAgentState } from "../state"; +import { AgentTool } from "../tool"; +import { ulid } from "ulid"; +import { buildModelSchema } from "./model-output"; + +describe("buildModelSchema", () => { + let state: WorkflowAgentState; + let relevantSchemas: AgentTool[]; + let resultSchema: JsonSchema7ObjectType | undefined; + + beforeEach(() => { + state = { + messages: [ + { + id: ulid(), + clusterId: "test-cluster", + runId: "test-run", + data: { + message: "What are your capabilities?", + }, + type: "human", + }, + ], + waitingJobs: [], + allAvailableTools: [], + workflow: { + id: "test-run", + clusterId: "test-cluster", + }, + additionalContext: "", + status: "running", + }; + relevantSchemas = [ + { name: "localTool1"}, + { name: "localTool2"}, + { name: "globalTool1"}, + { name: "globalTool2"}, + ] as AgentTool[], + resultSchema = undefined; + }); + + it("returns a schema with 'message' when resultSchema is not provided", () => { + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); + + expect(schema.type).toBe("object"); + expect(schema.properties).toHaveProperty("message"); + expect(schema.properties).not.toHaveProperty("result"); + }); + + it("returns a schema with 'result' when resultSchema is provided", () => { + resultSchema = { + type: "object", + properties: { + foo: { type: "string" }, + }, + additionalProperties: false, + }; + + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }) as any; + + expect(schema.type).toBe("object"); + expect(schema.properties).toHaveProperty("result"); + expect(schema.properties).not.toHaveProperty("message"); + expect(schema.properties?.result?.description).toContain("final result"); + }); + + it("includes 'done' and 'issue' fields", () => { + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }) as any; + + expect(schema.properties).toHaveProperty("done"); + expect(schema.properties).toHaveProperty("issue"); + expect(schema.properties.done.type).toBe("boolean"); + expect(schema.properties.issue.type).toBe("string"); + }); + + it("builds the correct toolName enum from available tools", () => { + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); + const invocations = schema.properties?.invocations as any; + const items = invocations.items as JsonSchema7ObjectType; + const toolName = items.properties?.toolName as any; + + expect(toolName).toBeDefined(); + expect(toolName?.enum).toContain("localTool1"); + expect(toolName?.enum).toContain("localTool2"); + expect(toolName?.enum).toContain("globalTool1"); + expect(toolName?.enum).toContain("globalTool2"); + }); + + it("includes 'invocations' with correct structure", () => { + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); + const invocations = schema.properties?.invocations as any; + + expect(invocations.type).toBe("array"); + const items = invocations.items as any; + + expect(items.type).toBe("object"); + expect(items.additionalProperties).toBe(false); + expect(items.required).toEqual(["toolName", "input"]); + + expect(items.properties?.input.type).toBe("object"); + expect(items.properties?.input.additionalProperties).toBe(true); + }); + + it("does not include 'reasoning' by default", () => { + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); + const invocations = schema.properties?.invocations as any; + const items = invocations.items as JsonSchema7ObjectType; + + expect(items.properties).not.toHaveProperty("reasoning"); + }); + + it("includes 'reasoning' when reasoningTraces is true", () => { + state.workflow.reasoningTraces = true; + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); + const invocations = schema.properties?.invocations as any; + const items = invocations.items as any; + + expect(items.properties).toHaveProperty("reasoning"); + expect(items.properties?.reasoning?.type).toBe("string"); + }); + + it("has additionalProperties set to false at top level", () => { + const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); + expect(schema.additionalProperties).toBe(false); + }); +}); diff --git a/control-plane/src/modules/workflows/agent/nodes/model-output.ts b/control-plane/src/modules/workflows/agent/nodes/model-output.ts new file mode 100644 index 00000000..5827dfec --- /dev/null +++ b/control-plane/src/modules/workflows/agent/nodes/model-output.ts @@ -0,0 +1,101 @@ + +import { JsonSchema7ObjectType } from "zod-to-json-schema"; +import { AgentTool } from "../tool"; +import { workflows } from "../../../data"; +import { InferSelectModel } from "drizzle-orm"; +import { WorkflowAgentState } from "../state"; + +type ModelInvocationOutput = { + toolName: string; + input: unknown; + +} + +export type ModelOutput = { + invocations?: ModelInvocationOutput[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + result?: any; + message?: string; + done?: boolean; + issue?: string; +} + +export const buildModelSchema = ({ + state, + relevantSchemas, + resultSchema +}: { + state: WorkflowAgentState; + relevantSchemas: AgentTool[]; + resultSchema?: InferSelectModel["result_schema"]; + }) => { + + // Build the toolName enum + const toolNameEnum = [ + ...relevantSchemas.map((tool) => tool.name), + ...state.allAvailableTools, + ]; + + const schema: JsonSchema7ObjectType = { + type: "object", + additionalProperties: false, + properties: { + done: { + type: "boolean", + description: + "Whether the workflow is done. All tasks have been completed or you can not progress further.", + }, + issue: { + type: "string", + description: + "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", + }, + }, + }; + + if (resultSchema) { + schema.properties.result = { + ...resultSchema, + description: + "Structured object describing the final result of the workflow, only provided once all tasks have been completed.", + }; + } else { + schema.properties.message = { + type: "string", + description: "A message describing the current state or next steps.", + }; + } + + const invocationItemProperties: JsonSchema7ObjectType["properties"] = { + toolName: { + type: "string", + enum: toolNameEnum, + }, + input: { + type: "object", + additionalProperties: true, + description: "Arbitrary input parameters for the tool call.", + }, + }; + + if (state.workflow.reasoningTraces) { + invocationItemProperties.reasoning = { + type: "string", + description: "Reasoning trace for why this tool call is made.", + }; + } + + schema.properties.invocations = { + type: "array", + description: + "Any tool calls you need to make. If multiple are provided, they will be executed in parallel. DO NOT describe previous tool calls.", + items: { + type: "object", + additionalProperties: false, + properties: invocationItemProperties, + required: ["toolName", "input"], + }, + }; + + return schema; +};