Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Retrieve past function results for reinforcement learning #334

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion control-plane/src/modules/jobs/jobs.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { and, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { env } from "../../utilities/env";
import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors";
import { getBlobsForJobs } from "../blobs";
Expand Down Expand Up @@ -137,6 +137,38 @@ export const getJob = async ({
};
};

export const getLatestJobsResultedByFunctionName = async ({
clusterId,
service,
functionName,
limit,
resultType,
}: {
clusterId: string;
service: string;
functionName: string;
limit: number;
resultType: ResultType;
}) => {
return data.db
.select({
result: data.jobs.result,
resultType: data.jobs.result_type,
targetArgs: data.jobs.target_args,
})
.from(data.jobs)
.where(
and(
eq(data.jobs.cluster_id, clusterId),
eq(data.jobs.target_fn, functionName),
eq(data.jobs.service, service),
eq(data.jobs.result_type, resultType),
),
)
.orderBy(desc(data.jobs.created_at))
.limit(limit);
};

export const getJobsForWorkflow = async ({
clusterId,
runId,
Expand Down
9 changes: 2 additions & 7 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ import { ReleventToolLookup } from "../agent";
import { toAnthropicMessages } from "../../workflow-messages";
import { logger } from "../../../observability/logger";
import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state";
import {
addAttributes,
withSpan,
} from "../../../observability/tracer";
import { addAttributes, withSpan } from "../../../observability/tracer";
import { AgentError } from "../../../../utilities/errors";
import { ulid } from "ulid";

Expand Down Expand Up @@ -74,8 +71,7 @@ const _handleModelCall = async (
"If there is nothing left to do, return 'done' and provide the final result.",
"If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.",
"When possible, return multiple invocations to trigger them in parallel.",
"Once all tasks have been completed, return the final result as a structured object.",
"Provide concise and clear responses. Use **bold** to highlight important words.",
"Provide concise and clear responses.",
state.additionalContext,
"<TOOLS_SCHEMAS>",
schemaString,
Expand Down Expand Up @@ -263,7 +259,6 @@ const _handleModelCall = async (
};
}


return {
messages: [
{
Expand Down
13 changes: 5 additions & 8 deletions control-plane/src/modules/workflows/agent/nodes/model-output.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { AgentTool } from "../tool";
import { workflows } from "../../../data";
Expand All @@ -8,8 +7,7 @@ import { WorkflowAgentState } from "../state";
type ModelInvocationOutput = {
toolName: string;
input: unknown;

}
};

export type ModelOutput = {
invocations?: ModelInvocationOutput[];
Expand All @@ -18,18 +16,17 @@ export type ModelOutput = {
message?: string;
done?: boolean;
issue?: string;
}
};

export const buildModelSchema = ({
state,
relevantSchemas,
resultSchema
resultSchema,
}: {
state: WorkflowAgentState;
relevantSchemas: AgentTool[];
resultSchema?: InferSelectModel<typeof workflows>["result_schema"];
}) => {

}) => {
// Build the toolName enum
const toolNameEnum = [
...relevantSchemas.map((tool) => tool.name),
Expand All @@ -48,7 +45,7 @@ export const buildModelSchema = ({
issue: {
type: "string",
description:
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
"Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.",
},
},
};
Expand Down
60 changes: 55 additions & 5 deletions control-plane/src/modules/workflows/agent/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time";
import { env } from "../../../utilities/env";
import { events } from "../../observability/events";
import { AgentTool } from "./tool";
import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs";
import { truncate } from "lodash";

/**
* Run a workflow from the most recent saved state
Expand Down Expand Up @@ -259,6 +261,28 @@ export const processRun = async (
}
};

const formatJobsContext = (
jobs: { targetArgs: string; result: string | null }[],
status: "success" | "failed",
) => {
if (jobs.length === 0) return "";

const arbitraryLength = 500;

const jobEntries = jobs
.map(
(job) => `
<input>${truncate(job.targetArgs, { length: arbitraryLength })}</input>
<output>${truncate(job.result ?? "", { length: arbitraryLength })}</output>
`,
)
.join("\n");

return `<jobs status="${status}">
${jobEntries}
</jobs>`;
};

async function findRelatedFunctionTools(workflow: Run, search: string) {
const flags = await flagsmith?.getIdentityFlags(workflow.clusterId, {
clusterId: workflow.clusterId,
Expand All @@ -284,14 +308,40 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {

const toolContexts = await Promise.all(
relatedTools.map(async (toolDetails) => {
const metadata = await getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
);
const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([
getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "resolution",
}),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "rejection",
}),
]);

const contextArr = [];

const successJobsContext = formatJobsContext(resolvedJobs, "success");
if (successJobsContext) {
contextArr.push(successJobsContext);
}

const failedJobsContext = formatJobsContext(rejectedJobs, "failed");
if (failedJobsContext) {
contextArr.push(failedJobsContext);
}

if (metadata?.additionalContext) {
contextArr.push(`<context>${metadata.additionalContext}</context>`);
}
Expand Down
Loading