From 1d9c74cd20984d24a6e220941e2abc3bfc9b6dc1 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Wed, 18 Dec 2024 23:51:48 +1100 Subject: [PATCH 1/2] feat: Retrieve past function results for reinforcement learning - Introduced `getLatestJobsResultedByFunctionName` to fetch job results based on function name, service, and result type. - By default, job results will be ingested into the context. --- control-plane/src/modules/jobs/jobs.ts | 34 ++++++++++++- .../workflows/agent/nodes/model-call.ts | 9 +--- .../workflows/agent/nodes/model-output.ts | 13 ++--- .../src/modules/workflows/agent/run.ts | 49 +++++++++++++++++-- 4 files changed, 84 insertions(+), 21 deletions(-) diff --git a/control-plane/src/modules/jobs/jobs.ts b/control-plane/src/modules/jobs/jobs.ts index 1ae528b..97395d6 100644 --- a/control-plane/src/modules/jobs/jobs.ts +++ b/control-plane/src/modules/jobs/jobs.ts @@ -1,4 +1,4 @@ -import { and, eq, gt, isNull, lte, sql } from "drizzle-orm"; +import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm"; import { env } from "../../utilities/env"; import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors"; import { getBlobsForJobs } from "../blobs"; @@ -137,6 +137,38 @@ export const getJob = async ({ }; }; +export const getLatestJobsResultedByFunctionName = async ({ + clusterId, + service, + functionName, + limit, + resultType, +}: { + clusterId: string; + service: string; + functionName: string; + limit: number; + resultType: ResultType; +}) => { + return data.db + .select({ + result: data.jobs.result, + resultType: data.jobs.result_type, + targetArgs: data.jobs.target_args, + }) + .from(data.jobs) + .where( + and( + eq(data.jobs.cluster_id, clusterId), + eq(data.jobs.target_fn, functionName), + eq(data.jobs.service, service), + eq(data.jobs.result_type, resultType), + ), + ) + .orderBy(desc(data.jobs.created_at)) + .limit(limit); +}; + export const getJobsForWorkflow = async ({ clusterId, runId, diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index a98e73a..5d9adc3 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -2,10 +2,7 @@ import { ReleventToolLookup } from "../agent"; import { toAnthropicMessages } from "../../workflow-messages"; import { logger } from "../../../observability/logger"; import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state"; -import { - addAttributes, - withSpan, -} from "../../../observability/tracer"; +import { addAttributes, withSpan } from "../../../observability/tracer"; import { AgentError } from "../../../../utilities/errors"; import { ulid } from "ulid"; @@ -74,8 +71,7 @@ const _handleModelCall = async ( "If there is nothing left to do, return 'done' and provide the final result.", "If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.", "When possible, return multiple invocations to trigger them in parallel.", - "Once all tasks have been completed, return the final result as a structured object.", - "Provide concise and clear responses. Use **bold** to highlight important words.", + "Provide concise and clear responses.", state.additionalContext, "", schemaString, @@ -263,7 +259,6 @@ const _handleModelCall = async ( }; } - return { messages: [ { diff --git a/control-plane/src/modules/workflows/agent/nodes/model-output.ts b/control-plane/src/modules/workflows/agent/nodes/model-output.ts index 5827dfe..8bff6b7 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-output.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-output.ts @@ -1,4 +1,3 @@ - import { JsonSchema7ObjectType } from "zod-to-json-schema"; import { AgentTool } from "../tool"; import { workflows } from "../../../data"; @@ -8,8 +7,7 @@ import { WorkflowAgentState } from "../state"; type ModelInvocationOutput = { toolName: string; input: unknown; - -} +}; export type ModelOutput = { invocations?: ModelInvocationOutput[]; @@ -18,18 +16,17 @@ export type ModelOutput = { message?: string; done?: boolean; issue?: string; -} +}; export const buildModelSchema = ({ state, relevantSchemas, - resultSchema + resultSchema, }: { state: WorkflowAgentState; relevantSchemas: AgentTool[]; resultSchema?: InferSelectModel["result_schema"]; - }) => { - +}) => { // Build the toolName enum const toolNameEnum = [ ...relevantSchemas.map((tool) => tool.name), @@ -48,7 +45,7 @@ export const buildModelSchema = ({ issue: { type: "string", description: - "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", + "Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.", }, }, }; diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index fd69694..93a17f5 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -32,6 +32,7 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time"; import { env } from "../../../utilities/env"; import { events } from "../../observability/events"; import { AgentTool } from "./tool"; +import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs"; /** * Run a workflow from the most recent saved state @@ -284,14 +285,52 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { const toolContexts = await Promise.all( relatedTools.map(async (toolDetails) => { - const metadata = await getToolMetadata( - workflow.clusterId, - toolDetails.serviceName, - toolDetails.functionName, - ); + const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([ + getToolMetadata( + workflow.clusterId, + toolDetails.serviceName, + toolDetails.functionName, + ), + getLatestJobsResultedByFunctionName({ + clusterId: workflow.clusterId, + service: toolDetails.serviceName, + functionName: toolDetails.functionName, + limit: 3, + resultType: "resolution", + }), + getLatestJobsResultedByFunctionName({ + clusterId: workflow.clusterId, + service: toolDetails.serviceName, + functionName: toolDetails.functionName, + limit: 3, + resultType: "rejection", + }), + ]); const contextArr = []; + if (resolvedJobs.length > 0) { + contextArr.push( + `${resolvedJobs + .map( + (j) => + `${j.targetArgs}${j.result}`, + ) + .join("\n")}`, + ); + } + + if (rejectedJobs.length > 0) { + contextArr.push( + `${rejectedJobs + .map( + (j) => + `${j.targetArgs}${j.result}`, + ) + .join("\n")}`, + ); + } + if (metadata?.additionalContext) { contextArr.push(`${metadata.additionalContext}`); } From f6759791dc9ab054492ca80d6ca9a6638eb98ed0 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Wed, 18 Dec 2024 23:58:05 +1100 Subject: [PATCH 2/2] update --- .../src/modules/workflows/agent/run.ts | 47 ++++++++++++------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index 93a17f5..07c2556 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -33,6 +33,7 @@ import { env } from "../../../utilities/env"; import { events } from "../../observability/events"; import { AgentTool } from "./tool"; import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs"; +import { truncate } from "lodash"; /** * Run a workflow from the most recent saved state @@ -260,6 +261,28 @@ export const processRun = async ( } }; +const formatJobsContext = ( + jobs: { targetArgs: string; result: string | null }[], + status: "success" | "failed", +) => { + if (jobs.length === 0) return ""; + + const arbitraryLength = 500; + + const jobEntries = jobs + .map( + (job) => ` + ${truncate(job.targetArgs, { length: arbitraryLength })} + ${truncate(job.result ?? "", { length: arbitraryLength })} + `, + ) + .join("\n"); + + return ` + ${jobEntries} + `; +}; + async function findRelatedFunctionTools(workflow: Run, search: string) { const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, { clusterId: workflow.clusterId, @@ -309,26 +332,14 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { const contextArr = []; - if (resolvedJobs.length > 0) { - contextArr.push( - `${resolvedJobs - .map( - (j) => - `${j.targetArgs}${j.result}`, - ) - .join("\n")}`, - ); + const successJobsContext = formatJobsContext(resolvedJobs, "success"); + if (successJobsContext) { + contextArr.push(successJobsContext); } - if (rejectedJobs.length > 0) { - contextArr.push( - `${rejectedJobs - .map( - (j) => - `${j.targetArgs}${j.result}`, - ) - .join("\n")}`, - ); + const failedJobsContext = formatJobsContext(rejectedJobs, "failed"); + if (failedJobsContext) { + contextArr.push(failedJobsContext); } if (metadata?.additionalContext) {