diff --git a/control-plane/src/modules/jobs/jobs.ts b/control-plane/src/modules/jobs/jobs.ts index 1ae528b2..97395d64 100644 --- a/control-plane/src/modules/jobs/jobs.ts +++ b/control-plane/src/modules/jobs/jobs.ts @@ -1,4 +1,4 @@ -import { and, eq, gt, isNull, lte, sql } from "drizzle-orm"; +import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm"; import { env } from "../../utilities/env"; import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors"; import { getBlobsForJobs } from "../blobs"; @@ -137,6 +137,38 @@ export const getJob = async ({ }; }; +export const getLatestJobsResultedByFunctionName = async ({ + clusterId, + service, + functionName, + limit, + resultType, +}: { + clusterId: string; + service: string; + functionName: string; + limit: number; + resultType: ResultType; +}) => { + return data.db + .select({ + result: data.jobs.result, + resultType: data.jobs.result_type, + targetArgs: data.jobs.target_args, + }) + .from(data.jobs) + .where( + and( + eq(data.jobs.cluster_id, clusterId), + eq(data.jobs.target_fn, functionName), + eq(data.jobs.service, service), + eq(data.jobs.result_type, resultType), + ), + ) + .orderBy(desc(data.jobs.created_at)) + .limit(limit); +}; + export const getJobsForWorkflow = async ({ clusterId, runId, diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index a98e73a8..5d9adc31 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -2,10 +2,7 @@ import { ReleventToolLookup } from "../agent"; import { toAnthropicMessages } from "../../workflow-messages"; import { logger } from "../../../observability/logger"; import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state"; -import { - addAttributes, - withSpan, -} from "../../../observability/tracer"; +import { addAttributes, withSpan } from "../../../observability/tracer"; import { AgentError } from "../../../../utilities/errors"; import { ulid } from "ulid"; @@ -74,8 +71,7 @@ const _handleModelCall = async ( "If there is nothing left to do, return 'done' and provide the final result.", "If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.", "When possible, return multiple invocations to trigger them in parallel.", - "Once all tasks have been completed, return the final result as a structured object.", - "Provide concise and clear responses. Use **bold** to highlight important words.", + "Provide concise and clear responses.", state.additionalContext, "", schemaString, @@ -263,7 +259,6 @@ const _handleModelCall = async ( }; } - return { messages: [ { diff --git a/control-plane/src/modules/workflows/agent/nodes/model-output.ts b/control-plane/src/modules/workflows/agent/nodes/model-output.ts index 5827dfec..8bff6b7c 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-output.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-output.ts @@ -1,4 +1,3 @@ - import { JsonSchema7ObjectType } from "zod-to-json-schema"; import { AgentTool } from "../tool"; import { workflows } from "../../../data"; @@ -8,8 +7,7 @@ import { WorkflowAgentState } from "../state"; type ModelInvocationOutput = { toolName: string; input: unknown; - -} +}; export type ModelOutput = { invocations?: ModelInvocationOutput[]; @@ -18,18 +16,17 @@ export type ModelOutput = { message?: string; done?: boolean; issue?: string; -} +}; export const buildModelSchema = ({ state, relevantSchemas, - resultSchema + resultSchema, }: { state: WorkflowAgentState; relevantSchemas: AgentTool[]; resultSchema?: InferSelectModel["result_schema"]; - }) => { - +}) => { // Build the toolName enum const toolNameEnum = [ ...relevantSchemas.map((tool) => tool.name), @@ -48,7 +45,7 @@ export const buildModelSchema = ({ issue: { type: "string", description: - "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", + "Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.", }, }, }; diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index fd69694f..93a17f50 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -32,6 +32,7 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time"; import { env } from "../../../utilities/env"; import { events } from "../../observability/events"; import { AgentTool } from "./tool"; +import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs"; /** * Run a workflow from the most recent saved state @@ -284,14 +285,52 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { const toolContexts = await Promise.all( relatedTools.map(async (toolDetails) => { - const metadata = await getToolMetadata( - workflow.clusterId, - toolDetails.serviceName, - toolDetails.functionName, - ); + const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([ + getToolMetadata( + workflow.clusterId, + toolDetails.serviceName, + toolDetails.functionName, + ), + getLatestJobsResultedByFunctionName({ + clusterId: workflow.clusterId, + service: toolDetails.serviceName, + functionName: toolDetails.functionName, + limit: 3, + resultType: "resolution", + }), + getLatestJobsResultedByFunctionName({ + clusterId: workflow.clusterId, + service: toolDetails.serviceName, + functionName: toolDetails.functionName, + limit: 3, + resultType: "rejection", + }), + ]); const contextArr = []; + if (resolvedJobs.length > 0) { + contextArr.push( + `${resolvedJobs + .map( + (j) => + `${j.targetArgs}${j.result}`, + ) + .join("\n")}`, + ); + } + + if (rejectedJobs.length > 0) { + contextArr.push( + `${rejectedJobs + .map( + (j) => + `${j.targetArgs}${j.result}`, + ) + .join("\n")}`, + ); + } + if (metadata?.additionalContext) { contextArr.push(`${metadata.additionalContext}`); }