Skip to content

Commit

Permalink
[Security solution] Fix LangGraph stream with SimpleChatModel (#187994
Browse files Browse the repository at this point in the history
)
  • Loading branch information
stephmilovic authored Jul 15, 2024
1 parent 3c338a8 commit d5843b3
Show file tree
Hide file tree
Showing 15 changed files with 549 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,4 @@ export const getOptionalRequestParams = ({
};
};

export const hasParsableResponse = ({
isEnabledRAGAlerts,
isEnabledKnowledgeBase,
}: {
isEnabledRAGAlerts: boolean;
isEnabledKnowledgeBase: boolean;
}): boolean => isEnabledKnowledgeBase || isEnabledRAGAlerts;
export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
24 changes: 22 additions & 2 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import {
getDefaultConnector,
getBlockBotConversation,
mergeBaseWithPersistedConversations,
sleep,
} from './helpers';

import { useAssistantContext, UserAvatar } from '../assistant_context';
Expand Down Expand Up @@ -197,15 +198,34 @@ const AssistantComponent: React.FC<Props> = ({
}, [currentConversation?.title, setConversationTitle]);

const refetchCurrentConversation = useCallback(
async ({ cId, cTitle }: { cId?: string; cTitle?: string } = {}) => {
async ({
cId,
cTitle,
isStreamRefetch = false,
}: { cId?: string; cTitle?: string; isStreamRefetch?: boolean } = {}) => {
if (cId === '' || (cTitle && !conversations[cTitle])) {
return;
}

const conversationId = cId ?? (cTitle && conversations[cTitle].id) ?? currentConversation?.id;

if (conversationId) {
const updatedConversation = await getConversation(conversationId);
let updatedConversation = await getConversation(conversationId);
let retries = 0;
const maxRetries = 5;

// this retry is a workaround for the stream not YET being persisted to the stored conversation
while (
isStreamRefetch &&
updatedConversation &&
updatedConversation.messages[updatedConversation.messages.length - 1].role !==
'assistant' &&
retries < maxRetries
) {
retries++;
await sleep(2000);
updatedConversation = await getConversation(conversationId);
}

if (updatedConversation) {
setCurrentConversation(updatedConversation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export interface AssistantProviderProps {
currentConversation?: Conversation;
isEnabledLangChain: boolean;
isFetchingResponse: boolean;
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
regenerateMessage: (conversationId: string) => void;
showAnonymizedValues: boolean;
setIsStreaming: (isStreaming: boolean) => void;
Expand Down Expand Up @@ -108,7 +108,7 @@ export interface UseAssistantContext {
currentConversation?: Conversation;
isEnabledLangChain: boolean;
isFetchingResponse: boolean;
refetchCurrentConversation: () => void;
refetchCurrentConversation: ({ isStreamRefetch }: { isStreamRefetch?: boolean }) => void;
regenerateMessage: () => void;
showAnonymizedValues: boolean;
currentUserAvatar?: UserAvatar;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import { ActionsClientSimpleChatModel } from './simple_chat_model';
import { mockActionResponse } from './mocks';
import { BaseMessage } from '@langchain/core/messages';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { parseBedrockStream } from '../utils/bedrock';
import { parseGeminiStream } from '../utils/gemini';
import { parseBedrockStream, parseBedrockStreamAsAsyncIterator } from '../utils/bedrock';
import { parseGeminiStream, parseGeminiStreamAsAsyncIterator } from '../utils/gemini';

const connectorId = 'mock-connector-id';

Expand Down Expand Up @@ -301,5 +301,119 @@ describe('ActionsClientSimpleChatModel', () => {
expect(handleLLMNewToken).toHaveBeenCalledTimes(1);
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
});
it('extra tokens in the final answer start chunk get pushed to handleLLMNewToken', async () => {
(parseBedrockStream as jest.Mock).mockImplementation((_1, _2, _3, handleToken) => {
handleToken('token1');
handleToken(`"action":`);
handleToken(`"Final Answer"`);
handleToken(`, "action_input": "token5 `);
handleToken('token6');
});
actionsClient.execute.mockImplementationOnce(mockStreamExecute);

const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'bedrock',
streaming: true,
});
await actionsClientSimpleChatModel._call(callMessages, callOptions, callRunManager);
expect(handleLLMNewToken).toHaveBeenCalledTimes(2);
expect(handleLLMNewToken).toHaveBeenCalledWith('token5 ');
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
});
it('extra tokens in the final answer end chunk get pushed to handleLLMNewToken', async () => {
(parseBedrockStream as jest.Mock).mockImplementation((_1, _2, _3, handleToken) => {
handleToken('token5');
handleToken(`"action":`);
handleToken(`"Final Answer"`);
handleToken(`, "action_input": "`);
handleToken('token6');
handleToken('token7"');
handleToken('token8');
});
actionsClient.execute.mockImplementationOnce(mockStreamExecute);
const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'bedrock',
streaming: true,
});
await actionsClientSimpleChatModel._call(callMessages, callOptions, callRunManager);
expect(handleLLMNewToken).toHaveBeenCalledTimes(2);
expect(handleLLMNewToken).toHaveBeenCalledWith('token6');
expect(handleLLMNewToken).toHaveBeenCalledWith('token7');
});
});

describe('*_streamResponseChunks', () => {
it('iterates over bedrock chunks', async () => {
function* mockFetchData() {
yield 'token1';
yield 'token2';
yield 'token3';
}
(parseBedrockStreamAsAsyncIterator as jest.Mock).mockImplementation(mockFetchData);
actionsClient.execute.mockImplementationOnce(mockStreamExecute);

const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'bedrock',
streaming: true,
});

const gen = actionsClientSimpleChatModel._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);

const chunks = [];

for await (const chunk of gen) {
chunks.push(chunk);
}

expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
it('iterates over gemini chunks', async () => {
function* mockFetchData() {
yield 'token1';
yield 'token2';
yield 'token3';
}
(parseGeminiStreamAsAsyncIterator as jest.Mock).mockImplementation(mockFetchData);
actionsClient.execute.mockImplementationOnce(mockStreamExecute);

const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({
...defaultArgs,
actionsClient,
llmType: 'gemini',
streaming: true,
});

const gen = actionsClientSimpleChatModel._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);

const chunks = [];

for await (const chunk of gen) {
chunks.push(chunk);
}

expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ import {
SimpleChatModel,
type BaseChatModelParams,
} from '@langchain/core/language_models/chat_models';
import { type BaseMessage } from '@langchain/core/messages';
import { AIMessageChunk, type BaseMessage } from '@langchain/core/messages';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { Logger } from '@kbn/logging';
import { v4 as uuidv4 } from 'uuid';
import { get } from 'lodash/fp';
import { ChatGenerationChunk } from '@langchain/core/outputs';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { PublicMethodsOf } from '@kbn/utility-types';
import { parseGeminiStream } from '../utils/gemini';
import { parseBedrockStream } from '../utils/bedrock';
import { parseGeminiStreamAsAsyncIterator, parseGeminiStream } from '../utils/gemini';
import { parseBedrockStreamAsAsyncIterator, parseBedrockStream } from '../utils/bedrock';
import { getDefaultArguments } from './constants';

export const getMessageContentAndRole = (prompt: string, role = 'user') => ({
Expand All @@ -38,6 +39,18 @@ export interface CustomChatModelInput extends BaseChatModelParams {
maxTokens?: number;
}

function _formatMessages(messages: BaseMessage[]) {
if (!messages.length) {
throw new Error('No messages provided.');
}
return messages.map((message, i) => {
if (typeof message.content !== 'string') {
throw new Error('Multimodal messages are not supported.');
}
return getMessageContentAndRole(message.content, message._getType());
});
}

export class ActionsClientSimpleChatModel extends SimpleChatModel {
#actionsClient: PublicMethodsOf<ActionsClient>;
#connectorId: string;
Expand Down Expand Up @@ -91,16 +104,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error('No messages provided.');
}
const formattedMessages: Array<{ content: string; role: string }> = [];
messages.forEach((message, i) => {
if (typeof message.content !== 'string') {
throw new Error('Multimodal messages are not supported.');
}
formattedMessages.push(getMessageContentAndRole(message.content, message._getType()));
});
const formattedMessages = _formatMessages(messages);
this.#logger.debug(
() =>
`ActionsClientSimpleChatModel#_call\ntraceId: ${
Expand Down Expand Up @@ -150,18 +154,30 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
let finalOutputIndex = -1;
const finalOutputStartToken = '"action":"FinalAnswer","action_input":"';
let streamingFinished = false;
const finalOutputStopRegex = /(?<!\\)\"/;
const finalOutputStopRegex = /(?<!\\)"/;
const handleLLMNewToken = async (token: string) => {
if (finalOutputIndex === -1) {
currentOutput += token;
// Remove whitespace to simplify parsing
currentOutput += token.replace(/\s/g, '');
if (currentOutput.includes(finalOutputStartToken)) {
finalOutputIndex = currentOutput.indexOf(finalOutputStartToken);
const noWhitespaceOutput = currentOutput.replace(/\s/g, '');
if (noWhitespaceOutput.includes(finalOutputStartToken)) {
const nonStrippedToken = '"action_input": "';
finalOutputIndex = currentOutput.indexOf(nonStrippedToken);
const contentStartIndex = finalOutputIndex + nonStrippedToken.length;
const extraOutput = currentOutput.substring(contentStartIndex);
if (extraOutput.length > 0) {
await runManager?.handleLLMNewToken(extraOutput);
}
}
} else if (!streamingFinished) {
const finalOutputEndIndex = token.search(finalOutputStopRegex);
if (finalOutputEndIndex !== -1) {
streamingFinished = true;
const extraOutput = token.substring(0, finalOutputEndIndex);
streamingFinished = true;
if (extraOutput.length > 0) {
await runManager?.handleLLMNewToken(extraOutput);
}
} else {
await runManager?.handleLLMNewToken(token);
}
Expand All @@ -173,4 +189,55 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {

return parsed; // per the contact of _call, return a string
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun | undefined
): AsyncGenerator<ChatGenerationChunk> {
const formattedMessages = _formatMessages(messages);
this.#logger.debug(
() =>
`ActionsClientSimpleChatModel#stream\ntraceId: ${
this.#traceId
}\nassistantMessage:\n${JSON.stringify(formattedMessages)} `
);
// create a new connector request body with the assistant message:
const requestBody = {
actionId: this.#connectorId,
params: {
subAction: 'invokeStream',
subActionParams: {
model: this.model,
messages: formattedMessages,
...getDefaultArguments(this.llmType, this.temperature, options.stop, this.#maxTokens),
},
},
};
const actionResult = await this.#actionsClient.execute(requestBody);

if (actionResult.status === 'error') {
throw new Error(
`ActionsClientSimpleChatModel: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}

const readable = get('data', actionResult) as Readable;

if (typeof readable?.read !== 'function') {
throw new Error('Action result status is error: result is not streamable');
}

const streamParser =
this.llmType === 'bedrock'
? parseBedrockStreamAsAsyncIterator
: parseGeminiStreamAsAsyncIterator;
for await (const token of streamParser(readable, this.#logger, this.#signal)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({ content: token }),
text: token,
});
await runManager?.handleLLMNewToken(token);
}
}
}
25 changes: 25 additions & 0 deletions x-pack/packages/kbn-langchain/server/utils/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,37 @@
* 2.0.
*/

import { Readable } from 'stream';
import { finished } from 'stream/promises';
import { Logger } from '@kbn/core/server';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
import { StreamParser } from './types';

export const parseBedrockStreamAsAsyncIterator = async function* (
responseStream: Readable,
logger: Logger,
abortSignal?: AbortSignal
) {
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
responseStream.destroy(new Error('Aborted'));
});
}
try {
for await (const chunk of responseStream) {
const bedrockChunk = handleBedrockChunk({ chunk, bedrockBuffer: new Uint8Array(0), logger });
yield bedrockChunk.decodedChunk;
}
} catch (err) {
if (abortSignal?.aborted) {
logger.info('Bedrock stream parsing was aborted.');
} else {
throw err;
}
}
};

export const parseBedrockStream: StreamParser = async (
responseStream,
logger,
Expand Down
Loading

0 comments on commit d5843b3

Please sign in to comment.