-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Build model schema with JsonSchema
- Loading branch information
1 parent
e15e4a7
commit 887d38d
Showing
3 changed files
with
247 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 127 additions & 0 deletions
127
control-plane/src/modules/workflows/agent/nodes/model-output.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import { JsonSchema7ObjectType } from "zod-to-json-schema"; | ||
import { WorkflowAgentState } from "../state"; | ||
import { AgentTool } from "../tool"; | ||
import { ulid } from "ulid"; | ||
import { buildModelSchema } from "./model-output"; | ||
|
||
describe("buildModelSchema", () => { | ||
let state: WorkflowAgentState; | ||
let relevantSchemas: AgentTool[]; | ||
let resultSchema: JsonSchema7ObjectType | undefined; | ||
|
||
beforeEach(() => { | ||
state = { | ||
messages: [ | ||
{ | ||
id: ulid(), | ||
clusterId: "test-cluster", | ||
runId: "test-run", | ||
data: { | ||
message: "What are your capabilities?", | ||
}, | ||
type: "human", | ||
}, | ||
], | ||
waitingJobs: [], | ||
allAvailableTools: [], | ||
workflow: { | ||
id: "test-run", | ||
clusterId: "test-cluster", | ||
}, | ||
additionalContext: "", | ||
status: "running", | ||
}; | ||
relevantSchemas = [ | ||
{ name: "localTool1"}, | ||
{ name: "localTool2"}, | ||
{ name: "globalTool1"}, | ||
{ name: "globalTool2"}, | ||
] as AgentTool[], | ||
resultSchema = undefined; | ||
}); | ||
|
||
it("returns a schema with 'message' when resultSchema is not provided", () => { | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); | ||
|
||
expect(schema.type).toBe("object"); | ||
expect(schema.properties).toHaveProperty("message"); | ||
expect(schema.properties).not.toHaveProperty("result"); | ||
}); | ||
|
||
it("returns a schema with 'result' when resultSchema is provided", () => { | ||
resultSchema = { | ||
type: "object", | ||
properties: { | ||
foo: { type: "string" }, | ||
}, | ||
additionalProperties: false, | ||
}; | ||
|
||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }) as any; | ||
|
||
expect(schema.type).toBe("object"); | ||
expect(schema.properties).toHaveProperty("result"); | ||
expect(schema.properties).not.toHaveProperty("message"); | ||
expect(schema.properties?.result?.description).toContain("final result"); | ||
}); | ||
|
||
it("includes 'done' and 'issue' fields", () => { | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }) as any; | ||
|
||
expect(schema.properties).toHaveProperty("done"); | ||
expect(schema.properties).toHaveProperty("issue"); | ||
expect(schema.properties.done.type).toBe("boolean"); | ||
expect(schema.properties.issue.type).toBe("string"); | ||
}); | ||
|
||
it("builds the correct toolName enum from available tools", () => { | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); | ||
const invocations = schema.properties?.invocations as any; | ||
const items = invocations.items as JsonSchema7ObjectType; | ||
const toolName = items.properties?.toolName as any; | ||
|
||
expect(toolName).toBeDefined(); | ||
expect(toolName?.enum).toContain("localTool1"); | ||
expect(toolName?.enum).toContain("localTool2"); | ||
expect(toolName?.enum).toContain("globalTool1"); | ||
expect(toolName?.enum).toContain("globalTool2"); | ||
}); | ||
|
||
it("includes 'invocations' with correct structure", () => { | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); | ||
const invocations = schema.properties?.invocations as any; | ||
|
||
expect(invocations.type).toBe("array"); | ||
const items = invocations.items as any; | ||
|
||
expect(items.type).toBe("object"); | ||
expect(items.additionalProperties).toBe(false); | ||
expect(items.required).toEqual(["toolName", "input"]); | ||
|
||
expect(items.properties?.input.type).toBe("object"); | ||
expect(items.properties?.input.additionalProperties).toBe(true); | ||
}); | ||
|
||
it("does not include 'reasoning' by default", () => { | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); | ||
const invocations = schema.properties?.invocations as any; | ||
const items = invocations.items as JsonSchema7ObjectType; | ||
|
||
expect(items.properties).not.toHaveProperty("reasoning"); | ||
}); | ||
|
||
it("includes 'reasoning' when reasoningTraces is true", () => { | ||
state.workflow.reasoningTraces = true; | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); | ||
const invocations = schema.properties?.invocations as any; | ||
const items = invocations.items as any; | ||
|
||
expect(items.properties).toHaveProperty("reasoning"); | ||
expect(items.properties?.reasoning?.type).toBe("string"); | ||
}); | ||
|
||
it("has additionalProperties set to false at top level", () => { | ||
const schema = buildModelSchema({ state, relevantSchemas, resultSchema }); | ||
expect(schema.additionalProperties).toBe(false); | ||
}); | ||
}); |
101 changes: 101 additions & 0 deletions
101
control-plane/src/modules/workflows/agent/nodes/model-output.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
|
||
import { JsonSchema7ObjectType } from "zod-to-json-schema"; | ||
import { AgentTool } from "../tool"; | ||
import { workflows } from "../../../data"; | ||
import { InferSelectModel } from "drizzle-orm"; | ||
import { WorkflowAgentState } from "../state"; | ||
|
||
type ModelInvocationOutput = { | ||
toolName: string; | ||
input: unknown; | ||
|
||
} | ||
|
||
export type ModelOutput = { | ||
invocations?: ModelInvocationOutput[]; | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
result?: any; | ||
message?: string; | ||
done?: boolean; | ||
issue?: string; | ||
} | ||
|
||
export const buildModelSchema = ({ | ||
state, | ||
relevantSchemas, | ||
resultSchema | ||
}: { | ||
state: WorkflowAgentState; | ||
relevantSchemas: AgentTool[]; | ||
resultSchema?: InferSelectModel<typeof workflows>["result_schema"]; | ||
}) => { | ||
|
||
// Build the toolName enum | ||
const toolNameEnum = [ | ||
...relevantSchemas.map((tool) => tool.name), | ||
...state.allAvailableTools, | ||
]; | ||
|
||
const schema: JsonSchema7ObjectType = { | ||
type: "object", | ||
additionalProperties: false, | ||
properties: { | ||
done: { | ||
type: "boolean", | ||
description: | ||
"Whether the workflow is done. All tasks have been completed or you can not progress further.", | ||
}, | ||
issue: { | ||
type: "string", | ||
description: | ||
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.", | ||
}, | ||
}, | ||
}; | ||
|
||
if (resultSchema) { | ||
schema.properties.result = { | ||
...resultSchema, | ||
description: | ||
"Structured object describing the final result of the workflow, only provided once all tasks have been completed.", | ||
}; | ||
} else { | ||
schema.properties.message = { | ||
type: "string", | ||
description: "A message describing the current state or next steps.", | ||
}; | ||
} | ||
|
||
const invocationItemProperties: JsonSchema7ObjectType["properties"] = { | ||
toolName: { | ||
type: "string", | ||
enum: toolNameEnum, | ||
}, | ||
input: { | ||
type: "object", | ||
additionalProperties: true, | ||
description: "Arbitrary input parameters for the tool call.", | ||
}, | ||
}; | ||
|
||
if (state.workflow.reasoningTraces) { | ||
invocationItemProperties.reasoning = { | ||
type: "string", | ||
description: "Reasoning trace for why this tool call is made.", | ||
}; | ||
} | ||
|
||
schema.properties.invocations = { | ||
type: "array", | ||
description: | ||
"Any tool calls you need to make. If multiple are provided, they will be executed in parallel. DO NOT describe previous tool calls.", | ||
items: { | ||
type: "object", | ||
additionalProperties: false, | ||
properties: invocationItemProperties, | ||
required: ["toolName", "input"], | ||
}, | ||
}; | ||
|
||
return schema; | ||
}; |