From 80f454ded76a90505212aa7d6b60df9c1ccc8785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Tue, 2 Jul 2024 09:57:42 +0200 Subject: [PATCH] [8.14] [Obs AI Assistant] Boost user prompt in recall (#184933) (#187313) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Backport This will backport the following commits from `main` to `8.14`: - [[Obs AI Assistant] Boost user prompt in recall (#184933)](https://github.com/elastic/kibana/pull/184933) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) --- .../public/service/create_service.ts | 4 -- .../server/functions/context.ts | 69 +++++-------------- .../server/routes/functions/route.ts | 11 ++- .../get_context_function_request_if_needed.ts | 4 -- .../server/service/client/index.test.ts | 2 - .../server/service/client/index.ts | 2 +- .../service/knowledge_base_service/index.ts | 14 ++-- .../public/components/chat/chat_body.test.tsx | 4 +- .../tests/complete/complete.spec.ts | 1 - .../public_complete/public_complete.spec.ts | 14 +++- .../tests/conversations/index.spec.ts | 4 +- 11 files changed, 52 insertions(+), 77 deletions(-) diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts index 9e0adc5a94d8f..7232078d2efe8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts @@ -72,10 +72,6 @@ export function createService({ return of( createFunctionRequestMessage({ name: 'context', - args: { - queries: [], - categories: [], - }, }), createFunctionResponseMessage({ name: 'context', diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts index 4a347c2710ef4..bd5f84a7e515d 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts @@ -38,34 +38,10 @@ export function registerContextFunction({ description: 'This function provides context as to what the user is looking at on their screen, and recalled documents from the knowledge base that matches their query', visibility: FunctionVisibility.Internal, - parameters: { - type: 'object', - properties: { - queries: { - type: 'array', - description: 'The query for the semantic search', - items: { - type: 'string', - }, - }, - categories: { - type: 'array', - description: - 'Categories of internal documentation that you want to search for. By default internal documentation will be excluded. Use `apm` to get internal APM documentation, `lens` to get internal Lens documentation, or both.', - items: { - type: 'string', - enum: ['apm', 'lens'], - }, - }, - }, - required: ['queries', 'categories'], - } as const, }, - async ({ arguments: args, messages, screenContexts, chat }, signal) => { + async ({ messages, screenContexts, chat }, signal) => { const { analytics } = (await resources.context.core).coreStart; - const { queries, categories } = args; - async function getContext() { const screenDescription = compact( screenContexts.map((context) => context.screenDescription) @@ -92,30 +68,21 @@ export function registerContextFunction({ messages.filter((message) => message.message.role === MessageRole.User) ); - const nonEmptyQueries = compact(queries); - - const queriesOrUserPrompt = nonEmptyQueries.length - ? nonEmptyQueries - : compact([userMessage?.message.content]); - - queriesOrUserPrompt.push(screenDescription); - - const suggestions = await retrieveSuggestions({ - client, - categories, - queries: queriesOrUserPrompt, - }); + const userPrompt = userMessage?.message.content; + const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter( + ({ text }) => text + ) as Array<{ text: string; boost?: number }>; + const suggestions = await retrieveSuggestions({ client, queries }); if (suggestions.length === 0) { - return { - content, - }; + return { content }; } try { const { relevantDocuments, scores } = await scoreSuggestions({ suggestions, - queries: queriesOrUserPrompt, + screenDescription, + userPrompt, messages, chat, signal, @@ -123,7 +90,7 @@ export function registerContextFunction({ }); analytics.reportEvent(RecallRankingEventType, { - prompt: queriesOrUserPrompt.join('|'), + prompt: queries.map((query) => query.text).join('|'), scoredDocuments: suggestions.map((suggestion) => { const llmScore = scores.find((score) => score.id === suggestion.id); return { @@ -176,15 +143,12 @@ export function registerContextFunction({ async function retrieveSuggestions({ queries, client, - categories, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; client: ObservabilityAIAssistantClient; - categories: Array<'apm' | 'lens'>; }) { const recallResponse = await client.recall({ queries, - categories, }); return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction')); @@ -206,14 +170,16 @@ const scoreFunctionArgumentsRt = t.type({ async function scoreSuggestions({ suggestions, messages, - queries, + userPrompt, + screenDescription, chat, signal, logger, }: { suggestions: Awaited>; messages: Message[]; - queries: string[]; + userPrompt: string | undefined; + screenDescription: string; chat: FunctionCallChatFunction; signal: AbortSignal; logger: Logger; @@ -235,7 +201,10 @@ async function scoreSuggestions({ - The document contains new information not mentioned before in the conversation Question: - ${queries.join('\n')} + ${userPrompt} + + Screen description: + ${screenDescription} Documents: ${JSON.stringify(indexedSuggestions, null, 2)}`); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts index 58c93737b6617..5922440bc7ae1 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts @@ -65,7 +65,16 @@ const functionRecallRoute = createObservabilityAIAssistantServerRoute({ params: t.type({ body: t.intersection([ t.type({ - queries: t.array(nonEmptyStringRt), + queries: t.array( + t.intersection([ + t.type({ + text: t.string, + }), + t.partial({ + boost: t.number, + }), + ]) + ), }), t.partial({ categories: t.array(t.string), diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts index 8f05cf144a33b..aa1dc65576784 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts @@ -27,9 +27,5 @@ export function getContextFunctionRequestIfNeeded( return createFunctionRequestMessage({ name: 'context', - args: { - queries: [], - categories: [], - }, }); } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index ba6cc9fe8a5cf..ccb34e777e29e 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -1225,7 +1225,6 @@ describe('Observability AI Assistant client', () => { role: MessageRole.Assistant, function_call: { name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, @@ -1449,7 +1448,6 @@ describe('Observability AI Assistant client', () => { role: MessageRole.Assistant, function_call: { name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index 0dc38698faa89..46bc16fdf85d1 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -650,7 +650,7 @@ export class ObservabilityAIAssistantClient { queries, categories, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; categories?: string[]; }): Promise<{ entries: RecalledEntry[] }> => { return this.dependencies.knowledgeBaseService.recall({ diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts index e0ba8bd48d478..273815038b67d 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts @@ -303,7 +303,7 @@ export class KnowledgeBaseService { user, modelId, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; categories?: string[]; namespace: string; user?: { name: string }; @@ -311,11 +311,12 @@ export class KnowledgeBaseService { }): Promise { const query = { bool: { - should: queries.map((text) => ({ + should: queries.map(({ text, boost = 1 }) => ({ text_expansion: { 'ml.tokens': { model_text: text, model_id: modelId, + boost, }, }, })), @@ -352,7 +353,7 @@ export class KnowledgeBaseService { asCurrentUser, modelId, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; asCurrentUser: ElasticsearchClient; modelId: string; }): Promise { @@ -378,15 +379,16 @@ export class KnowledgeBaseService { const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`; const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`; - return queries.map((query) => { + return queries.map(({ text, boost = 1 }) => { return { bool: { should: [ { text_expansion: { [vectorField]: { - model_text: query, + model_text: text, model_id: modelId, + boost, }, }, }, @@ -431,7 +433,7 @@ export class KnowledgeBaseService { namespace, asCurrentUser, }: { - queries: string[]; + queries: Array<{ text: string; boost?: number }>; categories?: string[]; user?: { name: string }; namespace: string; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx index cfb85f7945240..9e36c0d64bb60 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx @@ -39,7 +39,7 @@ describe('', () => { role: 'assistant', function_call: { name: 'context', - arguments: '{"queries":[],"categories":[]}', + arguments: '{}', trigger: 'assistant', }, content: '', @@ -87,7 +87,7 @@ describe('', () => { role: 'assistant', function_call: { name: 'context', - arguments: '{"queries":[],"categories":[]}', + arguments: '{}', trigger: 'assistant', }, content: '', diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts index 38303c3a53076..4c1a7cd4df585 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts @@ -193,7 +193,6 @@ export default function ApiTest({ getService }: FtrProviderContext) { role: MessageRole.Assistant, function_call: { name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts index ac2fa36f6b0fd..f496e42868ac8 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts @@ -72,6 +72,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { format, }) .set('kbn-xsrf', 'foo') + .set('elastic-api-version', '2023-10-31') .send({ messages, connectorId, @@ -83,13 +84,20 @@ export default function ApiTest({ getService }: FtrProviderContext) { if (err) { return reject(err); } + if (response.status !== 200) { + return reject(new Error(`${response.status}: ${JSON.stringify(response.body)}`)); + } return resolve(response); }); }); - const [conversationSimulator, titleSimulator] = await Promise.all([ - conversationInterceptor.waitForIntercept(), - titleInterceptor.waitForIntercept(), + const [conversationSimulator, titleSimulator] = await Promise.race([ + Promise.all([ + conversationInterceptor.waitForIntercept(), + titleInterceptor.waitForIntercept(), + ]), + // make sure any request failures (like 400s) are properly propagated + responsePromise.then(() => []), ]); await titleSimulator.status(200); diff --git a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts index 670903591287f..a8decc9f3337d 100644 --- a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts +++ b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts @@ -94,7 +94,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte content: '', function_call: { name: 'context', - arguments: '{"queries":[],"categories":[]}', + arguments: '{}', trigger: MessageRole.Assistant, }, }, @@ -290,7 +290,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({ name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), }); expect(contextResponse.name).to.eql('context'); @@ -354,7 +353,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({ name: 'context', - arguments: JSON.stringify({ queries: [], categories: [] }), }); expect(contextResponse.name).to.eql('context');