Skip to content

Commit

Permalink
Revert "Revert "feat: Retrieve past function results for reinforcemen…
Browse files Browse the repository at this point in the history
…t learning (#334)" (#343)"

This reverts commit 5557f13.
  • Loading branch information
nadeesha committed Dec 20, 2024
1 parent 8fde5b1 commit 32051fe
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 21 deletions.
34 changes: 33 additions & 1 deletion control-plane/src/modules/jobs/jobs.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { and, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { env } from "../../utilities/env";
import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors";
import { getBlobsForJobs } from "../blobs";
Expand Down Expand Up @@ -137,6 +137,38 @@ export const getJob = async ({
};
};

export const getLatestJobsResultedByFunctionName = async ({
clusterId,
service,
functionName,
limit,
resultType,
}: {
clusterId: string;
service: string;
functionName: string;
limit: number;
resultType: ResultType;
}) => {
return data.db
.select({
result: data.jobs.result,
resultType: data.jobs.result_type,
targetArgs: data.jobs.target_args,
})
.from(data.jobs)
.where(
and(
eq(data.jobs.cluster_id, clusterId),
eq(data.jobs.target_fn, functionName),
eq(data.jobs.service, service),
eq(data.jobs.result_type, resultType),
),
)
.orderBy(desc(data.jobs.created_at))
.limit(limit);
};

export const getJobsForWorkflow = async ({
clusterId,
runId,
Expand Down
9 changes: 2 additions & 7 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ import { ReleventToolLookup } from "../agent";
import { toAnthropicMessages } from "../../workflow-messages";
import { logger } from "../../../observability/logger";
import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state";
import {
addAttributes,
withSpan,
} from "../../../observability/tracer";
import { addAttributes, withSpan } from "../../../observability/tracer";
import { AgentError } from "../../../../utilities/errors";
import { ulid } from "ulid";

Expand Down Expand Up @@ -74,8 +71,7 @@ const _handleModelCall = async (
"If there is nothing left to do, return 'done' and provide the final result.",
"If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.",
"When possible, return multiple invocations to trigger them in parallel.",
"Once all tasks have been completed, return the final result as a structured object.",
"Provide concise and clear responses. Use **bold** to highlight important words.",
"Provide concise and clear responses.",
state.additionalContext,
"<TOOLS_SCHEMAS>",
schemaString,
Expand Down Expand Up @@ -263,7 +259,6 @@ const _handleModelCall = async (
};
}


return {
messages: [
{
Expand Down
13 changes: 5 additions & 8 deletions control-plane/src/modules/workflows/agent/nodes/model-output.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { AgentTool } from "../tool";
import { workflows } from "../../../data";
Expand All @@ -8,8 +7,7 @@ import { WorkflowAgentState } from "../state";
type ModelInvocationOutput = {
toolName: string;
input: unknown;

}
};

export type ModelOutput = {
invocations?: ModelInvocationOutput[];
Expand All @@ -18,18 +16,17 @@ export type ModelOutput = {
message?: string;
done?: boolean;
issue?: string;
}
};

export const buildModelSchema = ({
state,
relevantSchemas,
resultSchema
resultSchema,
}: {
state: WorkflowAgentState;
relevantSchemas: AgentTool[];
resultSchema?: InferSelectModel<typeof workflows>["result_schema"];
}) => {

}) => {
// Build the toolName enum
const toolNameEnum = [
...relevantSchemas.map((tool) => tool.name),
Expand All @@ -48,7 +45,7 @@ export const buildModelSchema = ({
issue: {
type: "string",
description:
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
"Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.",
},
},
};
Expand Down
60 changes: 55 additions & 5 deletions control-plane/src/modules/workflows/agent/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time";
import { env } from "../../../utilities/env";
import { events } from "../../observability/events";
import { AgentTool } from "./tool";
import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs";
import { truncate } from "lodash";

/**
* Run a workflow from the most recent saved state
Expand Down Expand Up @@ -259,6 +261,28 @@ export const processRun = async (
}
};

const formatJobsContext = (
jobs: { targetArgs: string; result: string | null }[],
status: "success" | "failed",
) => {
if (jobs.length === 0) return "";

const arbitraryLength = 500;

const jobEntries = jobs
.map(
(job) => `
<input>${truncate(job.targetArgs, { length: arbitraryLength })}</input>
<output>${truncate(job.result ?? "", { length: arbitraryLength })}</output>
`,
)
.join("\n");

return `<jobs status="${status}">
${jobEntries}
</jobs>`;
};

async function findRelatedFunctionTools(workflow: Run, search: string) {
const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, {
clusterId: workflow.clusterId,
Expand All @@ -284,14 +308,40 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {

const toolContexts = await Promise.all(
relatedTools.map(async (toolDetails) => {
const metadata = await getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
);
const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([
getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "resolution",
}),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "rejection",
}),
]);

const contextArr = [];

const successJobsContext = formatJobsContext(resolvedJobs, "success");
if (successJobsContext) {
contextArr.push(successJobsContext);
}

const failedJobsContext = formatJobsContext(rejectedJobs, "failed");
if (failedJobsContext) {
contextArr.push(failedJobsContext);
}

if (metadata?.additionalContext) {
contextArr.push(`<context>${metadata.additionalContext}</context>`);
}
Expand Down

0 comments on commit 32051fe

Please sign in to comment.