Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] naturalLanguageToEsql Tool added to default assistant graph #192042

Merged
merged 13 commits into from
Sep 18, 2024
1 change: 1 addition & 0 deletions x-pack/plugins/elastic_assistant/kibana.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"ml",
"taskManager",
"licensing",
"inference",
"spaces",
"security"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export const createMockClients = () => {
getAIAssistantAnonymizationFieldsDataClient: dataClientMock.create(),
getSpaceId: jest.fn(),
getCurrentUser: jest.fn(),
inference: jest.fn(),
},
savedObjectsClient: core.savedObjects.client,

Expand Down Expand Up @@ -130,6 +131,7 @@ const createElasticAssistantRequestContextMock = (
getCurrentUser: jest.fn(),
getServerBasePath: jest.fn(),
getSpaceId: jest.fn(),
inference: { getClient: jest.fn() },
core: clients.core,
telemetry: clients.elasticAssistant.telemetry,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain';
import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common';
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store';
Expand Down Expand Up @@ -47,6 +48,7 @@ export interface AgentExecutorParams<T extends boolean> {
langChainMessages: BaseMessage[];
llmType?: string;
logger: Logger;
inference: InferenceServerStart;
onNewReplacements?: (newReplacements: Replacements) => void;
replacements: Replacements;
isStream?: T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
dataClients,
esClient,
esStore,
inference,
langChainMessages,
llmType,
logger: parentLogger,
Expand Down Expand Up @@ -107,7 +108,9 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
alertsIndexPattern,
anonymizationFields,
chain,
connectorId,
esClient,
inference,
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
logger,
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/elastic_assistant/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export class ElasticAssistantPlugin

return {
actions: plugins.actions,
inference: plugins.inference,
getRegisteredFeatures: (pluginName: string) => {
return appContextService.getRegisteredFeatures(pluginName);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const formatAssistantToolParams = ({
ExecuteConnectorRequestBody | AttackDiscoveryPostRequestBody
>;
size: number;
}): AssistantToolParams => ({
}): Omit<AssistantToolParams, 'connectorId' | 'inference'> => ({
alertsIndexPattern,
anonymizationFields: [...(anonymizationFields ?? []), ...REQUIRED_FOR_ATTACK_DISCOVERY],
isEnabledKnowledgeBase: false, // not required for attack discovery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export const chatCompleteRoute = (
const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']);
const logger: Logger = ctx.elasticAssistant.logger;
telemetry = ctx.elasticAssistant.telemetry;
const inference = ctx.elasticAssistant.inference;

// Perform license and authenticated user checks
const checkResponse = performChecks({
Expand Down Expand Up @@ -195,6 +196,7 @@ export const chatCompleteRoute = (
context: ctx,
getElser,
logger,
inference,
messages: messages ?? [],
onLlmResponse,
onNewReplacements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ export const postEvaluateRoute = (
// Default ELSER model
const elserId = await getElser();

const inference = ctx.elasticAssistant.inference;

// Data clients
const anonymizationFieldsDataClient =
(await assistantContext.getAIAssistantAnonymizationFieldsDataClient()) ?? undefined;
Expand Down Expand Up @@ -260,6 +262,8 @@ export const postEvaluateRoute = (
alertsIndexPattern,
// onNewReplacements,
replacements,
inference,
connectorId: connector.id,
size,
};

Expand Down
4 changes: 4 additions & 0 deletions x-pack/plugins/elastic_assistant/server/routes/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { AwaitedProperties, PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities';
import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base';
import { FindResponse } from '../ai_assistant_data_clients/find';
import { EsPromptsSchema } from '../ai_assistant_data_clients/prompts/types';
Expand Down Expand Up @@ -321,6 +322,7 @@ export interface LangChainExecuteParams {
telemetry: AnalyticsServiceSetup;
actionTypeId: string;
connectorId: string;
inference: InferenceServerStart;
conversationId?: string;
context: AwaitedProperties<
Pick<ElasticAssistantRequestHandlerContext, 'elasticAssistant' | 'licensing' | 'core'>
Expand Down Expand Up @@ -349,6 +351,7 @@ export const langChainExecute = async ({
connectorId,
context,
actionsClient,
inference,
request,
logger,
conversationId,
Expand Down Expand Up @@ -418,6 +421,7 @@ export const langChainExecute = async ({
connectorId,
esClient,
esStore,
inference,
isStream,
llmType: getLlmType(actionTypeId),
langChainMessages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export const postActionsConnectorExecuteRoute = (

// get the actions plugin start contract from the request context:
const actions = ctx.elasticAssistant.actions;
const inference = ctx.elasticAssistant.inference;
const actionsClient = await actions.getActionsClientWithRequest(request);

const conversationsDataClient =
Expand Down Expand Up @@ -132,6 +133,7 @@ export const postActionsConnectorExecuteRoute = (
context: ctx,
getElser,
logger,
inference,
messages: (newMessage ? [newMessage] : messages) ?? [],
onLlmResponse,
onNewReplacements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ export class RequestContextFactory implements IRequestContextFactory {
return appContextService.getRegisteredFeatures(pluginName);
},

inference: startPlugins.inference,

telemetry: core.analytics,

// Note: Due to plugin lifecycle and feature flag registration timing, we need to pass in the feature flag here
Expand Down
9 changes: 9 additions & 0 deletions x-pack/plugins/elastic_assistant/server/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server';

import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AttackDiscoveryDataClient } from './ai_assistant_data_clients/attack_discovery';
import { AIAssistantConversationsDataClient } from './ai_assistant_data_clients/conversations';
import type { GetRegisteredFeatures, GetRegisteredTools } from './services/app_context';
Expand All @@ -64,6 +65,10 @@ export interface ElasticAssistantPluginStart {
* Actions plugin start contract.
*/
actions: ActionsPluginStart;
/**
* Inference plugin start contract.
*/
inference: InferenceServerStart;
/**
* Register features to be used by the elastic assistant.
*
Expand Down Expand Up @@ -104,6 +109,7 @@ export interface ElasticAssistantPluginSetupDependencies {
}
export interface ElasticAssistantPluginStartDependencies {
actions: ActionsPluginStart;
inference: InferenceServerStart;
spaces?: SpacesPluginStart;
security: SecurityServiceStart;
licensing: LicensingPluginStart;
Expand All @@ -125,6 +131,7 @@ export interface ElasticAssistantApiRequestHandlerContext {
getAttackDiscoveryDataClient: () => Promise<AttackDiscoveryDataClient | null>;
getAIAssistantPromptsDataClient: () => Promise<AIAssistantDataClient | null>;
getAIAssistantAnonymizationFieldsDataClient: () => Promise<AIAssistantDataClient | null>;
inference: InferenceServerStart;
telemetry: AnalyticsServiceSetup;
}
/**
Expand Down Expand Up @@ -228,7 +235,9 @@ export type AssistantToolLlm =
export interface AssistantToolParams {
alertsIndexPattern?: string;
anonymizationFields?: AnonymizationFieldResponse[];
inference?: InferenceServerStart;
isEnabledKnowledgeBase: boolean;
connectorId?: string;
chain?: RetrievalQAChain;
esClient: ElasticsearchClient;
kbDataClient?: AIAssistantKnowledgeBaseDataClient;
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/elastic_assistant/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"@kbn/apm-utils",
"@kbn/std",
"@kbn/zod",
"@kbn/inference-plugin"
],
"exclude": [
"target/**/*",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ export const allowedExperimentalValues = Object.freeze({
*/
assistantBedrockChat: true,

/**
* Enables the NaturalLanguageESQLTool and disables the ESQLKnowledgeBaseTool, introduced in `8.16.0`.
*/
assistantNaturalLanguageESQLTool: false,

/**
* Enables the Managed User section inside the new user details flyout.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { DynamicStructuredTool } from '@langchain/core/tools';
import { z } from '@kbn/zod';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { APP_UI_ID } from '../../../../common';

export type ESQLToolParams = AssistantToolParams;

const TOOL_NAME = 'NaturalLanguageESQLTool';

const toolDetails = {
id: 'nl-to-esql-tool',
name: TOOL_NAME,
description: `You MUST use the "${TOOL_NAME}" function when the user wants to:
- run any arbitrary query
- breakdown or filter ES|QL queries that are displayed on the current page
- convert queries from another language to ES|QL
- asks general questions about ES|QL

DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries or explain anything about the ES|QL query language yourself.
DO NOT UNDER ANY CIRCUMSTANCES try to correct an ES|QL query yourself - always use the "${TOOL_NAME}" function for this.

Even if the "${TOOL_NAME}" function was used before that, follow it up with the "${TOOL_NAME}" function. If a query fails, do not attempt to correct it yourself. Again you should call the "${TOOL_NAME}" function,
even if it has been called before.`,
};

export const NL_TO_ESQL_TOOL: AssistantTool = {
...toolDetails,
sourceRegister: APP_UI_ID,
isSupported: (params: ESQLToolParams): params is ESQLToolParams => {
const { chain, isEnabledKnowledgeBase, modelExists } = params;
return isEnabledKnowledgeBase && modelExists && chain != null;
},
getTool(params: ESQLToolParams) {
if (!this.isSupported(params)) return null;

const { connectorId, inference, logger, request } = params as ESQLToolParams;
if (inference == null || connectorId == null) return null;

const callNaturalLanguageToEsql = async (question: string) => {
return lastValueFrom(
naturalLanguageToEsql({
client: inference.getClient({ request }),
connectorId,
input: question,
logger: {
debug: (source) => {
logger.debug(typeof source === 'function' ? source() : source);
},
},
})
);
};

return new DynamicStructuredTool({
name: toolDetails.name,
description: toolDetails.description,
schema: z.object({
question: z.string().describe(`The user's exact question about ESQL`),
}),
func: async (input) => {
const generateEvent = await callNaturalLanguageToEsql(input.question);
const answer = generateEvent.content ?? 'An error occurred in the tool';

logger.debug(`Received response from NL to ESQL tool: ${answer}`);
return answer;
},
tags: ['esql', 'query-generation', 'knowledge-base'],
// TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts
}) as unknown as DynamicStructuredTool;
},
};
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ describe('getAssistantTools', () => {
});

it('should return an array of applicable tools', () => {
const tools = getAssistantTools();
const tools = getAssistantTools(true);

const minExpectedTools = 3; // 3 tools are currently implemented

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@

import type { AssistantTool } from '@kbn/elastic-assistant-plugin/server';

import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool';
import { ESQL_KNOWLEDGE_BASE_TOOL } from './esql_language_knowledge_base/esql_language_knowledge_base_tool';
import { NL_TO_ESQL_TOOL } from './esql_language_knowledge_base/nl_to_esql_tool';
import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool';
import { OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL } from './open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool';
import { ATTACK_DISCOVERY_TOOL } from './attack_discovery/attack_discovery_tool';
import { KNOWLEDGE_BASE_RETRIEVAL_TOOL } from './knowledge_base/knowledge_base_retrieval_tool';
import { KNOWLEDGE_BASE_WRITE_TOOL } from './knowledge_base/knowledge_base_write_tool';

export const getAssistantTools = (): AssistantTool[] => [
export const getAssistantTools = (naturalLanguageESQLToolEnabled: boolean): AssistantTool[] => [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider making naturalLanguageESQLToolEnabled optional and default to false

ALERT_COUNTS_TOOL,
ATTACK_DISCOVERY_TOOL,
ESQL_KNOWLEDGE_BASE_TOOL,
naturalLanguageESQLToolEnabled ? NL_TO_ESQL_TOOL : ESQL_KNOWLEDGE_BASE_TOOL,
KNOWLEDGE_BASE_RETRIEVAL_TOOL,
KNOWLEDGE_BASE_WRITE_TOOL,
OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL,
Expand Down
5 changes: 4 additions & 1 deletion x-pack/plugins/security_solution/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,10 @@ export class Plugin implements ISecuritySolutionPlugin {
this.licensing$ = plugins.licensing.license$;

// Assistant Tool and Feature Registration
plugins.elasticAssistant.registerTools(APP_UI_ID, getAssistantTools());
plugins.elasticAssistant.registerTools(
APP_UI_ID,
getAssistantTools(config.experimentalFeatures.assistantNaturalLanguageESQLTool)
);
plugins.elasticAssistant.registerFeatures(APP_UI_ID, {
assistantBedrockChat: config.experimentalFeatures.assistantBedrockChat,
assistantKnowledgeBaseByDefault: config.experimentalFeatures.assistantKnowledgeBaseByDefault,
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/security_solution/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,6 @@
"@kbn/cloud-security-posture-common",
"@kbn/entityManager-plugin",
"@kbn/entities-schema",
"@kbn/inference-plugin",
]
}