Skip to content

Commit

Permalink
feat: Retrieve past function results for reinforcement learning
Browse files Browse the repository at this point in the history
- Introduced `getLatestJobsResultedByFunctionName` to fetch job results based on function name, service, and result type.
- By default, job results will be ingested into the context.
  • Loading branch information
nadeesha committed Dec 18, 2024
1 parent 558e4b3 commit 1d9c74c
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 21 deletions.
34 changes: 33 additions & 1 deletion control-plane/src/modules/jobs/jobs.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { and, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { and, desc, eq, gt, isNull, lte, sql } from "drizzle-orm";
import { env } from "../../utilities/env";
import { JobPollTimeoutError, NotFoundError } from "../../utilities/errors";
import { getBlobsForJobs } from "../blobs";
Expand Down Expand Up @@ -137,6 +137,38 @@ export const getJob = async ({
};
};

export const getLatestJobsResultedByFunctionName = async ({
clusterId,
service,
functionName,
limit,
resultType,
}: {
clusterId: string;
service: string;
functionName: string;
limit: number;
resultType: ResultType;
}) => {
return data.db
.select({
result: data.jobs.result,
resultType: data.jobs.result_type,
targetArgs: data.jobs.target_args,
})
.from(data.jobs)
.where(
and(
eq(data.jobs.cluster_id, clusterId),
eq(data.jobs.target_fn, functionName),
eq(data.jobs.service, service),
eq(data.jobs.result_type, resultType),
),
)
.orderBy(desc(data.jobs.created_at))
.limit(limit);
};

export const getJobsForWorkflow = async ({
clusterId,
runId,
Expand Down
9 changes: 2 additions & 7 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ import { ReleventToolLookup } from "../agent";
import { toAnthropicMessages } from "../../workflow-messages";
import { logger } from "../../../observability/logger";
import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state";
import {
addAttributes,
withSpan,
} from "../../../observability/tracer";
import { addAttributes, withSpan } from "../../../observability/tracer";
import { AgentError } from "../../../../utilities/errors";
import { ulid } from "ulid";

Expand Down Expand Up @@ -74,8 +71,7 @@ const _handleModelCall = async (
"If there is nothing left to do, return 'done' and provide the final result.",
"If you encounter invocation errors (e.g., incorrect tool name, missing input), retry based on the error message.",
"When possible, return multiple invocations to trigger them in parallel.",
"Once all tasks have been completed, return the final result as a structured object.",
"Provide concise and clear responses. Use **bold** to highlight important words.",
"Provide concise and clear responses.",
state.additionalContext,
"<TOOLS_SCHEMAS>",
schemaString,
Expand Down Expand Up @@ -263,7 +259,6 @@ const _handleModelCall = async (
};
}


return {
messages: [
{
Expand Down
13 changes: 5 additions & 8 deletions control-plane/src/modules/workflows/agent/nodes/model-output.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import { JsonSchema7ObjectType } from "zod-to-json-schema";
import { AgentTool } from "../tool";
import { workflows } from "../../../data";
Expand All @@ -8,8 +7,7 @@ import { WorkflowAgentState } from "../state";
type ModelInvocationOutput = {
toolName: string;
input: unknown;

}
};

export type ModelOutput = {
invocations?: ModelInvocationOutput[];
Expand All @@ -18,18 +16,17 @@ export type ModelOutput = {
message?: string;
done?: boolean;
issue?: string;
}
};

export const buildModelSchema = ({
state,
relevantSchemas,
resultSchema
resultSchema,
}: {
state: WorkflowAgentState;
relevantSchemas: AgentTool[];
resultSchema?: InferSelectModel<typeof workflows>["result_schema"];
}) => {

}) => {
// Build the toolName enum
const toolNameEnum = [
...relevantSchemas.map((tool) => tool.name),
Expand All @@ -48,7 +45,7 @@ export const buildModelSchema = ({
issue: {
type: "string",
description:
"Describe any issues you have encountered in this step. Specifically related to the tools you are using.",
"Describe any issues you have encountered in this step. Specifically related to the tools you are using. If none, keep this field empty.",
},
},
};
Expand Down
49 changes: 44 additions & 5 deletions control-plane/src/modules/workflows/agent/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { CURRENT_DATE_TIME_TOOL_NAME } from "./tools/date-time";
import { env } from "../../../utilities/env";
import { events } from "../../observability/events";
import { AgentTool } from "./tool";
import { getLatestJobsResultedByFunctionName } from "../../jobs/jobs";

/**
* Run a workflow from the most recent saved state
Expand Down Expand Up @@ -284,14 +285,52 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {

const toolContexts = await Promise.all(
relatedTools.map(async (toolDetails) => {
const metadata = await getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
);
const [metadata, resolvedJobs, rejectedJobs] = await Promise.all([
getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "resolution",
}),
getLatestJobsResultedByFunctionName({
clusterId: workflow.clusterId,
service: toolDetails.serviceName,
functionName: toolDetails.functionName,
limit: 3,
resultType: "rejection",
}),
]);

const contextArr = [];

if (resolvedJobs.length > 0) {
contextArr.push(
`<jobs status="success">${resolvedJobs
.map(
(j) =>
`<input>${j.targetArgs}</input><output>${j.result}</output>`,
)
.join("\n")}</jobs>`,
);
}

if (rejectedJobs.length > 0) {
contextArr.push(
`<jobs status="failed">${rejectedJobs
.map(
(j) =>
`<input>${j.targetArgs}</input><output>${j.result}</output>`,
)
.join("\n")}</jobs>`,
);
}

if (metadata?.additionalContext) {
contextArr.push(`<context>${metadata.additionalContext}</context>`);
}
Expand Down

0 comments on commit 1d9c74c

Please sign in to comment.