From b166dcb2cd46f197af0efa409efb6eb5322bdee7 Mon Sep 17 00:00:00 2001 From: John Smith Date: Fri, 20 Dec 2024 08:42:55 +1030 Subject: [PATCH] Revert "feat: Retrieve past function results for reinforcement learning (#334)" This reverts commit f825a8f474b7b9f19ea2418195ef9d0da27f7fae. --- control-plane/src/modules/jobs/jobs.ts | 34 +---------- .../workflows/agent/nodes/model-call.ts | 9 ++- .../workflows/agent/nodes/model-output.ts | 13 ++-- .../src/modules/workflows/agent/run.ts | 60 ++----------------- 4 files changed, 21 insertions(+), 95 deletions(-) diff --git a/control-plane/src/modules/jobs/jobs.ts b/control-plane/src/modules/jobs/jobs.ts index 97395d64..1ae528b2 100644 --- a/control-plane/src/modules/jobs/jobs.ts +++ b/control-plane/src/modules/jobs/jobs.ts @@ -1,4 +1,4 @@ -import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm"; +import { and, eq, gt, isNull, lte, sql } from "drizzle-orm"; import { env } from "../../utilities/env"; import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors"; import { getBlobsForJobs } from "../blobs"; @@ -137,38 +137,6 @@ export const getJob = async ({ }; }; -export const getLatestJobsResultedByFunctionName = async ({ - clusterId, - service, - functionName, - limit, - resultType, -}: { - clusterId: string; - service: string; - functionName: string; - limit: number; - resultType: ResultType; -}) => { - return data.db - .select({ - result: data.jobs.result, - resultType: data.jobs.result_type, - targetArgs: data.jobs.target_args, - }) - .from(data.jobs) - .where( - and( - eq(data.jobs.cluster_id, clusterId), - eq(data.jobs.target_fn, functionName), - eq(data.jobs.service, service), - eq(data.jobs.result_type, resultType), - ), - ) - .orderBy(desc(data.jobs.created_at)) - .limit(limit); -}; - export const getJobsForWorkflow = async ({ clusterId, runId, diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index 5d9adc31..a98e73a8 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -2,7 +2,10 @@ import { ReleventToolLookup } from "../agent"; import { toAnthropicMessages } from "../../workflow-messages"; import { logger } from "../../../observability/logger"; import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state"; -import { addAttributes, withSpan } from "../../../observability/tracer"; +import { + addAttributes, + withSpan, +} from "../../../observability/tracer"; import { AgentError } from "../../../../utilities/errors"; import { ulid } from "ulid"; @@ -71,7 +74,8 @@ const _handleModelCall = async ( "If there is nothing left to do, return 'done' and provide the final result.", "If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.", "When possible, return multiple invocations to trigger them in parallel.", - "Provide concise and clear responses.", + "Once all tasks have been completed, return the final result as a structured object.", + "Provide concise and clear responses. Use **bold** to highlight important words.", state.additionalContext, "", schemaString, @@ -259,6 +263,7 @@ const _handleModelCall = async ( }; } + return { messages: [ { diff --git a/control-plane/src/modules/workflows/agent/nodes/model-output.ts b/control-plane/src/modules/workflows/agent/nodes/model-output.ts index 8bff6b7c..5827dfec 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-output.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-output.ts @@ -1,3 +1,4 @@ + import { JsonSchema7ObjectType } from "zod-to-json-schema"; import { AgentTool } from "../tool"; import { workflows } from "../../../data"; @@ -7,7 +8,8 @@ import { WorkflowAgentState } from "../state"; type ModelInvocationOutput = { toolName: string; input: unknown; -}; + +} export type ModelOutput = { invocations?: ModelInvocationOutput[]; @@ -16,17 +18,18 @@ export type ModelOutput = { message?: string; done?: boolean; issue?: string; -}; +} export const buildModelSchema = ({ state, relevantSchemas, - resultSchema, + resultSchema }: { state: WorkflowAgentState; relevantSchemas: AgentTool[]; resultSchema?: InferSelectModel["result_schema"]; -}) => { + }) => { + // Build the toolName enum const toolNameEnum = [ ...relevantSchemas.map((tool) => tool.name), @@ -45,7 +48,7 @@ export const buildModelSchema = ({ issue: { type: "string", description: - "Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.", + "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", }, }, }; diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index 07c2556b..fd69694f 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -32,8 +32,6 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time"; import { env } from "../../../utilities/env"; import { events } from "../../observability/events"; import { AgentTool } from "./tool"; -import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs"; -import { truncate } from "lodash"; /** * Run a workflow from the most recent saved state @@ -261,28 +259,6 @@ export const processRun = async ( } }; -const formatJobsContext = ( - jobs: { targetArgs: string; result: string | null }[], - status: "success" | "failed", -) => { - if (jobs.length === 0) return ""; - - const arbitraryLength = 500; - - const jobEntries = jobs - .map( - (job) => ` - ${truncate(job.targetArgs, { length: arbitraryLength })} - ${truncate(job.result ?? "", { length: arbitraryLength })} - `, - ) - .join("\n"); - - return ` - ${jobEntries} - `; -}; - async function findRelatedFunctionTools(workflow: Run, search: string) { const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, { clusterId: workflow.clusterId, @@ -308,40 +284,14 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { const toolContexts = await Promise.all( relatedTools.map(async (toolDetails) => { - const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([ - getToolMetadata( - workflow.clusterId, - toolDetails.serviceName, - toolDetails.functionName, - ), - getLatestJobsResultedByFunctionName({ - clusterId: workflow.clusterId, - service: toolDetails.serviceName, - functionName: toolDetails.functionName, - limit: 3, - resultType: "resolution", - }), - getLatestJobsResultedByFunctionName({ - clusterId: workflow.clusterId, - service: toolDetails.serviceName, - functionName: toolDetails.functionName, - limit: 3, - resultType: "rejection", - }), - ]); + const metadata = await getToolMetadata( + workflow.clusterId, + toolDetails.serviceName, + toolDetails.functionName, + ); const contextArr = []; - const successJobsContext = formatJobsContext(resolvedJobs, "success"); - if (successJobsContext) { - contextArr.push(successJobsContext); - } - - const failedJobsContext = formatJobsContext(rejectedJobs, "failed"); - if (failedJobsContext) { - contextArr.push(failedJobsContext); - } - if (metadata?.additionalContext) { contextArr.push(`${metadata.additionalContext}`); }