From 31e508916bf3ed0b48da33dabc6da5e767269fcc Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 3 Jan 2025 08:27:45 +1100 Subject: [PATCH 1/3] chore: Refactor message types to use unified schema --- control-plane/src/modules/contract.ts | 69 ++++--- control-plane/src/modules/data.ts | 3 +- control-plane/src/modules/email/index.ts | 81 ++++---- .../src/modules/integrations/slack/index.ts | 189 +++++++++--------- control-plane/src/modules/router.ts | 4 +- .../modules/workflows/agent/agent.ai.test.ts | 122 ++++------- .../workflows/agent/nodes/model-call.test.ts | 85 ++++---- .../workflows/agent/nodes/tool-call.test.ts | 59 +++--- .../workflows/agent/nodes/tool-call.ts | 65 +++--- control-plane/src/modules/workflows/router.ts | 6 +- .../modules/workflows/workflow-messages.ts | 178 +++++++---------- .../src/modules/workflows/workflows.ts | 2 +- 12 files changed, 391 insertions(+), 472 deletions(-) diff --git a/control-plane/src/modules/contract.ts b/control-plane/src/modules/contract.ts index 24d679d0..ff88a603 100644 --- a/control-plane/src/modules/contract.ts +++ b/control-plane/src/modules/contract.ts @@ -91,14 +91,14 @@ export const integrationSchema = z.object({ .nullable(), }); -export const genericMessageDataSchema = z +const genericMessageDataSchema = z .object({ message: z.string(), details: z.object({}).passthrough().optional(), }) .strict(); -export const resultDataSchema = z +const resultDataSchema = z .object({ id: z.string(), result: z.object({}).passthrough(), @@ -124,7 +124,7 @@ export const learningSchema = z.object({ }), }); -export const agentDataSchema = z +const agentDataSchema = z .object({ done: z.boolean().optional(), result: anyObject.optional(), @@ -144,10 +144,31 @@ export const agentDataSchema = z }) .strict(); -export const messageDataSchema = z.union([ - resultDataSchema, - agentDataSchema, - genericMessageDataSchema, +export const unifiedMessageDataSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal("agent"), + data: agentDataSchema, + }), + z.object({ + type: z.literal("invocation-result"), + data: resultDataSchema, + }), + z.object({ + type: z.literal("human"), + data: genericMessageDataSchema, + }), + z.object({ + type: z.literal("template"), + data: genericMessageDataSchema, + }), + z.object({ + type: z.literal("supervisor"), + data: genericMessageDataSchema, + }), + z.object({ + type: z.literal("agent-invalid"), + data: genericMessageDataSchema, + }), ]); export const FunctionConfigSchema = z.object({ @@ -623,21 +644,14 @@ export const definition = { }), responses: { 200: z.array( - z.object({ - id: z.string(), - data: messageDataSchema, - type: z.enum([ - "human", - "template", - "invocation-result", - "agent", - "agent-invalid", - "supervisor", - ]), - createdAt: z.date(), - pending: z.boolean().default(false), - displayableContext: z.record(z.string()).nullable(), - }) + z + .object({ + id: z.string(), + createdAt: z.date(), + pending: z.boolean().default(false), + displayableContext: z.record(z.string()).nullable(), + }) + .merge(z.object({ data: unifiedMessageDataSchema })) ), 401: z.undefined(), }, @@ -935,16 +949,7 @@ export const definition = { messages: z.array( z.object({ id: z.string(), - data: messageDataSchema, - type: z.enum([ - // TODO: Remove 'template' type - "template", - "invocation-result", - "human", - "agent", - "agent-invalid", - "supervisor", - ]), + data: unifiedMessageDataSchema, createdAt: z.date(), pending: z.boolean().default(false), displayableContext: z.record(z.string()).nullable(), diff --git a/control-plane/src/modules/data.ts b/control-plane/src/modules/data.ts index 31105613..c31e2db6 100644 --- a/control-plane/src/modules/data.ts +++ b/control-plane/src/modules/data.ts @@ -18,7 +18,6 @@ import { import { Pool } from "pg"; import { env } from "../utilities/env"; import { logger } from "./observability/logger"; -import { MessageData } from "./workflows/workflow-messages"; export const createMutex = advisoryLock(env.DATABASE_URL); @@ -323,7 +322,7 @@ export const workflowMessages = pgTable( withTimezone: true, precision: 6, }), - data: json("data").$type().notNull(), + data: json("data").$type().notNull(), type: text("type", { enum: ["human", "invocation-result", "template", "agent", "agent-invalid", "supervisor"], }).notNull(), diff --git a/control-plane/src/modules/email/index.ts b/control-plane/src/modules/email/index.ts index f0f1bc17..c983c139 100644 --- a/control-plane/src/modules/email/index.ts +++ b/control-plane/src/modules/email/index.ts @@ -14,6 +14,7 @@ import { workflowMessages } from "../data"; import { ses } from "../ses"; import { getMessageByReference, updateMessageReference } from "../workflows/workflow-messages"; import { ulid } from "ulid"; +import { unifiedMessageDataSchema } from "../contract"; const EMAIL_INIT_MESSAGE_ID_META_KEY = "emailInitMessageId"; const EMAIL_SUBJECT_META_KEY = "emailSubject"; @@ -33,7 +34,7 @@ const sesMessageSchema = z.object({ dmarcVerdict: z.object({ status: z.string() }), }), content: z.string(), -}) +}); const snsNotificationSchema = z.object({ Type: z.literal("Notification"), @@ -54,7 +55,7 @@ const emailIngestionConsumer = env.SQS_EMAIL_INGESTION_QUEUE_URL : undefined; export const start = async () => { - emailIngestionConsumer?.start() + emailIngestionConsumer?.start(); }; export const stop = async () => { @@ -71,7 +72,6 @@ export const handleNewRunMessage = async ({ runId: string; type: InferSelectModel["type"]; data: InferSelectModel["data"]; - }; runMetadata?: Record; }) => { @@ -79,17 +79,21 @@ export const handleNewRunMessage = async ({ return; } - if (!runMetadata?.[EMAIL_INIT_MESSAGE_ID_META_KEY] || !runMetadata?.[EMAIL_SUBJECT_META_KEY] || !runMetadata?.[EMAIL_SOURCE_META_KEY]) { + if ( + !runMetadata?.[EMAIL_INIT_MESSAGE_ID_META_KEY] || + !runMetadata?.[EMAIL_SUBJECT_META_KEY] || + !runMetadata?.[EMAIL_SOURCE_META_KEY] + ) { return; } - if ("message" in message.data && message.data.message) { + const messageData = unifiedMessageDataSchema.parse(message.data).data; + + if ("message" in messageData && messageData.message) { const result = await ses.sendEmail({ Source: `"Inferable" <${message.clusterId}@${env.INFERABLE_EMAIL_DOMAIN}>`, Destination: { - ToAddresses: [ - runMetadata[EMAIL_SOURCE_META_KEY] - ], + ToAddresses: [runMetadata[EMAIL_SOURCE_META_KEY]], }, Message: { Subject: { @@ -99,11 +103,11 @@ export const handleNewRunMessage = async ({ Body: { Text: { Charset: "UTF-8", - Data: message.data.message, + Data: messageData.message, }, }, }, - }) + }); if (!result.MessageId) { throw new Error("SES did not return a message ID"); @@ -119,7 +123,6 @@ export const handleNewRunMessage = async ({ clusterId: message.clusterId, messageId: message.id, }); - } else { logger.warn("Email thread message does not have content"); } @@ -130,7 +133,7 @@ export async function parseMessage(message: unknown) { if (!notification.success) { logger.error("Could not parse SNS notification message", { error: notification.error, - }) + }); throw new Error("Could not parse SNS notification message"); } @@ -143,19 +146,22 @@ export async function parseMessage(message: unknown) { if (!sesMessage.success) { logger.error("Could not parse SES message", { error: sesMessage.error, - }) + }); throw new Error("Could not parse SES message"); } - const ingestionAddresses = sesMessage.data.mail.destination.filter( - (email) => email.endsWith(env.INFERABLE_EMAIL_DOMAIN) - ) + const ingestionAddresses = sesMessage.data.mail.destination.filter(email => + email.endsWith(env.INFERABLE_EMAIL_DOMAIN) + ); if (ingestionAddresses.length > 1) { throw new Error("Found multiple Inferable email addresses in destination"); } - const clusterId = ingestionAddresses.pop()?.replace(env.INFERABLE_EMAIL_DOMAIN, "").replace("@", ""); + const clusterId = ingestionAddresses + .pop() + ?.replace(env.INFERABLE_EMAIL_DOMAIN, "") + .replace("@", ""); if (!clusterId) { throw new Error("Could not extract clusterId from email address"); @@ -166,9 +172,9 @@ export async function parseMessage(message: unknown) { throw new Error("Could not parse email content"); } - let body = mail.text + let body = mail.text; if (!body && mail.html) { - body = mail.html + body = mail.html; } return { @@ -179,26 +185,25 @@ export async function parseMessage(message: unknown) { messageId: mail.messageId, source: sesMessage.data.mail.source, inReplyTo: mail.inReplyTo, - references: (typeof mail.references === "string") ? [mail.references] : mail.references ?? [], - } + references: typeof mail.references === "string" ? [mail.references] : (mail.references ?? []), + }; } // Strip trailing email chain quotes ">" export const stripQuoteTail = (message: string) => { const lines = message.split("\n").reverse(); - while (lines[0] && lines[0].startsWith(">") || lines[0].trim() === "") { + while ((lines[0] && lines[0].startsWith(">")) || lines[0].trim() === "") { lines.shift(); } return lines.reverse().join("\n"); -} +}; async function handleEmailIngestion(raw: unknown) { const message = await parseMessage(raw); if (!message.body) { - logger.info("Email had no body. Skipping", { - }); + logger.info("Email had no body. Skipping", {}); return; } @@ -239,7 +244,7 @@ async function handleEmailIngestion(raw: unknown) { body: message.body, clusterId: message.clusterId, messageId: message.messageId, - runId: existing.runId + runId: existing.runId, }); } @@ -249,12 +254,11 @@ async function handleEmailIngestion(raw: unknown) { clusterId: message.clusterId, messageId: message.messageId, subject: message.subject, - source: message.source + source: message.source, }); } - -const parseMailContent = (message: string): Promise => { +const parseMailContent = (message: string): Promise => { return new Promise((resolve, reject) => { simpleParser(message, (error, parsed) => { if (error) { @@ -263,10 +267,9 @@ const parseMailContent = (message: string): Promise => { resolve(parsed); } }); - }) + }); }; - const authenticateUser = async (emailAddress: string, clusterId: string) => { if (!env.CLERK_SECRET_KEY) { throw new Error("CLERK_SECRET_KEY must be set for email authentication"); @@ -293,7 +296,7 @@ const handleNewChain = async ({ clusterId, messageId, subject, - source + source, }: { userId: string; body: string; @@ -302,7 +305,7 @@ const handleNewChain = async ({ subject: string; source: string; }) => { - logger.info("Creating new run from email") + logger.info("Creating new run from email"); await createRunWithMessage({ userId, clusterId, @@ -317,15 +320,15 @@ const handleNewChain = async ({ }, message: body, type: "human", - }) -} + }); +}; const handleExistingChain = async ({ userId, body, clusterId, messageId, - runId + runId, }: { userId: string; body: string; @@ -333,7 +336,7 @@ const handleExistingChain = async ({ messageId: string; runId: string; }) => { - logger.info("Continuing existing run from email") + logger.info("Continuing existing run from email"); await addMessageAndResume({ id: ulid(), clusterId, @@ -345,5 +348,5 @@ const handleExistingChain = async ({ }, message: body, type: "human", - }) -} + }); +}; diff --git a/control-plane/src/modules/integrations/slack/index.ts b/control-plane/src/modules/integrations/slack/index.ts index c7826aa9..a389dd55 100644 --- a/control-plane/src/modules/integrations/slack/index.ts +++ b/control-plane/src/modules/integrations/slack/index.ts @@ -15,6 +15,7 @@ import { integrationSchema } from "../schema"; import { z } from "zod"; import { getUserForCluster } from "../../clerk"; import { submitApproval } from "../../jobs/jobs"; +import { unifiedMessageDataSchema } from "../../contract"; const THREAD_META_KEY = "slackThreadTs"; const CHANNEL_META_KEY = "slackChannel"; @@ -39,12 +40,12 @@ export const slack: InstallableIntegration = { prevConfig: z.infer ) => { logger.info("Deactivating Slack integration", { - clusterId - }) + clusterId, + }); if (!prevConfig.slack) { - logger.warn("Can not deactivate Slack integration with no config") - return + logger.warn("Can not deactivate Slack integration with no config"); + return; } // Cleanup the Nango connection await deleteNangoConnection(prevConfig.slack.nangoConnectionId); @@ -56,25 +57,25 @@ export const slack: InstallableIntegration = { ) => { logger.info("Activating Slack integration", { clusterId, - }) + }); if (!config.slack) { - logger.warn("Can not activate Slack integration with no config") - return + logger.warn("Can not activate Slack integration with no config"); + return; } // It can be possible for the same Nango session token to be used to create multiple connections // e.g, if the "try again" button. // This check will cleanup a previous connection if it is not the same if ( - prevConfig.slack - && config.slack - && prevConfig.slack.nangoConnectionId !== config.slack.nangoConnectionId + prevConfig.slack && + config.slack && + prevConfig.slack.nangoConnectionId !== config.slack.nangoConnectionId ) { logger.warn("Slack integration has been overridden. Cleaning up previous Nango connection", { prevNangoConnectionId: prevConfig.slack.nangoConnectionId, - nangoConnectionId: config.slack.nangoConnectionId - }) + nangoConnectionId: config.slack.nangoConnectionId, + }); await deleteNangoConnection(prevConfig.slack.nangoConnectionId); } @@ -85,7 +86,7 @@ export const slack: InstallableIntegration = { handleCall: async () => { logger.warn("Slack integration does not support calls"); }, -} +}; export const handleNewRunMessage = async ({ message, @@ -115,17 +116,21 @@ export const handleNewRunMessage = async ({ const token = await getAccessToken(integration.slack.nangoConnectionId); if (!token) { - throw new Error(`Could not fetch access token for Slack integration: ${integration.slack.nangoConnectionId}`); + throw new Error( + `Could not fetch access token for Slack integration: ${integration.slack.nangoConnectionId}` + ); } - const client = new webApi.WebClient(token) + const client = new webApi.WebClient(token); + + const messageData = unifiedMessageDataSchema.parse(message.data).data; - if ("message" in message.data && message.data.message) { + if ("message" in messageData && messageData.message) { client?.chat.postMessage({ thread_ts: runMetadata[THREAD_META_KEY], channel: runMetadata[CHANNEL_META_KEY], mrkdwn: true, - text: message.data.message, + text: messageData.message, }); } else { logger.warn("Slack thread message does not have content"); @@ -158,12 +163,14 @@ export const handleApprovalRequest = async ({ const token = await getAccessToken(integration.slack.nangoConnectionId); if (!token) { - throw new Error(`Could not fetch access token for Slack integration: ${integration.slack.nangoConnectionId}`); + throw new Error( + `Could not fetch access token for Slack integration: ${integration.slack.nangoConnectionId}` + ); } - const client = new webApi.WebClient(token) + const client = new webApi.WebClient(token); - const text = `I need your approval to call \`${service}.${targetFn}\` on run <${env.APP_ORIGIN}/clusters/${clusterId}/runs/${runId}|${runId}>` + const text = `I need your approval to call \`${service}.${targetFn}\` on run <${env.APP_ORIGIN}/clusters/${clusterId}/runs/${runId}|${runId}>`; client?.chat.postMessage({ thread_ts: metadata[THREAD_META_KEY], @@ -175,8 +182,8 @@ export const handleApprovalRequest = async ({ type: "section", text: { type: "mrkdwn", - text - } + text, + }, }, { type: "actions", @@ -188,7 +195,7 @@ export const handleApprovalRequest = async ({ text: "Approve", }, value: callId, - action_id: CALL_APPROVE_ACTION_ID + action_id: CALL_APPROVE_ACTION_ID, }, { type: "button", @@ -197,11 +204,11 @@ export const handleApprovalRequest = async ({ text: "Deny", }, value: callId, - action_id: CALL_DENY_ACTION_ID - } - ] - } - ] + action_id: CALL_DENY_ACTION_ID, + }, + ], + }, + ], }); }; @@ -223,14 +230,16 @@ export const start = async (fastify: FastifyInstance) => { if (!integration || !integration.slack) { logger.warn("Could not find Slack integration for teamId", { - teamId + teamId, }); throw new Error("Could not find Slack integration for teamId"); } - const token = await getAccessToken(integration.slack.nangoConnectionId) + const token = await getAccessToken(integration.slack.nangoConnectionId); if (!token) { - throw new Error(`Could not fetch access token for Slack integration: ${integration.slack.nangoConnectionId}`); + throw new Error( + `Could not fetch access token for Slack integration: ${integration.slack.nangoConnectionId}` + ); } return { @@ -238,7 +247,7 @@ export const start = async (fastify: FastifyInstance) => { enterpriseId, botUserId: integration.slack.botUserId, botToken: token, - } + }; }, receiver: new FastifySlackReceiver({ signingSecret: SLACK_SIGNING_SECRET, @@ -247,8 +256,12 @@ export const start = async (fastify: FastifyInstance) => { }), }); - app.action(CALL_APPROVE_ACTION_ID, async (params) => handleCallApprovalAction({ ...params, actionId: CALL_APPROVE_ACTION_ID })); - app.action(CALL_DENY_ACTION_ID, async (params) => handleCallApprovalAction({ ...params, actionId: CALL_DENY_ACTION_ID })); + app.action(CALL_APPROVE_ACTION_ID, async params => + handleCallApprovalAction({ ...params, actionId: CALL_APPROVE_ACTION_ID }) + ); + app.action(CALL_DENY_ACTION_ID, async params => + handleCallApprovalAction({ ...params, actionId: CALL_DENY_ACTION_ID }) + ); // Event listener for mentions app.event("app_mention", async ({ event, client }) => { @@ -281,7 +294,7 @@ export const start = async (fastify: FastifyInstance) => { return; } - const teamId = context.teamId + const teamId = context.teamId; if (!teamId) { logger.warn("Received message without teamId. Ignoring."); @@ -353,11 +366,11 @@ const hasUser = (e: any): e is { user: string } => { const isBlockAction = (e: SlackAction): e is BlockAction => { return typeof e?.type === "string" && e.type === "block_actions"; -} +}; const hasValue = (e: any): e is { value: string } => { - return 'value' in e && typeof e?.value === "string"; -} + return "value" in e && typeof e?.value === "string"; +}; // eslint-disable-next-line @typescript-eslint/no-explicit-any const isBotMessage = (e: any): boolean => { @@ -365,10 +378,11 @@ const isBotMessage = (e: any): boolean => { }; const integrationByTeam = async (teamId: string) => { - const [result] = await db.select({ - cluster_id: integrations.cluster_id, - slack: integrations.slack, - }) + const [result] = await db + .select({ + cluster_id: integrations.cluster_id, + slack: integrations.slack, + }) .from(integrations) .where(sql`slack->>'teamId' = ${teamId}`); @@ -376,19 +390,17 @@ const integrationByTeam = async (teamId: string) => { }; const integrationByCluster = async (clusterId: string) => { - const [result] = await db.select({ - cluster_id: integrations.cluster_id, - slack: integrations.slack, - }) + const [result] = await db + .select({ + cluster_id: integrations.cluster_id, + slack: integrations.slack, + }) .from(integrations) - .where( - eq(integrations.cluster_id, clusterId) - ); + .where(eq(integrations.cluster_id, clusterId)); return result; }; - const getAccessToken = async (connectionId: string) => { if (!nango) { throw new Error("Nango is not configured"); @@ -402,7 +414,10 @@ const getAccessToken = async (connectionId: string) => { return result; }; -const cleanupConflictingIntegrations = async (clusterId: string, config: z.infer) => { +const cleanupConflictingIntegrations = async ( + clusterId: string, + config: z.infer +) => { if (!config.slack) { return; } @@ -414,46 +429,38 @@ const cleanupConflictingIntegrations = async (clusterId: string, config: z.infer }) .from(integrations) .where( - and( - sql`slack->>'teamId' = ${config.slack.teamId}`, - ne(integrations.cluster_id, clusterId) - ) + and(sql`slack->>'teamId' = ${config.slack.teamId}`, ne(integrations.cluster_id, clusterId)) ); if (conflicts.length) { logger.info("Removed conflicting Slack integrations", { - conflicts: conflicts.map((conflict) => conflict.cluster_id) - }) + conflicts: conflicts.map(conflict => conflict.cluster_id), + }); // Cleanup Slack integrations from DB await db - .delete(integrations) - .where( - and( - sql`slack->>'teamId' = ${config.slack.teamId}`, - ne(integrations.cluster_id, clusterId) - ) - ); + .delete(integrations) + .where( + and(sql`slack->>'teamId' = ${config.slack.teamId}`, ne(integrations.cluster_id, clusterId)) + ); // Cleanup Nango connections - await Promise.allSettled(conflicts.map(async (conflict) => { - if (conflict.slack) { - await deleteNangoConnection(conflict.slack.nangoConnectionId); - } - })); - + await Promise.allSettled( + conflicts.map(async conflict => { + if (conflict.slack) { + await deleteNangoConnection(conflict.slack.nangoConnectionId); + } + }) + ); } -} +}; const deleteNangoConnection = async (connectionId: string) => { if (!nango) { throw new Error("Nango is not configured"); } - await nango.deleteConnection( - env.NANGO_SLACK_INTEGRATION_ID, - connectionId - ); + await nango.deleteConnection(env.NANGO_SLACK_INTEGRATION_ID, connectionId); }; const handleNewThread = async ({ event, client, clusterId, userId }: MessageEvent) => { @@ -528,10 +535,14 @@ const handleExistingThread = async ({ event, client, clusterId, userId }: Messag throw new Error("Event had no text"); }; -const authenticateUser = async (userId: string, client: webApi.WebClient, integration: { cluster_id: string }) => { +const authenticateUser = async ( + userId: string, + client: webApi.WebClient, + integration: { cluster_id: string } +) => { if (!env.CLERK_SECRET_KEY) { logger.info("Missing CLERK_SECRET_KEY. Skipping Slack user authentication."); - return + return; } const slackUser = await client.users.info({ @@ -540,8 +551,8 @@ const authenticateUser = async (userId: string, client: webApi.WebClient, integr }); logger.info("Authenticating Slack user", { - slackUser - }) + slackUser, + }); const confirmed = slackUser.user?.is_email_confirmed; const email = slackUser.user?.profile?.email; @@ -549,7 +560,7 @@ const authenticateUser = async (userId: string, client: webApi.WebClient, integr if (!confirmed || !email) { logger.info("Could not authenticate Slack user.", { confirmed, - email + email, }); throw new AuthenticationError("Could not authenticate Slack user"); } @@ -572,21 +583,21 @@ const handleCallApprovalAction = async ({ body, client, context, - actionId + actionId, }: { - ack: () => Promise, - body: SlackAction, - client: webApi.WebClient, - context: { teamId?: string }, - actionId: typeof CALL_APPROVE_ACTION_ID | typeof CALL_DENY_ACTION_ID - }) => { + ack: () => Promise; + body: SlackAction; + client: webApi.WebClient; + context: { teamId?: string }; + actionId: typeof CALL_APPROVE_ACTION_ID | typeof CALL_DENY_ACTION_ID; +}) => { await ack(); if (!isBlockAction(body)) { throw new Error("Slack Action was unexpected type"); } - const approved = actionId === CALL_APPROVE_ACTION_ID; + const approved = actionId === CALL_APPROVE_ACTION_ID; const teamId = context.teamId; const channelId = body.channel?.id; const messageTs = body.message?.ts; @@ -612,7 +623,7 @@ const handleCallApprovalAction = async ({ await submitApproval({ approved, callId: action.value, - clusterId: integration.cluster_id + clusterId: integration.cluster_id, }); logger.info("Call approval received via Slack", { diff --git a/control-plane/src/modules/router.ts b/control-plane/src/modules/router.ts index 8025dbee..e37e73b5 100644 --- a/control-plane/src/modules/router.ts +++ b/control-plane/src/modules/router.ts @@ -9,7 +9,7 @@ import * as data from "./data"; import * as management from "./management"; import * as events from "./observability/events"; import { - assertGenericMessage, + assertMessageOfType, editHumanMessage, getRunMessagesForDisplay, } from "./workflows/workflow-messages"; @@ -300,7 +300,7 @@ export const router = initServer().router(contract, { const auth = request.request.getAuth(); await auth.canManage({ run: { clusterId, runId } }); - assertGenericMessage({ + assertMessageOfType("human", { type: "human", data: { message, diff --git a/control-plane/src/modules/workflows/agent/agent.ai.test.ts b/control-plane/src/modules/workflows/agent/agent.ai.test.ts index bc62596e..f74e42dd 100644 --- a/control-plane/src/modules/workflows/agent/agent.ai.test.ts +++ b/control-plane/src/modules/workflows/agent/agent.ai.test.ts @@ -1,8 +1,8 @@ import { createWorkflowAgent } from "./agent"; import { z } from "zod"; -import { assertResultMessage } from "../workflow-messages"; import { redisClient } from "../../redis"; import { AgentTool } from "./tool"; +import { assertMessageOfType } from "../workflow-messages"; if (process.env.CI) { jest.retryTimes(3); @@ -89,10 +89,7 @@ describe("Agent", () => { expect(outputState.messages[0]).toHaveProperty("type", "human"); expect(outputState.messages[1]).toHaveProperty("type", "agent"); - expect(outputState.messages[2]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[3]).toHaveProperty("type", "agent"); }); @@ -131,14 +128,8 @@ describe("Agent", () => { expect(outputState.messages).toHaveLength(5); expect(outputState.messages[0]).toHaveProperty("type", "human"); expect(outputState.messages[1]).toHaveProperty("type", "agent"); - expect(outputState.messages[2]).toHaveProperty( - "type", - "invocation-result", - ); - expect(outputState.messages[3]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); + expect(outputState.messages[3]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[4]).toHaveProperty("type", "agent"); }); @@ -176,21 +167,18 @@ describe("Agent", () => { expect(outputState.messages).toHaveLength(4); expect(outputState.messages[0]).toHaveProperty("type", "human"); expect(outputState.messages[1]).toHaveProperty("type", "agent"); - expect(outputState.messages[2]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[3]).toHaveProperty("type", "agent"); - assertResultMessage(outputState.messages[2]); - const topLevelResult = outputState.messages[2].data.result; - Object.keys(topLevelResult).forEach((key) => { + const resultMessage = assertMessageOfType("invocation-result", outputState.messages[2]); + const topLevelResult = resultMessage.data.result; + Object.keys(topLevelResult).forEach(key => { expect(topLevelResult[key]).toEqual({ result: "Failed to echo the word 'hello'", status: "success", resultType: "rejection", }); - }) + }); }); }); @@ -198,7 +186,6 @@ describe("Agent", () => { jest.setTimeout(120000); it("should result result schema", async () => { - const app = await createWorkflowAgent({ workflow: { ...workflow, @@ -206,10 +193,10 @@ describe("Agent", () => { type: "object", properties: { word: { - type: "string" - } - } - } + type: "string", + }, + }, + }, }, findRelevantTools: async () => tools, getTool: async () => tools[0], @@ -233,13 +220,10 @@ describe("Agent", () => { expect(outputState.messages).toHaveLength(2); expect(outputState.messages[0]).toHaveProperty("type", "human"); expect(outputState.messages[1]).toHaveProperty("type", "agent"); - expect(outputState.messages[1].data.result).toHaveProperty( - "word", - "hello", - ); + expect(outputState.messages[1].data.result).toHaveProperty("word", "hello"); expect(outputState.result).toEqual({ - word: "hello" + word: "hello", }); }); }); @@ -304,15 +288,9 @@ describe("Agent", () => { expect(outputState.messages[0]).toHaveProperty("type", "human"); expect(outputState.messages[1]).toHaveProperty("type", "agent"); - expect(outputState.messages[2]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[3]).toHaveProperty("type", "agent"); - expect(outputState.messages[4]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[4]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[5]).toHaveProperty("type", "agent"); }); @@ -388,26 +366,18 @@ describe("Agent", () => { expect(outputState.messages[0]).toHaveProperty("type", "human"); expect(outputState.messages[1]).toHaveProperty("type", "agent"); - expect(outputState.messages[2]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[3]).toHaveProperty("type", "agent"); - expect(outputState.messages[4]).toHaveProperty( - "type", - "invocation-result", - ); + expect(outputState.messages[4]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[5]).toHaveProperty("type", "agent"); }); - it("should respect mock responses", async () => { const tools = [ new AgentTool({ name: "searchHaystack", description: "Search haystack", - schema: z.object({ - }).passthrough(), + schema: z.object({}).passthrough(), func: async (input: any) => { return toolCallback(input.input); }, @@ -421,15 +391,14 @@ describe("Agent", () => { type: "object", properties: { word: { - type: "string" - } - } - } + type: "string", + }, + }, + }, }, allAvailableTools: ["searchHaystack"], findRelevantTools: async () => tools, - getTool: async (input) => - tools.find((tool) => tool.name === input.toolName)!, + getTool: async input => tools.find(tool => tool.name === input.toolName)!, postStepSave: async () => {}, mockModelResponses: [ JSON.stringify({ @@ -437,27 +406,28 @@ describe("Agent", () => { invocations: [ { toolName: "searchHaystack", - input: {} - } - ] + input: {}, + }, + ], }), JSON.stringify({ done: true, result: { - word: "needle" - } - }) - ] + word: "needle", + }, + }), + ], }); - - toolCallback.mockResolvedValue(JSON.stringify({ - result: JSON.stringify({ - word: "needle" - }), - resultType: "resolution", - status: "success", - })); + toolCallback.mockResolvedValue( + JSON.stringify({ + result: JSON.stringify({ + word: "needle", + }), + resultType: "resolution", + status: "success", + }) + ); const outputState = await app.invoke({ messages: [ @@ -475,14 +445,10 @@ describe("Agent", () => { expect(outputState.messages[1]).toHaveProperty("type", "agent"); expect(outputState.messages[2]).toHaveProperty("type", "invocation-result"); expect(outputState.messages[3]).toHaveProperty("type", "agent"); - expect(outputState.messages[3].data).toHaveProperty( - "result", - { - word: "needle" - } - ) + expect(outputState.messages[3].data).toHaveProperty("result", { + word: "needle", + }); expect(toolCallback).toHaveBeenCalledTimes(1); }); }); - }); diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts index 3e1b1919..64f356ef 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.test.ts @@ -1,14 +1,11 @@ -import { ReleventToolLookup } from "../agent"; -import { handleModelCall } from "./model-call"; -import { z } from "zod"; -import { WorkflowAgentState } from "../state"; import { ulid } from "ulid"; -import { - assertAgentMessage, - assertGenericMessage, -} from "../../workflow-messages"; +import { z } from "zod"; import { Model } from "../../../models"; +import { assertMessageOfType } from "../../workflow-messages"; +import { ReleventToolLookup } from "../agent"; +import { WorkflowAgentState } from "../state"; import { AgentTool } from "../tool"; +import { handleModelCall } from "./model-call"; describe("handleModelCall", () => { const workflow = { @@ -71,7 +68,7 @@ describe("handleModelCall", () => { }, structured: { done: true, - message: "nothing to do" + message: "nothing to do", }, }); @@ -83,8 +80,8 @@ describe("handleModelCall", () => { expect(result.messages![0]).toHaveProperty("type", "agent"); - assertAgentMessage(result.messages![0]); - expect(result.messages![0].data.invocations).not.toBeDefined(); + const agentMessage = assertMessageOfType("agent", result.messages![0]); + expect(agentMessage.data.invocations).not.toBeDefined(); }); it("should ignore done if invocations are provided", async () => { @@ -115,12 +112,12 @@ describe("handleModelCall", () => { expect(result.messages![0]).toHaveProperty("type", "agent"); - assertAgentMessage(result.messages![0]); + const agentMessage = assertMessageOfType("agent", result.messages![0]); // Result should have been striped out - expect(result.messages![0].data.result).not.toBeDefined(); - expect(result.messages![0].data.invocations).toHaveLength(1); - expect(result.messages![0].data.invocations).toContainEqual({ + expect(agentMessage.data.result).not.toBeDefined(); + expect(agentMessage.data.invocations).toHaveLength(1); + expect(agentMessage.data.invocations).toContainEqual({ id: expect.any(String), toolName: "notify", input: { message: "A message" }, @@ -146,20 +143,20 @@ describe("handleModelCall", () => { expect(result.status).toBe("running"); expect(result.messages![0]).toHaveProperty("type", "agent-invalid"); - assertGenericMessage(result.messages![0]); - expect(result.messages![0].data).toHaveProperty( + const invalidMessage = assertMessageOfType("agent-invalid", result.messages![0]); + expect(invalidMessage.data).toHaveProperty( "details", expect.objectContaining({ message: "nothing to do", - }), + }) ); expect(result.messages![1]).toHaveProperty("type", "supervisor"); - assertGenericMessage(result.messages![1]); + const supervisorMessage = assertMessageOfType("supervisor", result.messages![1]); - expect(result.messages![1].data).toHaveProperty( + expect(supervisorMessage.data).toHaveProperty( "message", - "If you are not done, please provide an invocation, otherwise return done.", + "If you are not done, please provide an invocation, otherwise return done." ); }); @@ -181,20 +178,20 @@ describe("handleModelCall", () => { expect(result.status).toBe("running"); expect(result.messages![0]).toHaveProperty("type", "agent-invalid"); - assertGenericMessage(result.messages![0]); - expect(result.messages![0].data).toHaveProperty( + const invalidMessage = assertMessageOfType("agent-invalid", result.messages![0]); + expect(invalidMessage.data).toHaveProperty( "details", expect.objectContaining({ done: true, - }), + }) ); expect(result.messages![1]).toHaveProperty("type", "supervisor"); - assertGenericMessage(result.messages![1]); - expect(result.messages![1].data).toHaveProperty( + const supervisorMessage = assertMessageOfType("supervisor", result.messages![1]); + expect(supervisorMessage.data).toHaveProperty( "message", - "Please provide a final result or a reason for stopping.", + "Please provide a final result or a reason for stopping." ); }); @@ -205,9 +202,7 @@ describe("handleModelCall", () => { throw error; }; - expect( - handleModelCall(state, model, errorFindRelevantTools), - ).rejects.toThrow(error); + expect(handleModelCall(state, model, errorFindRelevantTools)).rejects.toThrow(error); }); it("should abort if a cycle is detected", async () => { @@ -220,9 +215,7 @@ describe("handleModelCall", () => { runId: workflow.id, type: "agent" as const, data: { - invocations: [ - { done: false, learning: "I learnt some stuff" } as any, - ], + invocations: [{ done: false, learning: "I learnt some stuff" } as any], }, }); messages.push({ @@ -236,9 +229,9 @@ describe("handleModelCall", () => { }); } - expect( - handleModelCall({ ...state, messages }, model, findRelevantTools), - ).rejects.toThrow("Detected cycle in workflow."); + expect(handleModelCall({ ...state, messages }, model, findRelevantTools)).rejects.toThrow( + "Detected cycle in workflow." + ); }); it("should trigger supervisor if parsing fails", async () => { @@ -258,17 +251,17 @@ describe("handleModelCall", () => { expect(result.status).toBe("running"); expect(result.messages![0]).toHaveProperty("type", "agent-invalid"); - assertGenericMessage(result.messages![0]); - expect(result.messages![0].data).toHaveProperty("details"); + const invalidMessage = assertMessageOfType("agent-invalid", result.messages![0]); + expect(invalidMessage.data).toHaveProperty("details"); expect(result.messages![1]).toHaveProperty("type", "supervisor"); - assertGenericMessage(result.messages![1]); + const supervisorMessage = assertMessageOfType("supervisor", result.messages![1]); - expect(result.messages![1].data).toHaveProperty( + expect(supervisorMessage.data).toHaveProperty( "message", - expect.stringContaining("Provided object was invalid, check your input"), + expect.stringContaining("Provided object was invalid, check your input") ); - expect(result.messages![1].data.details).toHaveProperty("errors"); + expect(supervisorMessage.data.details).toHaveProperty("errors"); }); // Edge case where the model trys to call a tool (unbound) rather than returning it through `invocations` array. @@ -305,9 +298,9 @@ describe("handleModelCall", () => { expect(result.status).toBe("running"); expect(result.messages![0]).toHaveProperty("type", "agent"); - assertAgentMessage(result.messages![0]); + const agentMessage = assertMessageOfType("agent", result.messages![0]); - expect(result.messages![0].data).toHaveProperty("invocations", [ + expect(agentMessage.data).toHaveProperty("invocations", [ { id: expect.any(String), toolName: "notify", @@ -360,9 +353,9 @@ describe("handleModelCall", () => { expect(result.status).toBe("running"); expect(result.messages![0]).toHaveProperty("type", "agent"); - assertAgentMessage(result.messages![0]); + const agentMessage = assertMessageOfType("agent", result.messages![0]); - expect(result.messages![0].data).toHaveProperty("invocations", [ + expect(agentMessage.data).toHaveProperty("invocations", [ { id: expect.any(String), toolName: "notify", diff --git a/control-plane/src/modules/workflows/agent/nodes/tool-call.test.ts b/control-plane/src/modules/workflows/agent/nodes/tool-call.test.ts index b1f38a5f..9976cd86 100644 --- a/control-plane/src/modules/workflows/agent/nodes/tool-call.test.ts +++ b/control-plane/src/modules/workflows/agent/nodes/tool-call.test.ts @@ -4,9 +4,9 @@ import { SpecialResultTypes } from "../tools/functions"; import { NotFoundError } from "../../../../utilities/errors"; import { ulid } from "ulid"; import { WorkflowAgentState } from "../state"; -import { assertResultMessage } from "../../workflow-messages"; import { redisClient } from "../../../redis"; import { AgentTool } from "../tool"; +import { assertMessageOfType } from "../../workflow-messages"; describe("handleToolCalls", () => { const workflow = { @@ -91,7 +91,7 @@ describe("handleToolCalls", () => { result: JSON.stringify([waitingJobId]), resultType: SpecialResultTypes.jobTimeout, status: "success", - }), + }) ); const stateUpdate = await handleToolCalls(baseState, async () => tool); @@ -117,16 +117,16 @@ describe("handleToolCalls", () => { expect(stateUpdate).toHaveProperty("messages"); expect(stateUpdate.messages).toHaveLength(1); - expect(stateUpdate.messages![0]).toHaveProperty( - "type", - "invocation-result", - ); + expect(stateUpdate.messages![0]).toHaveProperty("type", "invocation-result"); - assertResultMessage(stateUpdate.messages![0]!); + assertMessageOfType("invocation-result", stateUpdate.messages![0]!); - expect(stateUpdate.messages![0]?.data.result).toHaveProperty( - "message", - expect.stringContaining(`Failed to find tool: console_echo`), + expect(stateUpdate.messages![0]!.data as any).toEqual( + expect.objectContaining({ + result: expect.objectContaining({ + message: expect.stringContaining(`Failed to find tool: console_echo`), + }), + }) ); }); @@ -158,7 +158,7 @@ describe("handleToolCalls", () => { ...baseState, messages, }, - async () => tool, + async () => tool ); expect(toolHandler).toHaveBeenCalledTimes(0); @@ -170,24 +170,18 @@ describe("handleToolCalls", () => { expect(stateUpdate.messages).toHaveLength(1); - expect(stateUpdate.messages![0]).toHaveProperty( - "type", - "invocation-result", - ); - assertResultMessage(stateUpdate.messages![0]); + expect(stateUpdate.messages![0]).toHaveProperty("type", "invocation-result"); + assertMessageOfType("invocation-result", stateUpdate.messages![0]!); - expect(stateUpdate.messages![0].data.result).toEqual( + expect((stateUpdate.messages![0].data as any).result).toEqual( expect.objectContaining({ - message: expect.stringContaining( - `Provided input did not match schema for ${tool.name}`, - ), + message: expect.stringContaining(`Provided input did not match schema for ${tool.name}`), parseResult: expect.arrayContaining([ expect.objectContaining({ - message: "is not allowed to have the additional property \"wrongKey\"" - } - ) + message: 'is not allowed to have the additional property "wrongKey"', + }), ]), - }), + }) ); }); @@ -240,7 +234,7 @@ describe("handleToolCalls", () => { ...baseState, messages, }, - async () => tool, + async () => tool ); expect(stateUpdate).toHaveProperty("status", "running"); @@ -256,18 +250,15 @@ describe("handleToolCalls", () => { expect(arg1).toEqual( expect.objectContaining({ input: "hello", - }), + }) ); // We should only receive one new message, a tool call result for tool call `123` expect(stateUpdate.messages).toHaveLength(1); - expect(stateUpdate.messages![0]).toHaveProperty( - "type", - "invocation-result", - ); + expect(stateUpdate.messages![0]).toHaveProperty("type", "invocation-result"); - assertResultMessage(stateUpdate.messages![0]); + assertMessageOfType("invocation-result", stateUpdate.messages![0]!); expect(stateUpdate.messages![0].data).toHaveProperty("id", "123"); }); @@ -308,7 +299,7 @@ describe("handleToolCalls", () => { result: JSON.stringify([waitingJobId]), resultType: SpecialResultTypes.jobTimeout, status: "success", - }), + }) ); toolHandler.mockResolvedValueOnce( @@ -316,7 +307,7 @@ describe("handleToolCalls", () => { result: JSON.stringify({}), resultType: "resolution", status: "success", - }), + }) ); const stateUpdate = await handleToolCalls( @@ -324,7 +315,7 @@ describe("handleToolCalls", () => { ...baseState, messages, }, - async () => tool, + async () => tool ); expect(toolHandler).toHaveBeenCalledTimes(2); diff --git a/control-plane/src/modules/workflows/agent/nodes/tool-call.ts b/control-plane/src/modules/workflows/agent/nodes/tool-call.ts index eee7ca53..b757142c 100644 --- a/control-plane/src/modules/workflows/agent/nodes/tool-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/tool-call.ts @@ -1,13 +1,10 @@ import { ulid } from "ulid"; -import { - AgentError, - InvalidJobArgumentsError, -} from "../../../../utilities/errors"; +import { AgentError, InvalidJobArgumentsError } from "../../../../utilities/errors"; import * as events from "../../../observability/events"; import { logger } from "../../../observability/logger"; import { addAttributes, withSpan } from "../../../observability/tracer"; import { trackCustomerTelemetry } from "../../../track-customer-telemetry"; -import { AgentMessage, assertAgentMessage } from "../../workflow-messages"; +import { AgentMessage, assertMessageOfType } from "../../workflow-messages"; import { Run } from "../../workflows"; import { ToolFetcher } from "../agent"; import { WorkflowAgentState } from "../state"; @@ -16,14 +13,12 @@ import { AgentTool, AgentToolInputError } from "../tool"; export const TOOL_CALL_NODE_NAME = "action"; -export const handleToolCalls = ( - state: WorkflowAgentState, - getTool: ToolFetcher, -) => withSpan("workflow.toolCalls", () => _handleToolCalls(state, getTool)); +export const handleToolCalls = (state: WorkflowAgentState, getTool: ToolFetcher) => + withSpan("workflow.toolCalls", () => _handleToolCalls(state, getTool)); const _handleToolCalls = async ( state: WorkflowAgentState, - getTool: ToolFetcher, + getTool: ToolFetcher ): Promise> => { // When we recieve parallel tool calls, we will receive a number of ToolMessage's // after the last AIMessage (The actual function call). @@ -33,12 +28,9 @@ const _handleToolCalls = async ( const resolvedToolsCalls = new Set(); while (lastMessage.type === "invocation-result") { - logger.info( - "Found invocation-result message, finding last non-invocation message", - { - toolCallId: lastMessage.data.id, - }, - ); + logger.info("Found invocation-result message, finding last non-invocation message", { + toolCallId: lastMessage.data.id, + }); // Keep track of the tool calls which have already resolved resolvedToolsCalls.add(lastMessage.data.id); @@ -54,21 +46,18 @@ const _handleToolCalls = async ( lastMessage = message; } - assertAgentMessage(lastMessage); + const agentMessage = assertMessageOfType("agent", lastMessage); - if ( - !lastMessage.data.invocations || - lastMessage.data.invocations.length === 0 - ) { + if (!agentMessage.data.invocations || agentMessage.data.invocations.length === 0) { logger.error("Expected a tool call", { lastMessage }); throw new AgentError("Expected a tool call"); } const toolResults = await Promise.all( - lastMessage.data.invocations + agentMessage.data.invocations // Filter out any tool_calls which have already resolvedd - .filter((toolCall) => !resolvedToolsCalls.has(toolCall.id ?? "")) - .map((toolCall) => handleToolCall(toolCall, state.workflow, getTool)), + .filter(toolCall => !resolvedToolsCalls.has(toolCall.id ?? "")) + .map(toolCall => handleToolCall(toolCall, state.workflow, getTool)) ); return toolResults.reduce( @@ -77,10 +66,10 @@ const _handleToolCalls = async ( if (result.waitingJobs) acc.waitingJobs!.push(...result.waitingJobs); if (result.result) { if (!!acc.result && !!result.result && result.result !== acc.result) { - logger.error( - "Multiple tools returned different results. Last one will be used.", - { result, accResult: acc.result }, - ); + logger.error("Multiple tools returned different results. Last one will be used.", { + result, + accResult: acc.result, + }); } acc.result = result.result; @@ -93,30 +82,26 @@ const _handleToolCalls = async ( waitingJobs: [], status: "running", result: undefined, - }, + } ); }; const handleToolCall = ( toolCall: Required["invocations"][number], workflow: Run, - getTool: ToolFetcher, + getTool: ToolFetcher ) => - withSpan( - "workflow.toolCall", - () => _handleToolCall(toolCall, workflow, getTool), - { - attributes: { - "tool.name": toolCall.toolName, - "tool.call.id": toolCall.id, - }, + withSpan("workflow.toolCall", () => _handleToolCall(toolCall, workflow, getTool), { + attributes: { + "tool.name": toolCall.toolName, + "tool.call.id": toolCall.id, }, - ); + }); const _handleToolCall = async ( toolCall: Required["invocations"][number], workflow: Run, - getTool: ToolFetcher, + getTool: ToolFetcher ): Promise> => { logger.info("Executing tool call"); diff --git a/control-plane/src/modules/workflows/router.ts b/control-plane/src/modules/workflows/router.ts index eec2dfd6..2892fd71 100644 --- a/control-plane/src/modules/workflows/router.ts +++ b/control-plane/src/modules/workflows/router.ts @@ -19,7 +19,6 @@ import { timeline } from "../timeline"; import { getRunsByMetadata } from "./metadata"; import { addMessageAndResume, - assertRunReady, createRetry, createRun, deleteRun, @@ -60,7 +59,10 @@ export const runsRouter = initServer().router( return { status: 200, - body: workflow, + body: { + ...workflow, + result: workflow.result ?? null, + }, }; }, createRun: async request => { diff --git a/control-plane/src/modules/workflows/workflow-messages.ts b/control-plane/src/modules/workflows/workflow-messages.ts index f15a4c23..51217454 100644 --- a/control-plane/src/modules/workflows/workflow-messages.ts +++ b/control-plane/src/modules/workflows/workflow-messages.ts @@ -1,55 +1,48 @@ import Anthropic from "@anthropic-ai/sdk"; import { and, desc, eq, gt, InferSelectModel, ne, sql } from "drizzle-orm"; +import { ulid } from "ulid"; import { z } from "zod"; -import { - agentDataSchema, - genericMessageDataSchema, - messageDataSchema, - resultDataSchema, -} from "../contract"; +import { unifiedMessageDataSchema } from "../contract"; import { db, RunMessageMetadata, workflowMessages } from "../data"; import { events } from "../observability/events"; import { logger } from "../observability/logger"; import { resumeRun } from "./workflows"; -import { ulid } from "ulid"; -export type MessageData = z.infer; +export type TypedMessage = z.infer; -export type TypedMessage = AgentMessage | InvocationResultMessage | GenericMessage; +export type TypedMessageWithMeta = { + id: string; + data: TypedMessage; + createdAt: Date; + pending: boolean; + displayableContext: Record | null; +}; /** * A structured response from the agent. */ -export type AgentMessage = { - data: z.infer; - type: "agent"; -}; +export type AgentMessage = Extract; /** * The result of a tool call. */ -export type InvocationResultMessage = { - data: z.infer; - type: "invocation-result"; -}; +export type InvocationResultMessage = Extract; /** * A generic message container. */ -export type GenericMessage = { - data: z.infer; - type: "human" | "template" | "supervisor" | "agent-invalid"; -}; +export type GenericMessage = Extract< + TypedMessage, + { type: "human" | "template" | "supervisor" | "agent-invalid" } +>; -export type RunMessage = { +export type RunMessage = TypedMessage & { id: string; - data: InferSelectModel["data"]; - type: InferSelectModel["type"]; clusterId: string; runId: string; createdAt: Date; updatedAt?: Date | null; -} & TypedMessage; +}; export const insertRunMessage = async ({ clusterId, @@ -138,7 +131,7 @@ export const getRunMessagesForDisplay = async ({ runId: string; last?: number; after?: string; -}) => { +}): Promise => { const messages = await db .select({ id: workflowMessages.id, @@ -181,10 +174,28 @@ export const getRunMessagesForDisplay = async ({ return message; }) .map(message => { - validateMessage(message); + const { success, data, error } = unifiedMessageDataSchema.safeParse({ + type: message.type, + data: message.data, + }); + + const messageWithType = !success + ? { + type: "supervisor" as const, + data: { + message: "Invalid message data", + details: { + error: error?.message, + }, + }, + } + : data; return { - ...message, + id: message.id, + data: messageWithType, + createdAt: message.createdAt, + pending: false, displayableContext: message.metadata?.displayable ?? null, }; }); @@ -332,79 +343,40 @@ export const toAnthropicMessage = (message: TypedMessage): Anthropic.MessagePara } }; -const validateMessage = (message: Pick): TypedMessage => { - switch (message.type) { - case "agent": { - assertAgentMessage(message); - break; - } - case "invocation-result": { - assertResultMessage(message); - break; - } - default: { - assertGenericMessage(message); - } - } - return message; -}; - -export function hasInvocations(message: AgentMessage): boolean { - return (message.data.invocations && message.data.invocations.length > 0) ?? false; -} - -export function assertAgentMessage( - message: Pick -): asserts message is AgentMessage { - if (message.type !== "agent") { - throw new Error("Expected an AgentMessage"); - } - - const result = agentDataSchema.safeParse(message.data); +const validateMessage = (message: unknown) => { + const result = unifiedMessageDataSchema.safeParse(message); if (!result.success) { - logger.error("Invalid AgentMessage data", { - data: message.data, + logger.error(`Invalid message type detected`, { + data: message, result, + error: result.error, }); - throw new Error("Invalid AgentMessage data"); - } -} -export function assertResultMessage( - message: Pick -): asserts message is InvocationResultMessage { - if (message.type !== "invocation-result") { - throw new Error("Expected a InvocationResultMessage"); + throw new Error("Invalid message data"); } - const result = resultDataSchema.safeParse(message.data); + return result.data; +}; - if (!result.success) { - logger.error("Invalid InvocationResultMessage data", { - data: message.data, - result, - }); - throw new Error("Invalid InvocationResultMessage data"); - } +export function hasInvocations(message: AgentMessage): boolean { + return (message.data.invocations && message.data.invocations.length > 0) ?? false; } -export function assertGenericMessage( - message: Pick -): asserts message is GenericMessage { - if (!["human", "template", "supervisor", "agent-invalid"].includes(message.type)) { - throw new Error("Expected a GenericMessage"); - } - - const result = genericMessageDataSchema.safeParse(message.data); +export function assertMessageOfType< + T extends "agent" | "invocation-result" | "human" | "template" | "supervisor" | "agent-invalid", +>(type: T, message: unknown) { + const result = unifiedMessageDataSchema.safeParse(message); if (!result.success) { - logger.error("Invalid GenericMessage data", { - data: message.data, - result, - }); - throw new Error("Invalid GenericMessage data"); + throw new Error("Invalid message data"); } + + if (result.data.type !== type) { + throw new Error(`Expected a ${type} message. Got ${result.data.type}`); + } + + return result.data as Extract; } export const lastAgentMessage = async ({ @@ -435,11 +407,9 @@ export const lastAgentMessage = async ({ return; } - assertAgentMessage(result); - return result; + return assertMessageOfType("agent", result); }; - export const getMessageByReference = async (reference: string, clusterId: string) => { const [result] = await db .select({ @@ -459,28 +429,22 @@ export const getMessageByReference = async (reference: string, clusterId: string ) ); - return result -} + return result; +}; export const updateMessageReference = async ({ externalReference, clusterId, - messageId -}: - { - externalReference: string, - clusterId: string, - messageId: string - }) => { + messageId, +}: { + externalReference: string; + clusterId: string; + messageId: string; +}) => { await db .update(workflowMessages) .set({ metadata: sql`COALESCE(${workflowMessages.metadata}, '{}')::jsonb || ${JSON.stringify({ externalReference })}::jsonb`, }) - .where( - and( - eq(workflowMessages.cluster_id, clusterId), - eq(workflowMessages.id, messageId) - ) - ); -} + .where(and(eq(workflowMessages.cluster_id, clusterId), eq(workflowMessages.id, messageId))); +}; diff --git a/control-plane/src/modules/workflows/workflows.ts b/control-plane/src/modules/workflows/workflows.ts index aa9af408..5bfe8dd4 100644 --- a/control-plane/src/modules/workflows/workflows.ts +++ b/control-plane/src/modules/workflows/workflows.ts @@ -354,7 +354,7 @@ export const getWorkflowDetail = async ({ // Current a workflow can have multiple "results". // For now, we just use the last result. // In the future, we will actually persist the workflow result. - result: agentMessage?.data?.result ?? null, + result: agentMessage?.type === "agent" ? agentMessage.data.result : null, }; }; From e4a6e579a73816b04ce20564f746641cd5ab98b6 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 3 Jan 2025 08:33:09 +1100 Subject: [PATCH 2/3] chore: Remove redundant logging in cache utility --- control-plane/src/utilities/cache.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/control-plane/src/utilities/cache.ts b/control-plane/src/utilities/cache.ts index 0c305a5a..b6e1fea8 100644 --- a/control-plane/src/utilities/cache.ts +++ b/control-plane/src/utilities/cache.ts @@ -13,7 +13,6 @@ export const createCache = (namespace: symbol) => { const localResult = localCache.get(key); if (localResult !== undefined) { - logger.info("Local cache hit", { key }); return localResult; } @@ -26,7 +25,6 @@ export const createCache = (namespace: symbol) => { return undefined; }, set: async (key: string, value: T, stdTTLSeconds: number) => { - logger.info("Local cache set", { key, value, stdTTLSeconds }); localCache.set(key, value, stdTTLSeconds); if (stdTTLSeconds > 1000) { From fe28fe0db940e23f3174032737e46f97514e2a36 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 3 Jan 2025 09:01:15 +1100 Subject: [PATCH 3/3] chore: Remove unused logger and validation function in workflow-messages.ts --- .../modules/workflows/workflow-messages.ts | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/control-plane/src/modules/workflows/workflow-messages.ts b/control-plane/src/modules/workflows/workflow-messages.ts index 51217454..b67fa1f1 100644 --- a/control-plane/src/modules/workflows/workflow-messages.ts +++ b/control-plane/src/modules/workflows/workflow-messages.ts @@ -5,7 +5,6 @@ import { z } from "zod"; import { unifiedMessageDataSchema } from "../contract"; import { db, RunMessageMetadata, workflowMessages } from "../data"; import { events } from "../observability/events"; -import { logger } from "../observability/logger"; import { resumeRun } from "./workflows"; export type TypedMessage = z.infer; @@ -61,7 +60,6 @@ export const insertRunMessage = async ({ data: InferSelectModel["data"]; metadata?: RunMessageMetadata; }) => { - validateMessage({ data, type }); return db .insert(workflowMessages) .values({ @@ -69,9 +67,8 @@ export const insertRunMessage = async ({ user_id: userId ?? "SYSTEM", cluster_id: clusterId, workflow_id: runId, - type, - data, metadata, + ...unifiedMessageDataSchema.parse({ data, type }), }) .returning({ id: workflowMessages.id, @@ -254,7 +251,7 @@ export const getWorkflowMessages = async ({ .map(message => { return { ...message, - ...validateMessage(message), + ...unifiedMessageDataSchema.parse(message), }; }); }; @@ -343,22 +340,6 @@ export const toAnthropicMessage = (message: TypedMessage): Anthropic.MessagePara } }; -const validateMessage = (message: unknown) => { - const result = unifiedMessageDataSchema.safeParse(message); - - if (!result.success) { - logger.error(`Invalid message type detected`, { - data: message, - result, - error: result.error, - }); - - throw new Error("Invalid message data"); - } - - return result.data; -}; - export function hasInvocations(message: AgentMessage): boolean { return (message.data.invocations && message.data.invocations.length > 0) ?? false; }