Skip to content

Commit

Permalink
fix: Attach knowledge artifact tool conditionally and async tool build (
Browse files Browse the repository at this point in the history
  • Loading branch information
nadeesha authored Nov 28, 2024
1 parent 4f53047 commit 4ad91a7
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
67 changes: 34 additions & 33 deletions control-plane/src/modules/workflows/agent/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,7 @@ import {
} from "./tools/knowledge-artifacts";
import { buildMockFunctionTool } from "./tools/mock-function";
import { events } from "../../observability/events";

const internalToolsMap: Record<
string,
(
workflow: Run,
toolCallId: string,
) => DynamicStructuredTool | Promise<DynamicStructuredTool>
> = {
[ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME]: buildAccessKnowledgeArtifacts,
};
import { getClusterInternalTools } from "./tools/cluster-internal-tools";

/**
* Run a workflow from the most recent saved state
Expand All @@ -46,7 +37,13 @@ export const run = async (run: Run) => {
logger.info("Running workflow");

// Parallelize fetching additional context and service definitions
const [additionalContext, serviceDefinitions] = await Promise.all([
const [
additionalContext,
serviceDefinitions,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_updateResult,
internalToolsMap,
] = await Promise.all([
buildAdditionalContext(run),
getServiceDefinitions({
clusterId: run.clusterId,
Expand All @@ -55,6 +52,7 @@ export const run = async (run: Run) => {
...run,
status: "running",
}),
getClusterInternalTools(run.clusterId),
]);

const allAvailableTools: string[] = [];
Expand All @@ -73,9 +71,9 @@ export const run = async (run: Run) => {
serviceFunctionEmbeddingId({
serviceName: service.name,
functionName: f.name,
}),
})
);
}),
})
);

const mockToolsMap: Record<string, DynamicStructuredTool> =
Expand Down Expand Up @@ -105,7 +103,7 @@ export const run = async (run: Run) => {
const serviceFunctionDetails = await embeddableServiceFunction.getEntity(
run.clusterId,
"service-function",
toolCall.toolName,
toolCall.toolName
);

if (serviceFunctionDetails) {
Expand All @@ -130,7 +128,7 @@ export const run = async (run: Run) => {
// optimistically embed the next search query
// this is not critical to the workflow, so we can do it in the background
embedSearchQuery(
state.messages.map((m) => JSON.stringify(m.data)).join(" "),
state.messages.map((m) => JSON.stringify(m.data)).join(" ")
);
}

Expand Down Expand Up @@ -165,7 +163,7 @@ export const run = async (run: Run) => {
},
{
recursionLimit: 100,
},
}
);

const parsedOutput = z
Expand Down Expand Up @@ -237,7 +235,7 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {
workflow.clusterId,
"service-function",
search,
30,
30
)
: [];

Expand All @@ -253,7 +251,7 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {
const metadata = await getToolMetadata(
workflow.clusterId,
toolDetails.serviceName,
toolDetails.functionName,
toolDetails.functionName
);

const contextArr = [];
Expand All @@ -264,7 +262,9 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {

if (metadata?.resultKeys) {
contextArr.push(
`<result_keys>${metadata.resultKeys.slice(0, 10).map((k) => k.key)}</result_keys>`,
`<result_keys>${metadata.resultKeys
.slice(0, 10)
.map((k) => k.key)}</result_keys>`
);
}

Expand All @@ -273,7 +273,7 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {
functionName: toolDetails.functionName,
toolContext: contextArr.join("\n\n"),
};
}),
})
);

const selectedTools = relatedTools.map((toolDetails) =>
Expand All @@ -284,13 +284,13 @@ async function findRelatedFunctionTools(workflow: Run, search: string) {
toolContexts.find(
(c) =>
c?.serviceName === toolDetails.serviceName &&
c?.functionName === toolDetails.functionName,
c?.functionName === toolDetails.functionName
)?.toolContext,
]
.filter(Boolean)
.join("\n\n"),
schema: toolDetails.schema,
}),
})
);

return selectedTools;
Expand Down Expand Up @@ -327,32 +327,33 @@ export const findRelevantTools = async (state: WorkflowAgentState) => {
const serviceFunctionDetails = await embeddableServiceFunction.getEntity(
workflow.clusterId,
"service-function",
tool,
tool
);

if (!serviceFunctionDetails) {
throw new Error(
`Tool ${tool} not found in cluster ${workflow.clusterId}`,
`Tool ${tool} not found in cluster ${workflow.clusterId}`
);
}

tools.push(
buildAbstractServiceFunctionTool({
...serviceFunctionDetails,
schema: serviceFunctionDetails.schema,
}),
})
);
}
} else {
const found = await findRelatedFunctionTools(
workflow,
state.messages.map((m) => JSON.stringify(m.data)).join(" "),
state.messages.map((m) => JSON.stringify(m.data)).join(" ")
);

tools.push(...found);

const accessKnowledgeArtifactsTool =
await buildAccessKnowledgeArtifacts(workflow);
const accessKnowledgeArtifactsTool = await buildAccessKnowledgeArtifacts(
workflow
);

// When tools are not specified, attach all internalTools
tools.push(accessKnowledgeArtifactsTool);
Expand All @@ -369,7 +370,7 @@ export const buildMockTools = async (workflow: Run) => {

if (!workflow.test) {
logger.warn(
"Workflow is not marked as test enabled but contains mocks. Mocks will be ignored.",
"Workflow is not marked as test enabled but contains mocks. Mocks will be ignored."
);
return mocks;
}
Expand Down Expand Up @@ -397,7 +398,7 @@ export const buildMockTools = async (workflow: Run) => {
}

const serviceDefinition = serviceDefinitions.find(
(sd) => sd.name === serviceName,
(sd) => sd.name === serviceName
);

if (!serviceDefinition) {
Expand All @@ -406,13 +407,13 @@ export const buildMockTools = async (workflow: Run) => {
{
key,
serviceName,
},
}
);
continue;
}

const functionDefinition = serviceDefinition.functions?.find(
(f) => f.name === functionName,
(f) => f.name === functionName
);

if (!functionDefinition) {
Expand All @@ -422,7 +423,7 @@ export const buildMockTools = async (workflow: Run) => {
key,
serviceName,
functionName,
},
}
);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { DynamicStructuredTool } from "@langchain/core/tools";
import { Run } from "../../workflows";
import {
buildAccessKnowledgeArtifacts,
ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME,
} from "./knowledge-artifacts";
import { createCache } from "../../../../utilities/cache";
import { getClusterDetails } from "../../../management";

const clusterSettingsCache = createCache<{
enableKnowledgebase: boolean;
}>(Symbol("cluster-settings"));

const CACHE_TTL = 60 * 2; // 2 minutes

export type InternalToolBuilder = (
workflow: Run,
toolCallId: string
) => DynamicStructuredTool | Promise<DynamicStructuredTool>;

export const getClusterInternalTools = async (
clusterId: string
): Promise<Record<string, InternalToolBuilder>> => {
const cacheKey = `cluster:${clusterId}`;

let settings = clusterSettingsCache.get(cacheKey);

if (!settings) {
// Get cluster settings
const cluster = await getClusterDetails({ clusterId });
settings = {
enableKnowledgebase: cluster.enableKnowledgebase,
};
clusterSettingsCache.set(cacheKey, settings, CACHE_TTL);
}

const tools: Record<string, InternalToolBuilder> = {};

// Only include knowledge artifacts tool if enabled for cluster
if (settings.enableKnowledgebase) {
tools[ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME] = buildAccessKnowledgeArtifacts;
}

return tools;
};

0 comments on commit 4ad91a7

Please sign in to comment.