From 90f7e12c2af2edf2abf2880be6e5bed71a8befd7 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Fri, 22 Nov 2024 05:52:40 -0700 Subject: [PATCH] [Security solution] Fix gemini streaming (#201299) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes streaming for Gemini in Security Assistant. Content can appear in the `finishReason` block. I'm not sure when this started happening. Updates our streaming logic to support content being in the `finishReason` block. Example of `finishReason` block with content: ``` `data: {"candidates": [{"content": {"role": "model","parts": [{"text": " are 170 critical and 20 high open alerts."}]},"finishReason": "STOP","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE","probabilityScore": 0.16776322,"severity": "HARM_SEVERITY_LOW","severityScore": 0.37113687},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}` ``` ## To test 1. Have alerts 2. Have a Gemini connector (`gemini-1.5-pro-002`) 3. Have streaming on in the assistant with the Gemini connector selected 4. Ask the assistant: "How many open alerts do I have?" ### Previously A response begin to streams and then the response gets cut off. Screenshot 2024-11-21 at 4 18 06 PM ### Now The response streams in full as expected. Screenshot 2024-11-21 at 4 25 13 PM --------- Co-authored-by: Elastic Machine --- .../chat_vertex/chat_vertex.test.ts | 122 ++++++++++++++++++ .../chat_vertex/chat_vertex.ts | 13 +- .../server/language_models/gemini_chat.ts | 12 +- 3 files changed, 133 insertions(+), 14 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts index 07fe252bd5074..69086ffcd108b 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts @@ -13,6 +13,7 @@ import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messag import { ActionsClientChatVertexAI } from './chat_vertex'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; import { GeminiContent } from '@langchain/google-common'; +import { FinishReason } from '@google/generative-ai'; const connectorId = 'mock-connector-id'; @@ -55,6 +56,74 @@ const mockStreamExecute = jest.fn().mockImplementation(() => { }; }); +const mockStreamExecuteWithGoodStopEvents = jest.fn().mockImplementation(() => { + const passThrough = new PassThrough(); + + // Write the data chunks to the stream + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]},"finishReason": "${FinishReason.STOP}","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE","probabilityScore": 0.16776322,"severity": "HARM_SEVERITY_LOW","severityScore": 0.37113687},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}` + ) + ); + // End the stream + passThrough.end(); + }); + + return { + data: passThrough, // PassThrough stream will act as the async iterator + status: 'ok', + }; +}); + +const mockStreamExecuteWithBadStopEvents = jest.fn().mockImplementation(() => { + const passThrough = new PassThrough(); + + // Write the data chunks to the stream + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]},"finishReason": "${FinishReason.SAFETY}","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "HIGH","probabilityScore": 0.96776322,"severity": "HARM_SEVERITY_HIGH","severityScore": 0.97113687,"blocked":true},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}` + ) + ); + // End the stream + passThrough.end(); + }); + + return { + data: passThrough, // PassThrough stream will act as the async iterator + status: 'ok', + }; +}); + const systemInstruction = 'Answer the following questions truthfully and as best you can.'; const callMessages = [ @@ -198,6 +267,59 @@ describe('ActionsClientChatVertexAI', () => { expect(handleLLMNewToken).toHaveBeenCalledWith('token2'); expect(handleLLMNewToken).toHaveBeenCalledWith('token3'); }); + it('includes tokens from finishReason: STOP', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecuteWithGoodStopEvents); + + const actionsClientChatVertexAI = new ActionsClientChatVertexAI({ + ...defaultArgs, + actionsClient, + streaming: true, + }); + + const gen = actionsClientChatVertexAI._streamResponseChunks( + callMessages, + callOptions, + callRunManager + ); + + const chunks = []; + + for await (const chunk of gen) { + chunks.push(chunk); + } + + expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']); + expect(handleLLMNewToken).toHaveBeenCalledTimes(3); + expect(handleLLMNewToken).toHaveBeenCalledWith('token1'); + expect(handleLLMNewToken).toHaveBeenCalledWith('token2'); + expect(handleLLMNewToken).toHaveBeenCalledWith('token3'); + }); + it('throws an error on bad stop events', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecuteWithBadStopEvents); + + const actionsClientChatVertexAI = new ActionsClientChatVertexAI({ + ...defaultArgs, + actionsClient, + streaming: true, + }); + + const gen = actionsClientChatVertexAI._streamResponseChunks( + callMessages, + callOptions, + callRunManager + ); + + const chunks = []; + await expect(async () => { + for await (const chunk of gen) { + chunks.push(chunk); + } + }).rejects.toEqual( + Error( + `Gemini Utils: action result status is error. Candidate was blocked due to SAFETY - HARM_CATEGORY_DANGEROUS_CONTENT: HARM_SEVERITY_HIGH` + ) + ); + }); }); describe('message formatting', () => { diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts index 745c273c79583..7cea2d421a9da 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts @@ -130,7 +130,12 @@ export class ActionsClientChatVertexAI extends ChatVertexAI { partialStreamChunk += nextChunk; } - if (parsedStreamChunk !== null && !parsedStreamChunk.candidates?.[0]?.finishReason) { + if (parsedStreamChunk !== null) { + const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk); + if (errorMessage != null) { + throw new Error(errorMessage); + } + const response = { ...parsedStreamChunk, functionCalls: () => @@ -178,12 +183,6 @@ export class ActionsClientChatVertexAI extends ChatVertexAI { yield chunk; await runManager?.handleLLMNewToken(chunk.text ?? ''); } - } else if (parsedStreamChunk) { - // handle bad finish reason - const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk); - if (errorMessage != null) { - throw new Error(errorMessage); - } } } } diff --git a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts index 700e26d5a0a14..f8755af19d78c 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/gemini_chat.ts @@ -199,7 +199,11 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { partialStreamChunk += nextChunk; } - if (parsedStreamChunk !== null && !parsedStreamChunk.candidates?.[0]?.finishReason) { + if (parsedStreamChunk !== null) { + const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk); + if (errorMessage != null) { + throw new Error(errorMessage); + } const response = { ...parsedStreamChunk, functionCalls: () => @@ -247,12 +251,6 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI { yield chunk; await runManager?.handleLLMNewToken(chunk.text ?? ''); } - } else if (parsedStreamChunk) { - // handle bad finish reason - const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk); - if (errorMessage != null) { - throw new Error(errorMessage); - } } } }