Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI Assistant] Use semantic_text for internal knowledge base #186499

Merged
merged 34 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7f5d750
[AI Assistant] Use semantic_text for internal knowledge base
sorenlouv Jun 20, 2024
7e318df
Fix status endpoint
sorenlouv Jun 20, 2024
c23e54d
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Aug 5, 2024
6d709c3
Improve logging
sorenlouv Aug 5, 2024
6987fa5
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Aug 6, 2024
2552585
Rename `getModelId` to `getSearchConnectorModelId`
sorenlouv Aug 6, 2024
3e242ab
Support configurable model
sorenlouv Aug 7, 2024
847de5f
WIP
sorenlouv Aug 7, 2024
743d6fb
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Aug 20, 2024
ef3d51e
WIP
sorenlouv Aug 21, 2024
19d1ad9
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Aug 21, 2024
422024e
Add esarchive
sorenlouv Aug 21, 2024
c9d008c
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Nov 7, 2024
035f911
Fix api test
sorenlouv Nov 7, 2024
e567b66
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Nov 7, 2024
d77a411
Fix i18n
sorenlouv Nov 7, 2024
dac5d15
catch and log error
sorenlouv Nov 8, 2024
66dfb6f
fix serverless test
sorenlouv Nov 8, 2024
493b1e3
Add modelId as query param
sorenlouv Nov 12, 2024
36b6e3e
Fix request timeout
sorenlouv Nov 12, 2024
dddeb7e
Merge branch 'main' of github.com:elastic/kibana into use-semantic-te…
sorenlouv Nov 12, 2024
59c61f5
Merge branch 'main' into use-semantic-text-internal-kb
sorenlouv Nov 12, 2024
c6d48d4
Fix imports
sorenlouv Nov 12, 2024
60c0884
Re-add error handling when installing KB
sorenlouv Nov 12, 2024
2501c86
change modelName to modelId
sorenlouv Nov 12, 2024
a24ce1f
Address PR feedback
sorenlouv Nov 12, 2024
ac998fa
Fix type
sorenlouv Nov 12, 2024
17aaaa8
Fix functional test
sorenlouv Nov 13, 2024
36ba7f6
call deleteInferenceEndpoint consistently
sorenlouv Nov 13, 2024
6c23e5d
Merge branch 'main' into use-semantic-text-internal-kb
sorenlouv Nov 13, 2024
657c82a
Fix task manager test
sorenlouv Nov 13, 2024
3919ccd
Address feedback
sorenlouv Nov 13, 2024
97341e7
Remove `payload` timeout
sorenlouv Nov 13, 2024
0753b88
Delete previous endpoint during setup
sorenlouv Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ export function registerContextFunction({
client,
functions,
resources,
isKnowledgeBaseAvailable,
}: FunctionRegistrationParameters & { isKnowledgeBaseAvailable: boolean }) {
isKnowledgeBaseReady,
}: FunctionRegistrationParameters & { isKnowledgeBaseReady: boolean }) {
functions.registerFunction(
{
name: CONTEXT_FUNCTION_NAME,
Expand Down Expand Up @@ -54,7 +54,7 @@ export function registerContextFunction({
...(dataWithinTokenLimit.length ? { data_on_screen: dataWithinTokenLimit } : {}),
};

if (!isKnowledgeBaseAvailable) {
if (!isKnowledgeBaseReady) {
return { content };
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export const registerFunctions: RegistrationCallback = async ({
}.
If the user asks how to change the language, reply in the same language the user asked in.`);

const { ready: isReady } = await client.getKnowledgeBaseStatus();
const { ready: isKnowledgeBaseReady } = await client.getKnowledgeBaseStatus();

functions.registerInstruction(({ availableFunctionNames }) => {
const instructions: string[] = [];
Expand All @@ -83,7 +83,7 @@ export const registerFunctions: RegistrationCallback = async ({
Data that is compact enough automatically gets included in the response for the "${CONTEXT_FUNCTION_NAME}" function.`);
}

if (isReady) {
if (isKnowledgeBaseReady) {
if (availableFunctionNames.includes(SUMMARIZE_FUNCTION_NAME)) {
instructions.push(`You can use the "${SUMMARIZE_FUNCTION_NAME}" function to store new information you have learned in a knowledge database.
Only use this function when the user asks for it.
Expand All @@ -103,11 +103,11 @@ export const registerFunctions: RegistrationCallback = async ({
return instructions.map((instruction) => dedent(instruction));
});

if (isReady) {
if (isKnowledgeBaseReady) {
registerSummarizationFunction(registrationParameters);
}

registerContextFunction({ ...registrationParameters, isKnowledgeBaseAvailable: isReady });
registerContextFunction({ ...registrationParameters, isKnowledgeBaseReady });

registerElasticsearchFunction(registrationParameters);
const request = registrationParameters.resources.request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ export class ObservabilityAIAssistantPlugin
}) as ObservabilityAIAssistantRouteHandlerResources['plugins'];

// Using once to make sure the same model ID is used during service init and Knowledge base setup
const getModelId = once(async () => {
const getSearchConnectorModelId = once(async () => {
const configModelId = this.config.modelId;
if (configModelId) {
return configModelId;
Expand Down Expand Up @@ -156,7 +156,7 @@ export class ObservabilityAIAssistantPlugin
logger: this.logger.get('service'),
core,
taskManager: plugins.taskManager,
getModelId,
getSearchConnectorModelId,
}));

service.register(registerFunctions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
* 2.0.
*/

import type {
MlDeploymentAllocationState,
MlDeploymentState,
} from '@elastic/elasticsearch/lib/api/types';
import { notImplemented } from '@hapi/boom';
import { nonEmptyStringRt, toBooleanRt } from '@kbn/io-ts-utils';
import * as t from 'io-ts';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { InferenceEndpointResponse } from '../../service/create_inference_endpoint';
import {
Instruction,
KnowledgeBaseEntry,
Expand All @@ -25,22 +22,21 @@ const getKnowledgeBaseStatus = createObservabilityAIAssistantServerRoute({
options: {
tags: ['access:ai_assistant'],
},
handler: async (
resources
): Promise<{
ready: boolean;
error?: any;
deployment_state?: MlDeploymentState;
allocation_state?: MlDeploymentAllocationState;
model_name?: string;
}> => {
const client = await resources.service.getClient({ request: resources.request });
handler: async ({
service,
request,
}): Promise<
Partial<InferenceEndpointResponse['endpoints'][0]> & {
ready: boolean;
}
> => {
const client = await service.getClient({ request });

if (!client) {
throw notImplemented();
}

return await client.getKnowledgeBaseStatus();
return client.getKnowledgeBaseStatus();
},
});

Expand All @@ -52,16 +48,48 @@ const setupKnowledgeBase = createObservabilityAIAssistantServerRoute({
idleSocket: 20 * 60 * 1000, // 20 minutes
},
},
handler: async (resources): Promise<{}> => {
handler: async (resources): Promise<unknown> => {
const client = await resources.service.getClient({ request: resources.request });

if (!client) {
throw notImplemented();
}

await client.setupKnowledgeBase();
return await client.setupKnowledgeBase();
},
});

const resetKnowledgeBase = createObservabilityAIAssistantServerRoute({
endpoint: 'POST /internal/observability_ai_assistant/kb/reset',
options: {
tags: ['access:ai_assistant'],
},
handler: async (resources): Promise<{ result: string }> => {
const client = await resources.service.getClient({ request: resources.request });

if (!client) {
throw notImplemented();
}

await client.resetKnowledgeBase();

return { result: 'success' };
},
});

const semanticTextMigrationKnowledgeBase = createObservabilityAIAssistantServerRoute({
endpoint: 'POST /internal/observability_ai_assistant/kb/semantic_text_migration',
options: {
tags: ['access:ai_assistant'],
},
handler: async (resources): Promise<void> => {
const client = await resources.service.getClient({ request: resources.request });

if (!client) {
throw notImplemented();
}

return {};
return client.migrateKnowledgeBaseToSemanticText();
},
});

Expand Down Expand Up @@ -268,7 +296,9 @@ const importKnowledgeBaseEntries = createObservabilityAIAssistantServerRoute({
});

export const knowledgeBaseRoutes = {
...semanticTextMigrationKnowledgeBase,
...setupKnowledgeBase,
...resetKnowledgeBase,
...getKnowledgeBaseStatus,
...getKnowledgeBaseEntries,
...saveKnowledgeBaseUserInstruction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ import {
LangtraceServiceProvider,
withLangtraceChatCompleteSpan,
} from './operators/with_langtrace_chat_complete_span';
import { runSemanticTextKnowledgeBaseMigration } from '../task_manager_definitions/register_migrate_knowledge_base_entries_task';

const MAX_FUNCTION_CALLS = 8;

Expand Down Expand Up @@ -724,11 +725,25 @@ export class ObservabilityAIAssistantClient {
};

getKnowledgeBaseStatus = () => {
return this.dependencies.knowledgeBaseService.status();
return this.dependencies.knowledgeBaseService.getElserModelStatus();
};

setupKnowledgeBase = () => {
return this.dependencies.knowledgeBaseService.setup();
const { esClient } = this.dependencies;
return this.dependencies.knowledgeBaseService.setup(esClient);
};

resetKnowledgeBase = () => {
const { esClient } = this.dependencies;
return this.dependencies.knowledgeBaseService.reset(esClient);
};

migrateKnowledgeBaseToSemanticText = () => {
return runSemanticTextKnowledgeBaseMigration({
esClient: this.dependencies.esClient.asInternalUser,
logger: this.dependencies.logger,
kbService: this.dependencies.knowledgeBaseService,
});
};

addKnowledgeBaseEntry = async ({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import { Logger } from '@kbn/logging';

export const AI_ASSISTANT_KB_INFERENCE_ID = 'ai_assistant_kb_inference';

export async function createInferenceEndpoint({
esClient,
logger,
}: {
esClient: {
asCurrentUser: ElasticsearchClient;
};
logger: Logger;
}) {
try {
logger.debug(`Creating inference endpoint "${AI_ASSISTANT_KB_INFERENCE_ID}"`);
const response = await esClient.asCurrentUser.transport.request({
method: 'PUT',
path: `_inference/sparse_embedding/${AI_ASSISTANT_KB_INFERENCE_ID}`,
body: {
service: 'elser',
service_settings: {
num_allocations: 1,
num_threads: 1,
},
},
sorenlouv marked this conversation as resolved.
Show resolved Hide resolved
});

return response;
} catch (e) {
logger.error(
`Failed to create inference endpoint "${AI_ASSISTANT_KB_INFERENCE_ID}": ${e.message}`
);
throw e;
}
}

export async function deleteInferenceEndpoint({
esClient,
logger,
}: {
esClient: {
asCurrentUser: ElasticsearchClient;
};
logger: Logger;
}) {
try {
const response = await esClient.asCurrentUser.transport.request({
method: 'DELETE',
path: `_inference/sparse_embedding/${AI_ASSISTANT_KB_INFERENCE_ID}`,
querystring: {
force: true, // Deletes the endpoint regardless if it’s used in an inference pipeline or a in a semantic_text field.
},
});

return response;
} catch (e) {
logger.error(`Failed to delete inference endpoint: ${e.message}`);
throw e;
}
}

export interface InferenceEndpointResponse {
endpoints: Array<{
model_id: string;
task_type: string;
service: string;
service_settings: {
num_allocations: number;
num_threads: number;
model_id: string;
};
task_settings: {};
}>;
}

export async function getInferenceEndpoint({
esClient,
logger,
}: {
esClient: { asInternalUser: ElasticsearchClient };
logger: Logger;
}) {
try {
const response = await esClient.asInternalUser.transport.request<InferenceEndpointResponse>({
method: 'GET',
path: `_inference/sparse_embedding/${AI_ASSISTANT_KB_INFERENCE_ID}`,
});

if (response.endpoints.length > 0) {
return response.endpoints[0];
}
} catch (e) {
logger.error(`Failed to fetch inference endpoint: ${e.message}`);
throw e;
}
}
Loading