Skip to content

Commit

Permalink
feat: Build model schema with JsonSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Dec 15, 2024
1 parent e15e4a7 commit 887d38d
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 72 deletions.
91 changes: 19 additions & 72 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ import {
withSpan,
} from "../../../observability/tracer";
import { AgentError } from "../../../../utilities/errors";
import { z } from "zod";
import { ulid } from "ulid";

import { deserializeFunctionSchema } from "../../../service-definitions";
import { validateFunctionSchema } from "inferable";
import { JsonSchemaInput } from "inferable/bin/types";
import { Model } from "../../../models";
import { ToolUseBlock } from "@anthropic-ai/sdk/resources";

import { zodToJsonSchema } from "zod-to-json-schema";
import { Schema, Validator } from "jsonschema";
import { buildModelSchema, ModelOutput } from "./model-output";

type WorkflowStateUpdate = Partial<WorkflowAgentState>;

export const MODEL_CALL_NODE_NAME = "model";

const validator = new Validator();
export const handleModelCall = (
state: WorkflowAgentState,
model: Model,
Expand Down Expand Up @@ -56,64 +56,11 @@ const _handleModelCall = async (
}
}

const resultSchema = state.workflow.resultSchema
? deserializeFunctionSchema(state.workflow.resultSchema)
: null;

const modelSchema = z
.object({
done: z
.boolean()
.describe(
"Whether the workflow is done. All tasks have been completed or you can not progress further.",
)
.optional(),

// If we have a result schema, specify it as the result output
...(!!resultSchema
? {
result: resultSchema
.optional()
.describe(
"Structrued object describing The final result of the workflow, only provided once all tasks have been completed.",
),
}
: {}),

// Otherwise request a string message
...(!resultSchema
? {
message: z.string().optional(),
}
: {}),

issue: z
.string()
.describe(
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
)
.optional(),

invocations: z
.array(
z.object({
// @ts-expect-error: We don't care about the type information here, but we want to constrain the model's `toolName` choices.
toolName: z.enum([
...relevantSchemas.map((tool) => tool.name),
...state.allAvailableTools,
] as string[] as const),
...(state.workflow.reasoningTraces
? { reasoning: z.string() }
: {}),
input: z.object({}).passthrough(),
}),
)
.optional()
.describe(
"Any tools calls you need to make. If multiple are provided, they will be executed in parallel (Do this where possible). DO NOT describe previous tool calls.",
),
})
.strict();
const schema = buildModelSchema({
state,
relevantSchemas,
resultSchema: state.workflow.resultSchema as JsonSchemaInput,
});

