Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Revert "feat: Retrieve past function results for reinforcement learning" #343

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions control-plane/src/modules/jobs/jobs.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { and, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { env } from "../../utilities/env";
import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors";
import { getBlobsForJobs } from "../blobs";
Expand Down Expand Up @@ -137,38 +137,6 @@ export const getJob = async ({
};
};

export const getLatestJobsResultedByFunctionName = async ({
clusterId,
service,
functionName,
limit,
resultType,
}: {
clusterId: string;
service: string;
functionName: string;
limit: number;
resultType: ResultType;
}) => {
return data.db
.select({
result: data.jobs.result,
resultType: data.jobs.result_type,
targetArgs: data.jobs.target_args,
})
.from(data.jobs)
.where(
and(
eq(data.jobs.cluster_id, clusterId),
eq(data.jobs.target_fn, functionName),
eq(data.jobs.service, service),
eq(data.jobs.result_type, resultType),
),
)
.orderBy(desc(data.jobs.created_at))
.limit(limit);
};

export const getJobsForWorkflow = async ({
clusterId,
runId,
Expand Down
9 changes: 7 additions & 2 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { ReleventToolLookup } from "../agent";
import { toAnthropicMessages } from "../../workflow-messages";
import { logger } from "../../../observability/logger";
import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state";
import { addAttributes, withSpan } from "../../../observability/tracer";
import {
addAttributes,
withSpan,
} from "../../../observability/tracer";
import { AgentError } from "../../../../utilities/errors";
import { ulid } from "ulid";

Expand Down Expand Up @@ -71,7 +74,8 @@ const _handleModelCall = async (
"If there is nothing left to do, return 'done' and provide the final result.",
"If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.",
"When possible, return multiple invocations to trigger them in parallel.",
"Provide concise and clear responses.",
"Once all tasks have been completed, return the final result as a structured object.",
"Provide concise and clear responses. Use **bold** to highlight important words.",
state.additionalContext,
"<TOOLS_SCHEMAS>",
schemaString,
Expand Down Expand Up @@ -259,6 +263,7 @@ const _handleModelCall = async (
};
}


return {
messages: [
{
Expand Down
13 changes: 8 additions & 5 deletions control-plane/src/modules/workflows/agent/nodes/model-output.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { AgentTool } from "../tool";
import { workflows } from "../../../data";
Expand All @@ -7,7 +8,8 @@ import { WorkflowAgentState } from "../state";
type ModelInvocationOutput = {
toolName: string;
input: unknown;
};

}

export type ModelOutput = {
invocations?: ModelInvocationOutput[];
Expand All @@ -16,17 +18,18 @@ export type ModelOutput = {
message?: string;
done?: boolean;
issue?: string;
};
}

export const buildModelSchema = ({
state,
relevantSchemas,
resultSchema,
resultSchema
}: {
state: WorkflowAgentState;
relevantSchemas: AgentTool[];
resultSchema?: InferSelectModel<typeof workflows>["result_schema"];
}) => {
}) => {

// Build the toolName enum
const toolNameEnum = [
...relevantSchemas.map((tool) => tool.name),
Expand All @@ -45,7 +48,7 @@ export const buildModelSchema = ({
issue: {
type: "string",
description:
"Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.",
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
},
},
};
Expand Down
60 changes: 5 additions & 55 deletions control-plane/src/modules/workflows/agent/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time";
import { env } from "../../../utilities/env";
import { events } from "../../observability/events";
import { AgentTool } from "./tool";
import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs";
import { truncate } from "lodash";

/**
* Run a workflow from the most recent saved state
Expand Down Expand Up @@ -261,28 +259,6 @@ export const processRun = async (
}
};

const formatJobsContext = (
jobs: { targetArgs: string; result: string | null }[],
status: "success" | "failed",
) => {
if (jobs.length === 0) return "";

const arbitraryLength = 500;

const jobEntries = jobs
.map(
(job) => `
<input>${truncate(job.targetArgs, { length: arbitraryLength })}</input>
<output>${truncate(job.result ?? "", { length: arbitraryLength })}</output>
`,
)
.join("\n");

return `<jobs status="${status}">
${jobEntries}
</jobs>`;
};

async function findRelatedFunctionTools(workflow: Run, search: string) {
const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, {
clusterId: workflow.clusterId,
Expand All @@ -308,40 +284,14 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {

const toolContexts = await Promise.all(
relatedTools.map(async (toolDetails) => {
const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([
getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "resolution",
}),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "rejection",
}),
]);
const metadata = await getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
);

const contextArr = [];

const successJobsContext = formatJobsContext(resolvedJobs, "success");
if (successJobsContext) {
contextArr.push(successJobsContext);
}

const failedJobsContext = formatJobsContext(rejectedJobs, "failed");
if (failedJobsContext) {
contextArr.push(failedJobsContext);
}

if (metadata?.additionalContext) {
contextArr.push(`<context>${metadata.additionalContext}</context>`);
}
Expand Down
Loading