diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts index 07fe252bd5074..69086ffcd108b 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts @@ -13,6 +13,7 @@ import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messag import { ActionsClientChatVertexAI } from './chat_vertex'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; import { GeminiContent } from '@langchain/google-common'; +import { FinishReason } from '@google/generative-ai'; const connectorId = 'mock-connector-id'; @@ -55,6 +56,74 @@ const mockStreamExecute = jest.fn().mockImplementation(() => { }; }); +const mockStreamExecuteWithGoodStopEvents = jest.fn().mockImplementation(() => { + const passThrough = new PassThrough(); + + // Write the data chunks to the stream + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]},"finishReason": "${FinishReason.STOP}","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE","probabilityScore": 0.16776322,"severity": "HARM_SEVERITY_LOW","severityScore": 0.37113687},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}` + ) + ); + // End the stream + passThrough.end(); + }); + + return { + data: passThrough, // PassThrough stream will act as the async iterator + status: 'ok', + }; +}); + +const mockStreamExecuteWithBadStopEvents = jest.fn().mockImplementation(() => { + const passThrough = new PassThrough(); + + // Write the data chunks to the stream + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}` + ) + ); + }); + setTimeout(() => { + passThrough.write( + Buffer.from( + `data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]},"finishReason": "${FinishReason.SAFETY}","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "HIGH","probabilityScore": 0.96776322,"severity": "HARM_SEVERITY_HIGH","severityScore": 0.97113687,"blocked":true},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}` + ) + ); + // End the stream + passThrough.end(); + }); + + return { + data: passThrough, // PassThrough stream will act as the async iterator + status: 'ok', + }; +}); + const systemInstruction = 'Answer the following questions truthfully and as best you can.'; const callMessages = [ @@ -198,6 +267,59 @@ describe('ActionsClientChatVertexAI', () => { expect(handleLLMNewToken).toHaveBeenCalledWith('token2'); expect(handleLLMNewToken).toHaveBeenCalledWith('token3'); }); + it('includes tokens from finishReason: STOP', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecuteWithGoodStopEvents); + + const actionsClientChatVertexAI = new ActionsClientChatVertexAI({ + ...defaultArgs, + actionsClient, + streaming: true, + }); + + const gen = actionsClientChatVertexAI._streamResponseChunks( + callMessages, + callOptions, + callRunManager + ); + + const chunks = []; + + for await (const chunk of gen) { + chunks.push(chunk); + } + + expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']); + expect(handleLLMNewToken).toHaveBeenCalledTimes(3); + expect(handleLLMNewToken).toHaveBeenCalledWith('token1'); + expect(handleLLMNewToken).toHaveBeenCalledWith('token2'); + expect(handleLLMNewToken).toHaveBeenCalledWith('token3'); + }); + it('throws an error on bad stop events', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecuteWithBadStopEvents); + + const actionsClientChatVertexAI = new ActionsClientChatVertexAI({ + ...defaultArgs, + actionsClient, + streaming: true, + }); + + const gen = actionsClientChatVertexAI._streamResponseChunks( + callMessages, + callOptions, + callRunManager + ); + + const chunks = []; + await expect(async () => { + for await (const chunk of gen) { + chunks.push(chunk); + } + }).rejects.toEqual( + Error( + `Gemini Utils: action result status is error. Candidate was blocked due to SAFETY - HARM_CATEGORY_DANGEROUS_CONTENT: HARM_SEVERITY_HIGH` + ) + ); + }); }); describe('message formatting', () => { diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts index 745c273c79583..7cea2d421a9da 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.ts @@ -130,7 +130,12 @@ export class ActionsClientChatVertexAI extends ChatVertexAI { partialStreamChunk += nextChunk; } - if (parsedStreamChunk !== null && !parsedStreamChunk.candidates?.[0]?.finishReason) { + if (parsedStreamChunk !== null) { + const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk); + if (errorMessage != null) { + throw new Error(errorMessage); + } + const response = { ...parsedStreamChunk, functionCalls: () => @@ -178,12 +183,6 @@ export class ActionsClientChatVertexAI extends ChatVertexAI { yield chunk; await runManager?.handleLLMNewToken(chunk.text ?? ''); } - } else if (parsedStreamChunk) { - // handle bad finish reason - const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk); - if (errorMessage != null) { - throw new Error(errorMessage); - } } } }