const schemaString = relevantSchemas.map((tool) => {
return `${tool.name} - ${tool.description} ${tool.schema}`;
Expand Down Expand Up @@ -156,7 +103,7 @@ const _handleModelCall = async (
const response = await model.structured({
messages: renderedMessages,
system: systemPrompt,
schema: zodToJsonSchema(modelSchema),
schema,
});

if (!response) {
Expand All @@ -167,11 +114,12 @@ const _handleModelCall = async (
.filter((m) => m.type === "tool_use" && m.name !== "extract")
.map((m) => m as ToolUseBlock);

const parsed = modelSchema.safeParse(response.structured);
const validation = validator.validate(response.structured, schema as Schema);
const data = response.structured as ModelOutput;

if (!parsed.success) {
if (!validation.valid) {
logger.info("Model provided invalid response object", {
error: parsed.error,
errors: validation.errors,
});
return {
messages: [
Expand All @@ -191,7 +139,7 @@ const _handleModelCall = async (
type: "supervisor",
data: {
message: "Provided object was invalid, check your input",
details: { errors: parsed.error.errors },
details: { errors: validation.errors },
},
runId: state.workflow.id,
clusterId: state.workflow.clusterId,
Expand All @@ -216,12 +164,12 @@ const _handleModelCall = async (
.filter(Boolean);

if (invocations && invocations.length > 0) {
if (!parsed.data.invocations || !Array.isArray(parsed.data.invocations)) {
parsed.data.invocations = [];
if (!data.invocations || !Array.isArray(data.invocations)) {
data.invocations = [];
}

// Add them to the invocation array to be handled as if they were provided correctly
parsed.data.invocations.push(
data.invocations.push(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
...(invocations as any),
);
Expand All @@ -238,12 +186,11 @@ const _handleModelCall = async (
});
}

const data = parsed.data;
const hasInvocations = data.invocations && data.invocations.length > 0;

if (state.workflow.debug && hasInvocations) {
addAttributes({
"model.invocations": data.invocations?.map((invoc) =>
"model.invocations": data.invocations?.map((invoc: any) =>
JSON.stringify(invoc),
),
});
Expand Down Expand Up @@ -323,7 +270,7 @@ const _handleModelCall = async (
id: ulid(),
type: "agent",
data: {
invocations: data.invocations?.map((invocation) => ({
invocations: data.invocations?.map((invocation: any) => ({
...invocation,
id: ulid(),
reasoning: invocation.reasoning as string | undefined,
Expand Down
127 changes: 127 additions & 0 deletions control-plane/src/modules/workflows/agent/nodes/model-output.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { WorkflowAgentState } from "../state";
import { AgentTool } from "../tool";
import { ulid } from "ulid";
import { buildModelSchema } from "./model-output";

describe("buildModelSchema", () => {
let state: WorkflowAgentState;
let relevantSchemas: AgentTool[];
let resultSchema: JsonSchema7ObjectType | undefined;

beforeEach(() => {
state = {
messages: [
{
id: ulid(),
clusterId: "test-cluster",
runId: "test-run",
data: {
message: "What are your capabilities?",
},
type: "human",
},
],
waitingJobs: [],
allAvailableTools: [],
workflow: {
id: "test-run",
clusterId: "test-cluster",
},
additionalContext: "",
status: "running",
};
relevantSchemas = [
{ name: "localTool1"},
{ name: "localTool2"},
{ name: "globalTool1"},
{ name: "globalTool2"},
] as AgentTool[],
resultSchema = undefined;
});

it("returns a schema with 'message' when resultSchema is not provided", () => {
const schema = buildModelSchema({ state, relevantSchemas, resultSchema });

expect(schema.type).toBe("object");
expect(schema.properties).toHaveProperty("message");
expect(schema.properties).not.toHaveProperty("result");
});

it("returns a schema with 'result' when resultSchema is provided", () => {
resultSchema = {
type: "object",
properties: {
foo: { type: "string" },
},
additionalProperties: false,
};

const schema = buildModelSchema({ state, relevantSchemas, resultSchema }) as any;

expect(schema.type).toBe("object");
expect(schema.properties).toHaveProperty("result");
expect(schema.properties).not.toHaveProperty("message");
expect(schema.properties?.result?.description).toContain("final result");
});

it("includes 'done' and 'issue' fields", () => {
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }) as any;

expect(schema.properties).toHaveProperty("done");
expect(schema.properties).toHaveProperty("issue");
expect(schema.properties.done.type).toBe("boolean");
expect(schema.properties.issue.type).toBe("string");
});

it("builds the correct toolName enum from available tools", () => {
const schema = buildModelSchema({ state, relevantSchemas, resultSchema });
const invocations = schema.properties?.invocations as any;
const items = invocations.items as JsonSchema7ObjectType;
const toolName = items.properties?.toolName as any;

expect(toolName).toBeDefined();
expect(toolName?.enum).toContain("localTool1");
expect(toolName?.enum).toContain("localTool2");
expect(toolName?.enum).toContain("globalTool1");
expect(toolName?.enum).toContain("globalTool2");
});

it("includes 'invocations' with correct structure", () => {
const schema = buildModelSchema({ state, relevantSchemas, resultSchema });
const invocations = schema.properties?.invocations as any;

expect(invocations.type).toBe("array");
const items = invocations.items as any;

expect(items.type).toBe("object");
expect(items.additionalProperties).toBe(false);
expect(items.required).toEqual(["toolName", "input"]);

expect(items.properties?.input.type).toBe("object");
expect(items.properties?.input.additionalProperties).toBe(true);
});

it("does not include 'reasoning' by default", () => {
const schema = buildModelSchema({ state, relevantSchemas, resultSchema });
const invocations = schema.properties?.invocations as any;
const items = invocations.items as JsonSchema7ObjectType;

expect(items.properties).not.toHaveProperty("reasoning");
});

it("includes 'reasoning' when reasoningTraces is true", () => {
state.workflow.reasoningTraces = true;
const schema = buildModelSchema({ state, relevantSchemas, resultSchema });
const invocations = schema.properties?.invocations as any;
const items = invocations.items as any;

expect(items.properties).toHaveProperty("reasoning");
expect(items.properties?.reasoning?.type).toBe("string");
});

it("has additionalProperties set to false at top level", () => {
const schema = buildModelSchema({ state, relevantSchemas, resultSchema });
expect(schema.additionalProperties).toBe(false);
});
});
101 changes: 101 additions & 0 deletions control-plane/src/modules/workflows/agent/nodes/model-output.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@

import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { AgentTool } from "../tool";
import { workflows } from "../../../data";
import { InferSelectModel } from "drizzle-orm";
import { WorkflowAgentState } from "../state";

type ModelInvocationOutput = {
toolName: string;
input: unknown;

}

export type ModelOutput = {
invocations?: ModelInvocationOutput[];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
result?: any;
message?: string;
done?: boolean;
issue?: string;
}

export const buildModelSchema = ({
state,
relevantSchemas,
resultSchema
}: {
state: WorkflowAgentState;
relevantSchemas: AgentTool[];
resultSchema?: InferSelectModel<typeof workflows>["result_schema"];
}) => {

// Build the toolName enum
const toolNameEnum = [
...relevantSchemas.map((tool) => tool.name),
...state.allAvailableTools,
];

const schema: JsonSchema7ObjectType = {
type: "object",
additionalProperties: false,
properties: {
done: {
type: "boolean",
description:
"Whether the workflow is done. All tasks have been completed or you can not progress further.",
},
issue: {
type: "string",
description:
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
},
},
};

if (resultSchema) {
schema.properties.result = {
...resultSchema,
description:
"Structured object describing the final result of the workflow, only provided once all tasks have been completed.",
};
} else {
schema.properties.message = {
type: "string",
description: "A message describing the current state or next steps.",
};
}

const invocationItemProperties: JsonSchema7ObjectType["properties"] = {
toolName: {
type: "string",
enum: toolNameEnum,
},
input: {
type: "object",
additionalProperties: true,
description: "Arbitrary input parameters for the tool call.",
},
};

if (state.workflow.reasoningTraces) {
invocationItemProperties.reasoning = {
type: "string",
description: "Reasoning trace for why this tool call is made.",
};
}

schema.properties.invocations = {
type: "array",
description:
"Any tool calls you need to make. If multiple are provided, they will be executed in parallel. DO NOT describe previous tool calls.",
items: {
type: "object",
additionalProperties: false,
properties: invocationItemProperties,
required: ["toolName", "input"],
},
};

return schema;
};

0 comments on commit 887d38d

Please sign in to comment.