From 30ca2b30a00cd9c9c91254d05d39e245c0692b74 Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 2 Dec 2024 14:07:29 +1030 Subject: [PATCH 1/2] feat: Add the ability to mock model responses --- control-plane/src/modules/models/index.ts | 55 ++++++++ .../modules/workflows/agent/agent.ai.test.ts | 128 ++++++++---------- .../src/modules/workflows/agent/agent.ts | 11 +- 3 files changed, 125 insertions(+), 69 deletions(-) diff --git a/control-plane/src/modules/models/index.ts b/control-plane/src/modules/models/index.ts index 1dadf47c..f2e28bb2 100644 --- a/control-plane/src/modules/models/index.ts +++ b/control-plane/src/modules/models/index.ts @@ -17,6 +17,7 @@ import * as events from "../observability/events"; import { rateLimiter } from "../rate-limiter"; import { addAttributes } from "../observability/tracer"; import { customerTelemetry } from "../customer-telemetry"; +import { ulid } from "ulid"; type TrackingOptions = { clusterId?: string; runId?: string; @@ -336,6 +337,60 @@ const parseStructuredResponse = ({ }; }; +export const buildMockModel = ({ + mockResponses, + responseCount, +}: { + mockResponses: string[]; + responseCount: number; +}): Model => { + return { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + identifier: "mock" as any, + embedQuery: async () => { + throw new Error("Not implemented"); + }, + call: async () => { + throw new Error("Not implemented"); + }, + structured: async (options) => { + if (responseCount >= mockResponses.length) { + throw new Error("Mock model ran out of responses"); + } + + const parsed = options.schema.safeParse( + JSON.parse( + mockResponses[responseCount] + ) + ); + + // Sleep for between 500 and 1500 ms + await new Promise((resolve) => + setTimeout(resolve, Math.random() * 1000 + 500) + ); + + if (!parsed.success) { + return { + raw: { content: [] } as unknown as Anthropic.Message, + parsed: { + success: false, + error: parsed.error, + }, + }; + } + + return { + raw: { content: [] } as unknown as Anthropic.Message, + parsed: { + success: true, + data: parsed.data, + }, + }; + + }, + }; +} + const trackModelUsage = async ({ runId, clusterId, diff --git a/control-plane/src/modules/workflows/agent/agent.ai.test.ts b/control-plane/src/modules/workflows/agent/agent.ai.test.ts index 08ac6b39..d8c2923f 100644 --- a/control-plane/src/modules/workflows/agent/agent.ai.test.ts +++ b/control-plane/src/modules/workflows/agent/agent.ai.test.ts @@ -335,98 +335,90 @@ describe("Agent", () => { ); expect(outputState.messages[5]).toHaveProperty("type", "agent"); }); - }); - describe.skip("learning", () => { - const orderCallback = jest.fn(); - const authCallback = jest.fn(); - const tools = [ - new DynamicStructuredTool({ - name: "orderSearch", - description: "Searches for an order by customer name", - schema: z.object({ - query: z.string(), + it("should respect mock responses", async () => { + const tools = [ + new DynamicStructuredTool({ + name: "searchHaystack", + description: "Search haystack", + schema: z.object({ + }).passthrough(), + func: async (input: any) => { + return toolCallback(input.input); + }, }), - func: async (input: any) => { - return orderCallback(input.input); - }, - }), - new DynamicStructuredTool({ - name: "authenticate", - description: "Refreshes the authentication token", - schema: z.object({}), - func: async (input: any) => { - return authCallback(input.input); - }, - }), - ]; + ]; - it("should record tool learnings", async () => { const app = await createWorkflowAgent({ - workflow, - allAvailableTools: ["orderSearch", "authenticate"], + workflow: { + ...workflow, + resultSchema: { + type: "object", + properties: { + word: { + type: "string" + } + } + } + }, + allAvailableTools: ["searchHaystack"], findRelevantTools: async () => tools, getTool: async (input) => tools.find((tool) => tool.name === input.toolName)!, postStepSave: async () => {}, - }); - - orderCallback.mockResolvedValueOnce( - JSON.stringify({ - result: JSON.stringify({ error: "unauthenticated" }), - resultType: "resolution", - status: "failure", - }), - ); - - orderCallback.mockResolvedValueOnce( - JSON.stringify({ - result: JSON.stringify({ - orders: [ + mockModelResponses: [ + JSON.stringify({ + done: false, + invocations: [ { - customerId: "cus-223", - orderId: "ord-313", - details: "Reuben sandwich", - }, - ], + toolName: "searchHaystack", + input: {} + } + ] }), - resultType: "resolution", - status: "success", - }), - ); + JSON.stringify({ + done: true, + result: { + word: "needle" + } + }) + ] + }); - authCallback.mockResolvedValueOnce( - JSON.stringify({ - result: "done", - resultType: "resolution", - status: "success", + + toolCallback.mockResolvedValue(JSON.stringify({ + result: JSON.stringify({ + word: "needle" }), - ); + resultType: "resolution", + status: "success", + })); const outputState = await app.invoke({ messages: [ { data: { - message: "Get details for John's orders", + message: "What is the special word?", }, type: "human", }, ], }); - const learnings = outputState.messages.reduce( - (acc: any, m: WorkflowAgentStateMessage) => { - if (m.type === "agent" && m.data.learnings) { - acc.push(...m.data.learnings); - } - return acc; - }, - [], - ) as any[]; - - expect(learnings).toBeInstanceOf(Array); - expect(learnings.length).toBeGreaterThan(0); + expect(outputState.messages).toHaveLength(4); + expect(outputState.messages[0]).toHaveProperty("type", "human"); + expect(outputState.messages[1]).toHaveProperty("type", "agent"); + expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); + expect(outputState.messages[3]).toHaveProperty("type", "agent"); + expect(outputState.messages[3].data).toHaveProperty( + "result", + { + word: "needle" + } + ) + expect(toolCallback).toHaveBeenCalledTimes(1); }); }); + }); diff --git a/control-plane/src/modules/workflows/agent/agent.ts b/control-plane/src/modules/workflows/agent/agent.ts index 4e53ec0f..cdb126fc 100644 --- a/control-plane/src/modules/workflows/agent/agent.ts +++ b/control-plane/src/modules/workflows/agent/agent.ts @@ -11,7 +11,7 @@ import { postToolEdge, } from "./nodes/edges"; import { AgentMessage } from "../workflow-messages"; -import { buildModel } from "../../models"; +import { buildMockModel, buildModel } from "../../models"; export type ReleventToolLookup = ( state: WorkflowAgentState, @@ -28,6 +28,7 @@ export const createWorkflowAgent = async ({ postStepSave, findRelevantTools, getTool, + mockModelResponses, }: { workflow: Run; additionalContext?: string; @@ -35,6 +36,7 @@ export const createWorkflowAgent = async ({ postStepSave: PostStepSave; findRelevantTools: ReleventToolLookup; getTool: ToolFetcher; + mockModelResponses?: string[]; }) => { const workflowGraph = new StateGraph({ channels: createStateGraphChannels({ @@ -46,6 +48,13 @@ export const createWorkflowAgent = async ({ .addNode(MODEL_CALL_NODE_NAME, (state) => handleModelCall( state, + mockModelResponses ? + // If mock responses are provided, use the mock model + buildMockModel({ + mockResponses: mockModelResponses, + responseCount: state.messages.filter((m) => m.type === "agent").length + }) : + // Otherwise, use the real model buildModel({ identifier: workflow.modelIdentifier ?? "claude-3-5-sonnet", purpose: "agent_loop.reasoning", From cedd539dc24ffea77e200b23566eb4709be51b2c Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 2 Dec 2024 14:24:02 +1030 Subject: [PATCH 2/2] chore: Mock model responses for load test cluster --- .../src/modules/workflows/agent/run.ts | 26 +++++++++++++++++++ control-plane/src/utilities/env.ts | 2 ++ 2 files changed, 28 insertions(+) diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index 74117e4f..ede2b266 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -30,6 +30,7 @@ import { buildMockFunctionTool } from "./tools/mock-function"; import { getClusterInternalTools } from "./tools/cluster-internal-tools"; import { buildCurrentDateTimeTool } from "./tools/date-time"; import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time"; +import { env } from "../../../utilities/env"; /** * Run a workflow from the most recent saved state @@ -80,8 +81,33 @@ export const run = async (run: Run) => { const mockToolsMap: Record = await buildMockTools(run); + let mockModelResponses; + if (!!env.LOAD_TEST_CLUSTER_ID && run.clusterId === env.LOAD_TEST_CLUSTER_ID) { + logger.info("Mocking model responses for load test"); + + //https://github.com/inferablehq/inferable/blob/main/load-tests/script.js + mockModelResponses = [ + JSON.stringify({ + done: false, + invocations: [ + { + toolName: "searchHaystack", + input: {} + } + ] + }), + JSON.stringify({ + done: true, + result: { + word: "needle" + } + }) + ] + } + const app = await createWorkflowAgent({ workflow: run, + mockModelResponses, allAvailableTools, additionalContext, getTool: async (toolCall) => { diff --git a/control-plane/src/utilities/env.ts b/control-plane/src/utilities/env.ts index 5286d62d..30fd1c20 100644 --- a/control-plane/src/utilities/env.ts +++ b/control-plane/src/utilities/env.ts @@ -48,6 +48,8 @@ const envSchema = z SQS_BASE_QUEUE_URL: z.string().optional(), + LOAD_TEST_CLUSTER_ID: z.string().optional(), + // Required in EE (Disabled by default) EE_DEPLOYMENT: truthy.default(false),