Skip to content

Commit

Permalink
[8.16] [Security solution] Fix gemini streaming (#201299) (#201372)
Browse files Browse the repository at this point in the history
  • Loading branch information
kibanamachine authored Dec 2, 2024
1 parent c4d00b3 commit f8edfb0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messag
import { ActionsClientChatVertexAI } from './chat_vertex';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { GeminiContent } from '@langchain/google-common';
import { FinishReason } from '@google/generative-ai';

const connectorId = 'mock-connector-id';

Expand Down Expand Up @@ -55,6 +56,74 @@ const mockStreamExecute = jest.fn().mockImplementation(() => {
};
});

const mockStreamExecuteWithGoodStopEvents = jest.fn().mockImplementation(() => {
const passThrough = new PassThrough();

// Write the data chunks to the stream
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]},"finishReason": "${FinishReason.STOP}","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE","probabilityScore": 0.16776322,"severity": "HARM_SEVERITY_LOW","severityScore": 0.37113687},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}`
)
);
// End the stream
passThrough.end();
});

return {
data: passThrough, // PassThrough stream will act as the async iterator
status: 'ok',
};
});

const mockStreamExecuteWithBadStopEvents = jest.fn().mockImplementation(() => {
const passThrough = new PassThrough();

// Write the data chunks to the stream
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]},"finishReason": "${FinishReason.SAFETY}","safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE","probabilityScore": 0.060086742,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17106095},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "HIGH","probabilityScore": 0.96776322,"severity": "HARM_SEVERITY_HIGH","severityScore": 0.97113687,"blocked":true},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE","probabilityScore": 0.124212936,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.17441037},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE","probabilityScore": 0.05419875,"severity": "HARM_SEVERITY_NEGLIGIBLE","severityScore": 0.03461887}]}],"usageMetadata": {"promptTokenCount": 1062,"candidatesTokenCount": 15,"totalTokenCount": 1077},"modelVersion": "gemini-1.5-pro-002"}`
)
);
// End the stream
passThrough.end();
});

return {
data: passThrough, // PassThrough stream will act as the async iterator
status: 'ok',
};
});

const systemInstruction = 'Answer the following questions truthfully and as best you can.';

const callMessages = [
Expand Down Expand Up @@ -198,6 +267,59 @@ describe('ActionsClientChatVertexAI', () => {
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
it('includes tokens from finishReason: STOP', async () => {
actionsClient.execute.mockImplementationOnce(mockStreamExecuteWithGoodStopEvents);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI({
...defaultArgs,
actionsClient,
streaming: true,
});

const gen = actionsClientChatVertexAI._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);

const chunks = [];

for await (const chunk of gen) {
chunks.push(chunk);
}

expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
it('throws an error on bad stop events', async () => {
actionsClient.execute.mockImplementationOnce(mockStreamExecuteWithBadStopEvents);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI({
...defaultArgs,
actionsClient,
streaming: true,
});

const gen = actionsClientChatVertexAI._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);

const chunks = [];
await expect(async () => {
for await (const chunk of gen) {
chunks.push(chunk);
}
}).rejects.toEqual(
Error(
`Gemini Utils: action result status is error. Candidate was blocked due to SAFETY - HARM_CATEGORY_DANGEROUS_CONTENT: HARM_SEVERITY_HIGH`
)
);
});
});

describe('message formatting', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
partialStreamChunk += nextChunk;
}

if (parsedStreamChunk !== null && !parsedStreamChunk.candidates?.[0]?.finishReason) {
if (parsedStreamChunk !== null) {
const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk);
if (errorMessage != null) {
throw new Error(errorMessage);
}

const response = {
...parsedStreamChunk,
functionCalls: () =>
Expand Down Expand Up @@ -178,12 +183,6 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
yield chunk;
await runManager?.handleLLMNewToken(chunk.text ?? '');
}
} else if (parsedStreamChunk) {
// handle bad finish reason
const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk);
if (errorMessage != null) {
throw new Error(errorMessage);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
partialStreamChunk += nextChunk;
}

if (parsedStreamChunk !== null && !parsedStreamChunk.candidates?.[0]?.finishReason) {
if (parsedStreamChunk !== null) {
const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk);
if (errorMessage != null) {
throw new Error(errorMessage);
}
const response = {
...parsedStreamChunk,
functionCalls: () =>
Expand Down Expand Up @@ -247,12 +251,6 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
yield chunk;
await runManager?.handleLLMNewToken(chunk.text ?? '');
}
} else if (parsedStreamChunk) {
// handle bad finish reason
const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk);
if (errorMessage != null) {
throw new Error(errorMessage);
}
}
}
}
Expand Down

0 comments on commit f8edfb0

Please sign in to comment.