Skip to content

Commit

Permalink
Revert "feat: Retrieve past function results for reinforcement learni…
Browse files Browse the repository at this point in the history
…ng (#334)"

This reverts commit f825a8f.
  • Loading branch information
johnjcsmith authored Dec 19, 2024
1 parent 71d575b commit b166dcb
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 95 deletions.
34 changes: 1 addition & 33 deletions control-plane/src/modules/jobs/jobs.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { and, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { env } from "../../utilities/env";
import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors";
import { getBlobsForJobs } from "../blobs";
Expand Down Expand Up @@ -137,38 +137,6 @@ export const getJob = async ({
};
};

export const getLatestJobsResultedByFunctionName = async ({
clusterId,
service,
functionName,
limit,
resultType,
}: {
clusterId: string;
service: string;
functionName: string;
limit: number;
resultType: ResultType;
}) => {
return data.db
.select({
result: data.jobs.result,
resultType: data.jobs.result_type,
targetArgs: data.jobs.target_args,
})
.from(data.jobs)
.where(
and(
eq(data.jobs.cluster_id, clusterId),
eq(data.jobs.target_fn, functionName),
eq(data.jobs.service, service),
eq(data.jobs.result_type, resultType),
),
)
.orderBy(desc(data.jobs.created_at))
.limit(limit);
};

export const getJobsForWorkflow = async ({
clusterId,
runId,
Expand Down
9 changes: 7 additions & 2 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { ReleventToolLookup } from "../agent";
import { toAnthropicMessages } from "../../workflow-messages";
import { logger } from "../../../observability/logger";
import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state";
import { addAttributes, withSpan } from "../../../observability/tracer";
import {
addAttributes,
withSpan,
} from "../../../observability/tracer";
import { AgentError } from "../../../../utilities/errors";
import { ulid } from "ulid";

Expand Down Expand Up @@ -71,7 +74,8 @@ const _handleModelCall = async (
"If there is nothing left to do, return 'done' and provide the final result.",
"If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.",
"When possible, return multiple invocations to trigger them in parallel.",
"Provide concise and clear responses.",
"Once all tasks have been completed, return the final result as a structured object.",
"Provide concise and clear responses. Use **bold** to highlight important words.",
state.additionalContext,
"<TOOLS_SCHEMAS>",
schemaString,
Expand Down Expand Up @@ -259,6 +263,7 @@ const _handleModelCall = async (
};
}


return {
messages: [
{
Expand Down
13 changes: 8 additions & 5 deletions control-plane/src/modules/workflows/agent/nodes/model-output.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { AgentTool } from "../tool";
import { workflows } from "../../../data";
Expand All @@ -7,7 +8,8 @@ import { WorkflowAgentState } from "../state";
type ModelInvocationOutput = {
toolName: string;
input: unknown;
};

}

export type ModelOutput = {
invocations?: ModelInvocationOutput[];
Expand All @@ -16,17 +18,18 @@ export type ModelOutput = {
message?: string;
done?: boolean;
issue?: string;
};
}

export const buildModelSchema = ({
state,
relevantSchemas,
resultSchema,
resultSchema
}: {
state: WorkflowAgentState;
relevantSchemas: AgentTool[];
resultSchema?: InferSelectModel<typeof workflows>["result_schema"];
}) => {
}) => {

// Build the toolName enum
const toolNameEnum = [
...relevantSchemas.map((tool) => tool.name),
Expand All @@ -45,7 +48,7 @@ export const buildModelSchema = ({
issue: {
type: "string",
description:
"Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.",
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
},
},
};
Expand Down
60 changes: 5 additions & 55 deletions control-plane/src/modules/workflows/agent/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time";
import { env } from "../../../utilities/env";
import { events } from "../../observability/events";
import { AgentTool } from "./tool";
import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs";
import { truncate } from "lodash";

/**
* Run a workflow from the most recent saved state
Expand Down Expand Up @@ -261,28 +259,6 @@ export const processRun = async (
}
};

const formatJobsContext = (
jobs: { targetArgs: string; result: string | null }[],
status: "success" | "failed",
) => {
if (jobs.length === 0) return "";

const arbitraryLength = 500;

const jobEntries = jobs
.map(
(job) => `
<input>${truncate(job.targetArgs, { length: arbitraryLength })}</input>
<output>${truncate(job.result ?? "", { length: arbitraryLength })}</output>
`,
)
.join("\n");

return `<jobs status="${status}">
${jobEntries}
</jobs>`;
};

async function findRelatedFunctionTools(workflow: Run, search: string) {
const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, {
clusterId: workflow.clusterId,
Expand All @@ -308,40 +284,14 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {

const toolContexts = await Promise.all(
relatedTools.map(async (toolDetails) => {
const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([
getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "resolution",
}),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "rejection",
}),
]);
const metadata = await getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
);

const contextArr = [];

const successJobsContext = formatJobsContext(resolvedJobs, "success");
if (successJobsContext) {
contextArr.push(successJobsContext);
}

const failedJobsContext = formatJobsContext(rejectedJobs, "failed");
if (failedJobsContext) {
contextArr.push(failedJobsContext);
}

if (metadata?.additionalContext) {
contextArr.push(`<context>${metadata.additionalContext}</context>`);
}
Expand Down

0 comments on commit b166dcb

Please sign in to comment.