diff --git a/control-plane/src/modules/jobs/jobs.ts b/control-plane/src/modules/jobs/jobs.ts index 1ae528b..97395d6 100644 --- a/control-plane/src/modules/jobs/jobs.ts +++ b/control-plane/src/modules/jobs/jobs.ts @@ -1,4 +1,4 @@ -import { and, eq, gt, isNull, lte, sql } from "drizzle-orm"; +import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm"; import { env } from "../../utilities/env"; import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors"; import { getBlobsForJobs } from "../blobs"; @@ -137,6 +137,38 @@ export const getJob = async ({ }; }; +export const getLatestJobsResultedByFunctionName = async ({ + clusterId, + service, + functionName, + limit, + resultType, +}: { + clusterId: string; + service: string; + functionName: string; + limit: number; + resultType: ResultType; +}) => { + return data.db + .select({ + result: data.jobs.result, + resultType: data.jobs.result_type, + targetArgs: data.jobs.target_args, + }) + .from(data.jobs) + .where( + and( + eq(data.jobs.cluster_id, clusterId), + eq(data.jobs.target_fn, functionName), + eq(data.jobs.service, service), + eq(data.jobs.result_type, resultType), + ), + ) + .orderBy(desc(data.jobs.created_at)) + .limit(limit); +}; + export const getJobsForWorkflow = async ({ clusterId, runId, diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index a98e73a..5d9adc3 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -2,10 +2,7 @@ import { ReleventToolLookup } from "../agent"; import { toAnthropicMessages } from "../../workflow-messages"; import { logger } from "../../../observability/logger"; import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state"; -import { - addAttributes, - withSpan, -} from "../../../observability/tracer"; +import { addAttributes, withSpan } from "../../../observability/tracer"; import { AgentError } from "../../../../utilities/errors"; import { ulid } from "ulid"; @@ -74,8 +71,7 @@ const _handleModelCall = async ( "If there is nothing left to do, return 'done' and provide the final result.", "If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.", "When possible, return multiple invocations to trigger them in parallel.", - "Once all tasks have been completed, return the final result as a structured object.", - "Provide concise and clear responses. Use **bold** to highlight important words.", + "Provide concise and clear responses.", state.additionalContext, "", schemaString, @@ -263,7 +259,6 @@ const _handleModelCall = async ( }; } - return { messages: [ { diff --git a/control-plane/src/modules/workflows/agent/nodes/model-output.ts b/control-plane/src/modules/workflows/agent/nodes/model-output.ts index 5827dfe..8bff6b7 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-output.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-output.ts @@ -1,4 +1,3 @@ - import { JsonSchema7ObjectType } from "zod-to-json-schema"; import { AgentTool } from "../tool"; import { workflows } from "../../../data"; @@ -8,8 +7,7 @@ import { WorkflowAgentState } from "../state"; type ModelInvocationOutput = { toolName: string; input: unknown; - -} +}; export type ModelOutput = { invocations?: ModelInvocationOutput[]; @@ -18,18 +16,17 @@ export type ModelOutput = { message?: string; done?: boolean; issue?: string; -} +}; export const buildModelSchema = ({ state, relevantSchemas, - resultSchema + resultSchema, }: { state: WorkflowAgentState; relevantSchemas: AgentTool[]; resultSchema?: InferSelectModel["result_schema"]; - }) => { - +}) => { // Build the toolName enum const toolNameEnum = [ ...relevantSchemas.map((tool) => tool.name), @@ -48,7 +45,7 @@ export const buildModelSchema = ({ issue: { type: "string", description: - "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", + "Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.", }, }, }; diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index fd69694..07c2556 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -32,6 +32,8 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time"; import { env } from "../../../utilities/env"; import { events } from "../../observability/events"; import { AgentTool } from "./tool"; +import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs"; +import { truncate } from "lodash"; /** * Run a workflow from the most recent saved state @@ -259,6 +261,28 @@ export const processRun = async ( } }; +const formatJobsContext = ( + jobs: { targetArgs: string; result: string | null }[], + status: "success" | "failed", +) => { + if (jobs.length === 0) return ""; + + const arbitraryLength = 500; + + const jobEntries = jobs + .map( + (job) => ` + ${truncate(job.targetArgs, { length: arbitraryLength })} + ${truncate(job.result ?? "", { length: arbitraryLength })} + `, + ) + .join("\n"); + + return ` + ${jobEntries} + `; +}; + async function findRelatedFunctionTools(workflow: Run, search: string) { const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, { clusterId: workflow.clusterId, @@ -284,14 +308,40 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { const toolContexts = await Promise.all( relatedTools.map(async (toolDetails) => { - const metadata = await getToolMetadata( - workflow.clusterId, - toolDetails.serviceName, - toolDetails.functionName, - ); + const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([ + getToolMetadata( + workflow.clusterId, + toolDetails.serviceName, + toolDetails.functionName, + ), + getLatestJobsResultedByFunctionName({ + clusterId: workflow.clusterId, + service: toolDetails.serviceName, + functionName: toolDetails.functionName, + limit: 3, + resultType: "resolution", + }), + getLatestJobsResultedByFunctionName({ + clusterId: workflow.clusterId, + service: toolDetails.serviceName, + functionName: toolDetails.functionName, + limit: 3, + resultType: "rejection", + }), + ]); const contextArr = []; + const successJobsContext = formatJobsContext(resolvedJobs, "success"); + if (successJobsContext) { + contextArr.push(successJobsContext); + } + + const failedJobsContext = formatJobsContext(rejectedJobs, "failed"); + if (failedJobsContext) { + contextArr.push(failedJobsContext); + } + if (metadata?.additionalContext) { contextArr.push(`${metadata.additionalContext}`); }