Skip to content

Commit

Permalink
[Security Assistant] Product documentation tool (#199694)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Dec 4, 2024
1 parent f8860e9 commit e099b31
Show file tree
Hide file tree
Showing 22 changed files with 346 additions and 10 deletions.
4 changes: 3 additions & 1 deletion x-pack/plugins/elastic_assistant/kibana.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
"ml",
"taskManager",
"licensing",
"llmTasks",
"inference",
"productDocBase",
"spaces",
"security"
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export const createMockClients = () => {
getSpaceId: jest.fn(),
getCurrentUser: jest.fn(),
inference: jest.fn(),
llmTasks: jest.fn(),
},
savedObjectsClient: core.savedObjects.client,

Expand Down Expand Up @@ -145,6 +146,7 @@ const createElasticAssistantRequestContextMock = (
getServerBasePath: jest.fn(),
getSpaceId: jest.fn().mockReturnValue('default'),
inference: { getClient: jest.fn() },
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
core: clients.core,
telemetry: clients.elasticAssistant.telemetry,
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ensureProductDocumentationInstalled } from './helpers';
import { loggerMock } from '@kbn/logging-mocks';

const mockLogger = loggerMock.create();
const mockProductDocManager = {
getStatus: jest.fn(),
install: jest.fn(),
uninstall: jest.fn(),
update: jest.fn(),
};

describe('helpers', () => {
describe('ensureProductDocumentationInstalled', () => {
beforeEach(() => {
jest.clearAllMocks();
});

it('should install product documentation if not installed', async () => {
mockProductDocManager.getStatus.mockResolvedValue({ status: 'uninstalled' });
mockProductDocManager.install.mockResolvedValue(null);

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockLogger.debug).toHaveBeenCalledWith(
'Installing product documentation for AIAssistantService'
);
expect(mockProductDocManager.install).toHaveBeenCalled();
expect(mockLogger.debug).toHaveBeenNthCalledWith(
2,
'Successfully installed product documentation for AIAssistantService'
);
});

it('should not install product documentation if already installed', async () => {
mockProductDocManager.getStatus.mockResolvedValue({ status: 'installed' });

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockProductDocManager.install).not.toHaveBeenCalled();
expect(mockLogger.debug).not.toHaveBeenCalledWith(
'Installing product documentation for AIAssistantService'
);
});
it('should log a warning if install fails', async () => {
mockProductDocManager.getStatus.mockResolvedValue({ status: 'not_installed' });
mockProductDocManager.install.mockRejectedValue(new Error('Install failed'));

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockProductDocManager.install).toHaveBeenCalled();

expect(mockLogger.warn).toHaveBeenCalledWith(
'Failed to install product documentation for AIAssistantService: Install failed'
);
});

it('should log a warning if getStatus fails', async () => {
mockProductDocManager.getStatus.mockRejectedValue(new Error('Status check failed'));

await ensureProductDocumentationInstalled(mockProductDocManager, mockLogger);

expect(mockProductDocManager.getStatus).toHaveBeenCalled();
expect(mockLogger.warn).toHaveBeenCalledWith(
'Failed to get status of product documentation installation for AIAssistantService: Status check failed'
);
expect(mockProductDocManager.install).not.toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-ser
import type { MlPluginSetup } from '@kbn/ml-plugin/server';
import { DeleteByQueryRequest } from '@elastic/elasticsearch/lib/api/types';
import { i18n } from '@kbn/i18n';
import { ProductDocBaseStartContract } from '@kbn/product-doc-base-plugin/server';
import type { Logger } from '@kbn/logging';
import { getResourceName } from '.';
import { knowledgeBaseIngestPipeline } from '../ai_assistant_data_clients/knowledge_base/ingest_pipeline';
import { GetElser } from '../types';
Expand Down Expand Up @@ -141,3 +143,25 @@ const ESQL_QUERY_GENERATION_TITLE = i18n.translate(
defaultMessage: 'ES|QL Query Generation',
}
);

export const ensureProductDocumentationInstalled = async (
productDocManager: ProductDocBaseStartContract['management'],
logger: Logger
) => {
try {
const { status } = await productDocManager.getStatus();
if (status !== 'installed') {
logger.debug(`Installing product documentation for AIAssistantService`);
try {
await productDocManager.install();
logger.debug(`Successfully installed product documentation for AIAssistantService`);
} catch (e) {
logger.warn(`Failed to install product documentation for AIAssistantService: ${e.message}`);
}
}
} catch (e) {
logger.warn(
`Failed to get status of product documentation installation for AIAssistantService: ${e.message}`
);
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ describe('AI Assistant Service', () => {
kibanaVersion: '8.8.0',
ml,
taskManager: taskManagerMock.createSetup(),
productDocManager: Promise.resolve({
getStatus: jest.fn(),
install: jest.fn(),
update: jest.fn(),
uninstall: jest.fn(),
}),
};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server';
import type { MlPluginSetup } from '@kbn/ml-plugin/server';
import { Subject } from 'rxjs';
import { LicensingApiRequestHandlerContext } from '@kbn/licensing-plugin/server';
import { ProductDocBaseStartContract } from '@kbn/product-doc-base-plugin/server';
import { attackDiscoveryFieldMap } from '../lib/attack_discovery/persistence/field_maps_configuration/field_maps_configuration';
import { defendInsightsFieldMap } from '../ai_assistant_data_clients/defend_insights/field_maps_configuration';
import { getDefaultAnonymizationFields } from '../../common/anonymization';
Expand All @@ -35,7 +36,12 @@ import {
} from '../ai_assistant_data_clients/knowledge_base';
import { AttackDiscoveryDataClient } from '../lib/attack_discovery/persistence';
import { DefendInsightsDataClient } from '../ai_assistant_data_clients/defend_insights';
import { createGetElserId, createPipeline, pipelineExists } from './helpers';
import {
createGetElserId,
createPipeline,
ensureProductDocumentationInstalled,
pipelineExists,
} from './helpers';
import { hasAIAssistantLicense } from '../routes/helpers';

const TOTAL_FIELDS_LIMIT = 2500;
Expand All @@ -51,6 +57,7 @@ export interface AIAssistantServiceOpts {
ml: MlPluginSetup;
taskManager: TaskManagerSetupContract;
pluginStop$: Subject<void>;
productDocManager: Promise<ProductDocBaseStartContract['management']>;
}

export interface CreateAIAssistantClientParams {
Expand Down Expand Up @@ -87,6 +94,7 @@ export class AIAssistantService {
private initPromise: Promise<InitializationPromise>;
private isKBSetupInProgress: boolean = false;
private hasInitializedV2KnowledgeBase: boolean = false;
private productDocManager?: ProductDocBaseStartContract['management'];

constructor(private readonly options: AIAssistantServiceOpts) {
this.initialized = false;
Expand Down Expand Up @@ -129,6 +137,13 @@ export class AIAssistantService {
this.initPromise,
this.installAndUpdateSpaceLevelResources.bind(this)
);
options.productDocManager
.then((productDocManager) => {
this.productDocManager = productDocManager;
})
.catch((error) => {
this.options.logger.warn(`Failed to initialize productDocManager: ${error.message}`);
});
}

public isInitialized() {
Expand Down Expand Up @@ -183,6 +198,11 @@ export class AIAssistantService {
this.options.logger.debug(`Initializing resources for AIAssistantService`);
const esClient = await this.options.elasticsearchClientPromise;

if (this.productDocManager) {
// install product documentation without blocking other resources
void ensureProductDocumentationInstalled(this.productDocManager, this.options.logger);
}

await this.conversationsDataStream.install({
esClient,
logger: this.options.logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { PublicMethodsOf } from '@kbn/utility-types';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AnalyticsServiceSetup } from '@kbn/core-analytics-server';
import { TelemetryParams } from '@kbn/langchain/server/tracers/telemetry/telemetry_tracer';
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import { AIAssistantKnowledgeBaseDataClient } from '../../../ai_assistant_data_clients/knowledge_base';
Expand Down Expand Up @@ -45,10 +46,11 @@ export interface AgentExecutorParams<T extends boolean> {
dataClients?: AssistantDataClients;
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
llmTasks?: LlmTasksPluginStart;
llmType?: string;
isOssModel?: boolean;
logger: Logger;
inference: InferenceServerStart;
logger: Logger;
onNewReplacements?: (newReplacements: Replacements) => void;
replacements: Replacements;
isStream?: T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
esClient,
inference,
langChainMessages,
llmTasks,
llmType,
isOssModel,
logger: parentLogger,
Expand Down Expand Up @@ -106,6 +107,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
inference,
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
llmTasks,
logger,
onNewReplacements,
replacements,
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugins/elastic_assistant/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ export class ElasticAssistantPlugin
elasticsearchClientPromise: core
.getStartServices()
.then(([{ elasticsearch }]) => elasticsearch.client.asInternalUser),
productDocManager: core
.getStartServices()
.then(([_, { productDocBase }]) => productDocBase.management),
pluginStop$: this.pluginStop$,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const mockContext = {
getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures),
logger: loggingSystemMock.createLogger(),
telemetry: { ...coreMock.createSetup().analytics, reportEvent },
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
getCurrentUser: () => ({
username: 'user',
email: 'email',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ export const chatCompleteRoute = (
try {
telemetry = ctx.elasticAssistant.telemetry;
const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;

// Perform license and authenticated user checks
const checkResponse = performChecks({
Expand Down Expand Up @@ -217,6 +219,7 @@ export const chatCompleteRoute = (
response,
telemetry,
responseLanguage: request.body.responseLanguage,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
});
} catch (err) {
const error = transformError(err as Error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ export const postEvaluateRoute = (
const esClient = ctx.core.elasticsearch.client.asCurrentUser;

const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;

// Data clients
const anonymizationFieldsDataClient =
Expand Down Expand Up @@ -280,6 +282,7 @@ export const postEvaluateRoute = (
connectorId: connector.id,
size,
telemetry: ctx.elasticAssistant.telemetry,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
};

const tools: StructuredTool[] = assistantTools.flatMap(
Expand Down
4 changes: 4 additions & 0 deletions x-pack/plugins/elastic_assistant/server/routes/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { ActionsClient } from '@kbn/actions-plugin/server';
import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities';
import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { INVOKE_ASSISTANT_SUCCESS_EVENT } from '../lib/telemetry/event_based_telemetry';
import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base';
import { FindResponse } from '../ai_assistant_data_clients/find';
Expand Down Expand Up @@ -215,6 +216,7 @@ export interface LangChainExecuteParams {
telemetry: AnalyticsServiceSetup;
actionTypeId: string;
connectorId: string;
llmTasks?: LlmTasksPluginStart;
inference: InferenceServerStart;
isOssModel?: boolean;
conversationId?: string;
Expand Down Expand Up @@ -246,6 +248,7 @@ export const langChainExecute = async ({
isOssModel,
context,
actionsClient,
llmTasks,
inference,
request,
logger,
Expand Down Expand Up @@ -301,6 +304,7 @@ export const langChainExecute = async ({
conversationId,
connectorId,
esClient,
llmTasks,
inference,
isStream,
llmType: getLlmType(actionTypeId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const mockContext = {
actions: {
getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient),
},
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
getRegisteredTools: jest.fn(() => []),
getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures),
logger: loggingSystemMock.createLogger(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ export const postActionsConnectorExecuteRoute = (
// get the actions plugin start contract from the request context:
const actions = ctx.elasticAssistant.actions;
const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
const connector = connectors.length > 0 ? connectors[0] : undefined;
Expand Down Expand Up @@ -150,6 +152,7 @@ export const postActionsConnectorExecuteRoute = (
response,
telemetry,
systemPrompt,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
});
} catch (err) {
logger.error(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export class RequestContextFactory implements IRequestContextFactory {
getRegisteredFeatures: (pluginName: string) => {
return appContextService.getRegisteredFeatures(pluginName);
},

llmTasks: startPlugins.llmTasks,
inference: startPlugins.inference,

telemetry: core.analytics,
Expand Down
Loading

0 comments on commit e099b31

Please sign in to comment.