From 929c4fc1199d3cfd7b036eece940fc4a645ad828 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 29 Nov 2024 07:32:44 +1100 Subject: [PATCH] fix: Attach knowledge artifact tool conditionally and async tool build --- .../src/modules/workflows/agent/run.ts | 67 ++++++++++--------- .../agent/tools/cluster-internal-tools.ts | 45 +++++++++++++ 2 files changed, 79 insertions(+), 33 deletions(-) create mode 100644 control-plane/src/modules/workflows/agent/tools/cluster-internal-tools.ts diff --git a/control-plane/src/modules/workflows/agent/run.ts b/control-plane/src/modules/workflows/agent/run.ts index 3c9df5bc..39bb2557 100644 --- a/control-plane/src/modules/workflows/agent/run.ts +++ b/control-plane/src/modules/workflows/agent/run.ts @@ -28,16 +28,7 @@ import { } from "./tools/knowledge-artifacts"; import { buildMockFunctionTool } from "./tools/mock-function"; import { events } from "../../observability/events"; - -const internalToolsMap: Record< - string, - ( - workflow: Run, - toolCallId: string, - ) => DynamicStructuredTool | Promise -> = { - [ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME]: buildAccessKnowledgeArtifacts, -}; +import { getClusterInternalTools } from "./tools/cluster-internal-tools"; /** * Run a workflow from the most recent saved state @@ -46,7 +37,13 @@ export const run = async (run: Run) => { logger.info("Running workflow"); // Parallelize fetching additional context and service definitions - const [additionalContext, serviceDefinitions] = await Promise.all([ + const [ + additionalContext, + serviceDefinitions, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _updateResult, + internalToolsMap, + ] = await Promise.all([ buildAdditionalContext(run), getServiceDefinitions({ clusterId: run.clusterId, @@ -55,6 +52,7 @@ export const run = async (run: Run) => { ...run, status: "running", }), + getClusterInternalTools(run.clusterId), ]); const allAvailableTools: string[] = []; @@ -73,9 +71,9 @@ export const run = async (run: Run) => { serviceFunctionEmbeddingId({ serviceName: service.name, functionName: f.name, - }), + }) ); - }), + }) ); const mockToolsMap: Record = @@ -105,7 +103,7 @@ export const run = async (run: Run) => { const serviceFunctionDetails = await embeddableServiceFunction.getEntity( run.clusterId, "service-function", - toolCall.toolName, + toolCall.toolName ); if (serviceFunctionDetails) { @@ -130,7 +128,7 @@ export const run = async (run: Run) => { // optimistically embed the next search query // this is not critical to the workflow, so we can do it in the background embedSearchQuery( - state.messages.map((m) => JSON.stringify(m.data)).join(" "), + state.messages.map((m) => JSON.stringify(m.data)).join(" ") ); } @@ -165,7 +163,7 @@ export const run = async (run: Run) => { }, { recursionLimit: 100, - }, + } ); const parsedOutput = z @@ -237,7 +235,7 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { workflow.clusterId, "service-function", search, - 30, + 30 ) : []; @@ -253,7 +251,7 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { const metadata = await getToolMetadata( workflow.clusterId, toolDetails.serviceName, - toolDetails.functionName, + toolDetails.functionName ); const contextArr = []; @@ -264,7 +262,9 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { if (metadata?.resultKeys) { contextArr.push( - `${metadata.resultKeys.slice(0, 10).map((k) => k.key)}`, + `${metadata.resultKeys + .slice(0, 10) + .map((k) => k.key)}` ); } @@ -273,7 +273,7 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { functionName: toolDetails.functionName, toolContext: contextArr.join("\n\n"), }; - }), + }) ); const selectedTools = relatedTools.map((toolDetails) => @@ -284,13 +284,13 @@ async function findRelatedFunctionTools(workflow: Run, search: string) { toolContexts.find( (c) => c?.serviceName === toolDetails.serviceName && - c?.functionName === toolDetails.functionName, + c?.functionName === toolDetails.functionName )?.toolContext, ] .filter(Boolean) .join("\n\n"), schema: toolDetails.schema, - }), + }) ); return selectedTools; @@ -327,12 +327,12 @@ export const findRelevantTools = async (state: WorkflowAgentState) => { const serviceFunctionDetails = await embeddableServiceFunction.getEntity( workflow.clusterId, "service-function", - tool, + tool ); if (!serviceFunctionDetails) { throw new Error( - `Tool ${tool} not found in cluster ${workflow.clusterId}`, + `Tool ${tool} not found in cluster ${workflow.clusterId}` ); } @@ -340,19 +340,20 @@ export const findRelevantTools = async (state: WorkflowAgentState) => { buildAbstractServiceFunctionTool({ ...serviceFunctionDetails, schema: serviceFunctionDetails.schema, - }), + }) ); } } else { const found = await findRelatedFunctionTools( workflow, - state.messages.map((m) => JSON.stringify(m.data)).join(" "), + state.messages.map((m) => JSON.stringify(m.data)).join(" ") ); tools.push(...found); - const accessKnowledgeArtifactsTool = - await buildAccessKnowledgeArtifacts(workflow); + const accessKnowledgeArtifactsTool = await buildAccessKnowledgeArtifacts( + workflow + ); // When tools are not specified, attach all internalTools tools.push(accessKnowledgeArtifactsTool); @@ -369,7 +370,7 @@ export const buildMockTools = async (workflow: Run) => { if (!workflow.test) { logger.warn( - "Workflow is not marked as test enabled but contains mocks. Mocks will be ignored.", + "Workflow is not marked as test enabled but contains mocks. Mocks will be ignored." ); return mocks; } @@ -397,7 +398,7 @@ export const buildMockTools = async (workflow: Run) => { } const serviceDefinition = serviceDefinitions.find( - (sd) => sd.name === serviceName, + (sd) => sd.name === serviceName ); if (!serviceDefinition) { @@ -406,13 +407,13 @@ export const buildMockTools = async (workflow: Run) => { { key, serviceName, - }, + } ); continue; } const functionDefinition = serviceDefinition.functions?.find( - (f) => f.name === functionName, + (f) => f.name === functionName ); if (!functionDefinition) { @@ -422,7 +423,7 @@ export const buildMockTools = async (workflow: Run) => { key, serviceName, functionName, - }, + } ); continue; } diff --git a/control-plane/src/modules/workflows/agent/tools/cluster-internal-tools.ts b/control-plane/src/modules/workflows/agent/tools/cluster-internal-tools.ts new file mode 100644 index 00000000..73f95e92 --- /dev/null +++ b/control-plane/src/modules/workflows/agent/tools/cluster-internal-tools.ts @@ -0,0 +1,45 @@ +import { DynamicStructuredTool } from "@langchain/core/tools"; +import { Run } from "../../workflows"; +import { + buildAccessKnowledgeArtifacts, + ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME, +} from "./knowledge-artifacts"; +import { createCache } from "../../../../utilities/cache"; +import { getClusterDetails } from "../../../management"; + +const clusterSettingsCache = createCache<{ + enableKnowledgebase: boolean; +}>(Symbol("cluster-settings")); + +const CACHE_TTL = 60 * 2; // 2 minutes + +export type InternalToolBuilder = ( + workflow: Run, + toolCallId: string +) => DynamicStructuredTool | Promise; + +export const getClusterInternalTools = async ( + clusterId: string +): Promise> => { + const cacheKey = `cluster:${clusterId}`; + + let settings = clusterSettingsCache.get(cacheKey); + + if (!settings) { + // Get cluster settings + const cluster = await getClusterDetails({ clusterId }); + settings = { + enableKnowledgebase: cluster.enableKnowledgebase, + }; + clusterSettingsCache.set(cacheKey, settings, CACHE_TTL); + } + + const tools: Record = {}; + + // Only include knowledge artifacts tool if enabled for cluster + if (settings.enableKnowledgebase) { + tools[ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME] = buildAccessKnowledgeArtifacts; + } + + return tools; +};