From 51f9eedda5b6257280d9c679145f980e2da23f6a Mon Sep 17 00:00:00 2001 From: Joe McElroy Date: Wed, 22 May 2024 11:37:02 +0100 Subject: [PATCH] [Search] [Chat Playground] handle when the ActionLLM is not a ChatModel (#183931) ## Summary Two action based LLMs: `ActionsClientChatOpenAI` and `ActionsClientLlm`. `ActionsClientChatOpenAI` is based on the ChatModel LLM, `ActionsClientLlm` is a prompt based model. The callbacks are different when using a ChatModel vs LLMModel. Token count is done on the ChatModelStart callback. This meant the token count didn't happen for LLMModel based actions (Bedrock). To fix i listen on both callbacks. ### Checklist Delete any items that are not applicable to this PR. - [ ] Any text added follows [EUI's writing guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses sentence case text and includes [i18n support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md) - [ ] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios - [ ] [Flaky Test Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was used on any tests changed - [ ] Any UI touched in this PR is usable by keyboard only (learn more about [keyboard accessibility](https://webaim.org/techniques/keyboard/)) - [ ] Any UI touched in this PR does not create any new axe failures (run axe in browser: [FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/), [Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US)) - [ ] If a plugin configuration key changed, check if it needs to be allowlisted in the cloud and added to the [docker list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker) - [ ] This renders correctly on smaller devices using a responsive layout. (You can test this [in your browser](https://www.browserstack.com/guide/responsive-testing-on-local-server)) - [ ] This was checked for [cross-browser compatibility](https://www.elastic.co/support/matrix#matrix_browsers) --- .../server/lib/conversational_chain.test.ts | 154 ++++++++++++------ .../server/lib/conversational_chain.ts | 10 ++ 2 files changed, 116 insertions(+), 48 deletions(-) diff --git a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts index 8d4460f137736..8d67f6f03b8d1 100644 --- a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts +++ b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts @@ -9,19 +9,30 @@ import type { Client } from '@elastic/elasticsearch'; import { createAssist as Assist } from '../utils/assist'; import { ConversationalChain } from './conversational_chain'; import { FakeListChatModel } from '@langchain/core/utils/testing'; +import { FakeListLLM } from 'langchain/llms/fake'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { Message } from 'ai'; describe('conversational chain', () => { - const createTestChain = async ( - responses: string[], - chat: Message[], - expectedFinalAnswer: string, - expectedDocs: any, - expectedTokens: any, - expectedSearchRequest: any, - contentField: Record = { index: 'field', website: 'body_content' } - ) => { + const createTestChain = async ({ + responses, + chat, + expectedFinalAnswer, + expectedDocs, + expectedTokens, + expectedSearchRequest, + contentField = { index: 'field', website: 'body_content' }, + isChatModel = true, + }: { + responses: string[]; + chat: Message[]; + expectedFinalAnswer: string; + expectedDocs: any; + expectedTokens: any; + expectedSearchRequest: any; + contentField?: Record; + isChatModel?: boolean; + }) => { const searchMock = jest.fn().mockImplementation(() => { return { hits: { @@ -54,9 +65,11 @@ describe('conversational chain', () => { }, }; - const llm = new FakeListChatModel({ - responses, - }); + const llm = isChatModel + ? new FakeListChatModel({ + responses, + }) + : new FakeListLLM({ responses }); const aiClient = Assist({ es_client: mockElasticsearchClient as unknown as Client, @@ -118,17 +131,17 @@ describe('conversational chain', () => { }; it('should be able to create a conversational chain', async () => { - await createTestChain( - ['the final answer'], - [ + await createTestChain({ + responses: ['the final answer'], + chat: [ { id: '1', role: 'user', content: 'what is the work from home policy?', }, ], - 'the final answer', - [ + expectedFinalAnswer: 'the final answer', + expectedDocs: [ { documents: [ { metadata: { _id: '1', _index: 'index' }, pageContent: 'value' }, @@ -137,32 +150,32 @@ describe('conversational chain', () => { type: 'retrieved_docs', }, ], - [ + expectedTokens: [ { type: 'context_token_count', count: 15 }, { type: 'prompt_token_count', count: 5 }, ], - [ + expectedSearchRequest: [ { method: 'POST', path: '/index,website/_search', body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 }, }, - ] - ); + ], + }); }); it('should be able to create a conversational chain with nested field', async () => { - await createTestChain( - ['the final answer'], - [ + await createTestChain({ + responses: ['the final answer'], + chat: [ { id: '1', role: 'user', content: 'what is the work from home policy?', }, ], - 'the final answer', - [ + expectedFinalAnswer: 'the final answer', + expectedDocs: [ { documents: [ { metadata: { _id: '1', _index: 'index' }, pageContent: 'value' }, @@ -171,25 +184,25 @@ describe('conversational chain', () => { type: 'retrieved_docs', }, ], - [ + expectedTokens: [ { type: 'context_token_count', count: 15 }, { type: 'prompt_token_count', count: 5 }, ], - [ + expectedSearchRequest: [ { method: 'POST', path: '/index,website/_search', body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 }, }, ], - { index: 'field', website: 'metadata.source' } - ); + contentField: { index: 'field', website: 'metadata.source' }, + }); }); it('asking with chat history should re-write the question', async () => { - await createTestChain( - ['rewrite the question', 'the final answer'], - [ + await createTestChain({ + responses: ['rewrite the question', 'the final answer'], + chat: [ { id: '1', role: 'user', @@ -206,8 +219,8 @@ describe('conversational chain', () => { content: 'what is the work from home policy?', }, ], - 'the final answer', - [ + expectedFinalAnswer: 'the final answer', + expectedDocs: [ { documents: [ { metadata: { _id: '1', _index: 'index' }, pageContent: 'value' }, @@ -216,24 +229,24 @@ describe('conversational chain', () => { type: 'retrieved_docs', }, ], - [ + expectedTokens: [ { type: 'context_token_count', count: 15 }, { type: 'prompt_token_count', count: 5 }, ], - [ + expectedSearchRequest: [ { method: 'POST', path: '/index,website/_search', body: { query: { match: { field: 'rewrite the question' } }, size: 3 }, }, - ] - ); + ], + }); }); it('should cope with quotes in the query', async () => { - await createTestChain( - ['rewrite "the" question', 'the final answer'], - [ + await createTestChain({ + responses: ['rewrite "the" question', 'the final answer'], + chat: [ { id: '1', role: 'user', @@ -250,8 +263,8 @@ describe('conversational chain', () => { content: 'what is the work from home policy?', }, ], - 'the final answer', - [ + expectedFinalAnswer: 'the final answer', + expectedDocs: [ { documents: [ { metadata: { _id: '1', _index: 'index' }, pageContent: 'value' }, @@ -260,17 +273,62 @@ describe('conversational chain', () => { type: 'retrieved_docs', }, ], - [ + expectedTokens: [ { type: 'context_token_count', count: 15 }, { type: 'prompt_token_count', count: 5 }, ], - [ + expectedSearchRequest: [ { method: 'POST', path: '/index,website/_search', body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 }, }, - ] - ); + ], + }); + }); + + it('should work with an LLM based model', async () => { + await createTestChain({ + responses: ['rewrite "the" question', 'the final answer'], + chat: [ + { + id: '1', + role: 'user', + content: 'what is the work from home policy?', + }, + { + id: '2', + role: 'assistant', + content: 'the final answer', + }, + { + id: '3', + role: 'user', + content: 'what is the work from home policy?', + }, + ], + expectedFinalAnswer: 'the final answer', + expectedDocs: [ + { + documents: [ + { metadata: { _id: '1', _index: 'index' }, pageContent: 'value' }, + { metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' }, + ], + type: 'retrieved_docs', + }, + ], + expectedTokens: [ + { type: 'context_token_count', count: 15 }, + { type: 'prompt_token_count', count: 7 }, + ], + expectedSearchRequest: [ + { + method: 'POST', + path: '/index,website/_search', + body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 }, + }, + ], + isChatModel: false, + }); }); }); diff --git a/x-pack/plugins/search_playground/server/lib/conversational_chain.ts b/x-pack/plugins/search_playground/server/lib/conversational_chain.ts index 1ec7bfb20c017..6080557e4c68f 100644 --- a/x-pack/plugins/search_playground/server/lib/conversational_chain.ts +++ b/x-pack/plugins/search_playground/server/lib/conversational_chain.ts @@ -150,6 +150,7 @@ class ConversationalChainFn { { callbacks: [ { + // callback for chat based models (OpenAI) handleChatModelStart( llm, msg: BaseMessage[][], @@ -166,6 +167,15 @@ class ConversationalChainFn { }); } }, + // callback for prompt based models (Bedrock uses ActionsClientLlm) + handleLLMStart(llm, input, runId, parentRunId, extraParams, tags, metadata) { + if (metadata?.type === 'question_answer_qa') { + data.appendMessageAnnotation({ + type: 'prompt_token_count', + count: getTokenEstimate(input[0]), + }); + } + }, handleRetrieverEnd(documents) { retrievedDocs.push(...documents); data.appendMessageAnnotation({