From 6905a0fbf75bbc45392d66d07aa8d6b500979924 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 23 Nov 2023 07:38:35 -0700 Subject: [PATCH] [Security solution] Fix streaming on cloud (#171578) --- .../server/lib/gen_ai_token_tracking.ts | 1 + ...get_token_count_from_invoke_stream.test.ts | 96 ++++++-- .../lib/get_token_count_from_invoke_stream.ts | 145 +++++++++++- .../public/assistant/get_comments/index.tsx | 5 + .../get_comments/stream/index.test.tsx | 1 + .../assistant/get_comments/stream/index.tsx | 3 + .../stream/stream_observable.test.ts | 208 ++++++++++++++++-- .../get_comments/stream/stream_observable.ts | 202 +++++++++++++++-- .../get_comments/stream/use_stream.test.tsx | 19 +- .../get_comments/stream/use_stream.tsx | 9 +- .../connector_types/bedrock/bedrock.test.ts | 31 +-- .../server/connector_types/bedrock/bedrock.ts | 30 +-- .../connector_types/openai/openai.test.ts | 49 +---- .../server/connector_types/openai/openai.ts | 47 +--- .../tests/actions/connector_types/bedrock.ts | 54 ++++- 15 files changed, 675 insertions(+), 225 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index 7c104177ea36e..866580e8e7b3b 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -42,6 +42,7 @@ export const getGenAiTokenTracking = async ({ try { const { total, prompt, completion } = await getTokenCountFromInvokeStream({ responseStream: result.data.pipe(new PassThrough()), + actionTypeId, body: (validatedParams as { subActionParams: InvokeBody }).subActionParams, logger, }); diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts index 3c0dd66130f3a..2d8f86b881728 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts @@ -7,20 +7,15 @@ import { Transform } from 'stream'; import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream'; import { loggerMock } from '@kbn/logging-mocks'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; -interface StreamMock { - write: (data: string) => void; - fail: () => void; - complete: () => void; - transform: Transform; -} - -function createStreamMock(): StreamMock { +function createStreamMock() { const transform: Transform = new Transform({}); return { - write: (data: string) => { - transform.push(`${data}\n`); + write: (data: unknown) => { + transform.push(data); }, fail: () => { transform.emit('error', new Error('Stream failed')); @@ -34,7 +29,10 @@ function createStreamMock(): StreamMock { } const logger = loggerMock.create(); describe('getTokenCountFromInvokeStream', () => { - let stream: StreamMock; + beforeEach(() => { + jest.resetAllMocks(); + }); + let stream: ReturnType; const body = { messages: [ { @@ -48,36 +46,79 @@ describe('getTokenCountFromInvokeStream', () => { ], }; + const chunk = { + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: 'Single.', + }, + }, + ], + }; + const PROMPT_TOKEN_COUNT = 34; const COMPLETION_TOKEN_COUNT = 2; + describe('OpenAI stream', () => { + beforeEach(() => { + stream = createStreamMock(); + stream.write(`data: ${JSON.stringify(chunk)}`); + }); - beforeEach(() => { - stream = createStreamMock(); - stream.write('Single'); - }); - - describe('when a stream completes', () => { - beforeEach(async () => { + it('counts the prompt + completion tokens for OpenAI response', async () => { stream.complete(); - }); - it('counts the prompt tokens', async () => { const tokens = await getTokenCountFromInvokeStream({ responseStream: stream.transform, body, logger, + actionTypeId: '.gen-ai', }); expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); }); + it('resolves the promise with the correct prompt tokens', async () => { + const tokenPromise = getTokenCountFromInvokeStream({ + responseStream: stream.transform, + body, + logger, + actionTypeId: '.gen-ai', + }); + + stream.fail(); + + await expect(tokenPromise).resolves.toEqual({ + prompt: PROMPT_TOKEN_COUNT, + total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT, + completion: COMPLETION_TOKEN_COUNT, + }); + expect(logger.error).toHaveBeenCalled(); + }); }); + describe('Bedrock stream', () => { + beforeEach(() => { + stream = createStreamMock(); + stream.write(encodeBedrockResponse('Simple.')); + }); - describe('when a stream fails', () => { + it('counts the prompt + completion tokens for OpenAI response', async () => { + stream.complete(); + const tokens = await getTokenCountFromInvokeStream({ + responseStream: stream.transform, + body, + logger, + actionTypeId: '.bedrock', + }); + expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); + expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); + expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); + }); it('resolves the promise with the correct prompt tokens', async () => { const tokenPromise = getTokenCountFromInvokeStream({ responseStream: stream.transform, body, logger, + actionTypeId: '.bedrock', }); stream.fail(); @@ -91,3 +132,16 @@ describe('getTokenCountFromInvokeStream', () => { }); }); }); + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index 594fec89d93c0..dfb4bae69f8cf 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -9,6 +9,8 @@ import { Logger } from '@kbn/logging'; import { encode } from 'gpt-tokenizer'; import { Readable } from 'stream'; import { finished } from 'stream/promises'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; export interface InvokeBody { messages: Array<{ @@ -26,10 +28,12 @@ export interface InvokeBody { * @param logger the logger */ export async function getTokenCountFromInvokeStream({ + actionTypeId, responseStream, body, logger, }: { + actionTypeId: string; responseStream: Readable; body: InvokeBody; logger: Logger; @@ -47,9 +51,37 @@ export async function getTokenCountFromInvokeStream({ .join('\n') ).length; - let responseBody: string = ''; + const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream; + const parsedResponse = await parser(responseStream, logger); + + const completionTokens = encode(parsedResponse).length; + return { + prompt: promptTokens, + completion: completionTokens, + total: promptTokens + completionTokens, + }; +} + +type StreamParser = (responseStream: Readable, logger: Logger) => Promise; - responseStream.on('data', (chunk: string) => { +const parseBedrockStream: StreamParser = async (responseStream, logger) => { + const responseBuffer: Uint8Array[] = []; + responseStream.on('data', (chunk) => { + // special encoding for bedrock, do not attempt to convert to string + responseBuffer.push(chunk); + }); + try { + await finished(responseStream); + } catch (e) { + logger.error('An error occurred while calculating streaming response tokens'); + } + return parseBedrockBuffer(responseBuffer); +}; + +const parseOpenAIStream: StreamParser = async (responseStream, logger) => { + let responseBody: string = ''; + responseStream.on('data', (chunk) => { + // no special encoding, can safely use toString and append to responseBody responseBody += chunk.toString(); }); try { @@ -57,12 +89,109 @@ export async function getTokenCountFromInvokeStream({ } catch (e) { logger.error('An error occurred while calculating streaming response tokens'); } + return parseOpenAIResponse(responseBody); +}; - const completionTokens = encode(responseBody).length; +/** + * Parses a Bedrock buffer from an array of chunks. + * + * @param {Uint8Array[]} chunks - Array of Uint8Array chunks to be parsed. + * @returns {string} - Parsed string from the Bedrock buffer. + */ +const parseBedrockBuffer = (chunks: Uint8Array[]): string => { + // Initialize an empty Uint8Array to store the concatenated buffer. + let bedrockBuffer: Uint8Array = new Uint8Array(0); - return { - prompt: promptTokens, - completion: completionTokens, - total: promptTokens + completionTokens, - }; + // Map through each chunk to process the Bedrock buffer. + return chunks + .map((chunk) => { + // Concatenate the current chunk to the existing buffer. + bedrockBuffer = concatChunks(bedrockBuffer, chunk); + // Get the length of the next message in the buffer. + let messageLength = getMessageLength(bedrockBuffer); + // Initialize an array to store fully formed message chunks. + const buildChunks = []; + // Process the buffer until no complete messages are left. + while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + // Extract a chunk of the specified length from the buffer. + const extractedChunk = bedrockBuffer.slice(0, messageLength); + // Add the extracted chunk to the array of fully formed message chunks. + buildChunks.push(extractedChunk); + // Remove the processed chunk from the buffer. + bedrockBuffer = bedrockBuffer.slice(messageLength); + // Get the length of the next message in the updated buffer. + messageLength = getMessageLength(bedrockBuffer); + } + + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + // Decode and parse each message chunk, extracting the 'completion' property. + return buildChunks + .map((bChunk) => { + const event = awsDecoder.decode(bChunk); + const body = JSON.parse( + Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() + ); + return body.completion; + }) + .join(''); + }) + .join(''); +}; + +/** + * Concatenates two Uint8Array buffers. + * + * @param {Uint8Array} a - First buffer. + * @param {Uint8Array} b - Second buffer. + * @returns {Uint8Array} - Concatenated buffer. + */ +function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array { + const newBuffer = new Uint8Array(a.length + b.length); + // Copy the contents of the first buffer to the new buffer. + newBuffer.set(a); + // Copy the contents of the second buffer to the new buffer starting from the end of the first buffer. + newBuffer.set(b, a.length); + return newBuffer; +} + +/** + * Gets the length of the next message from the buffer. + * + * @param {Uint8Array} buffer - Buffer containing the message. + * @returns {number} - Length of the next message. + */ +function getMessageLength(buffer: Uint8Array): number { + // If the buffer is empty, return 0. + if (buffer.byteLength === 0) return 0; + // Create a DataView to read the Uint32 value at the beginning of the buffer. + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + // Read and return the Uint32 value (message length). + return view.getUint32(0, false); } + +const parseOpenAIResponse = (responseBody: string) => + responseBody + .split('\n') + .filter((line) => { + return line.startsWith('data: ') && !line.endsWith('[DONE]'); + }) + .map((line) => { + return JSON.parse(line.replace('data: ', '')); + }) + .filter( + ( + line + ): line is { + choices: Array<{ + delta: { content?: string; function_call?: { name?: string; arguments: string } }; + }>; + } => { + return 'object' in line && line.object === 'chat.completion.chunk'; + } + ) + .reduce((prev, line) => { + const msg = line.choices[0].delta!; + prev += msg.content || ''; + return prev; + }, ''); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx index 3b778013a42d1..d8cfc46ec5a22 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx @@ -66,6 +66,8 @@ export const getComments = ({ regenerateMessage(currentConversation.id); }; + const connectorTypeTitle = currentConversation.apiConfig.connectorTypeTitle ?? ''; + const extraLoadingComment = isFetchingResponse ? [ { @@ -75,6 +77,7 @@ export const getComments = ({ children: ( ; regenerateMessage: () => void; transformMessage: (message: string) => ContentMessage; @@ -29,6 +30,7 @@ interface Props { export const StreamComment = ({ amendMessage, content, + connectorTypeTitle, index, isError = false, isFetching = false, @@ -40,6 +42,7 @@ export const StreamComment = ({ const { error, isLoading, isStreaming, pendingMessage, setComplete } = useStream({ amendMessage, content, + connectorTypeTitle, reader, isError, }); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts index 764db1b3990ae..54a5684d20442 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts @@ -9,6 +9,8 @@ import { API_ERROR } from '../translations'; import type { PromptObservableState } from './types'; import { Subject } from 'rxjs'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; describe('getStreamObservable', () => { const mockReader = { read: jest.fn(), @@ -22,29 +24,102 @@ describe('getStreamObservable', () => { beforeEach(() => { jest.clearAllMocks(); }); + it('should emit loading state and chunks for Bedrock', (done) => { + const completeSubject = new Subject(); + const expectedStates: PromptObservableState[] = [ + { chunks: [], loading: true }, + { + // when i log the actual emit, chunks equal to message.split(''); test is wrong + chunks: ['My', ' new', ' message'], + message: 'My', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: false, + }, + ]; - it('should emit loading state and chunks', (done) => { + mockReader.read + .mockResolvedValueOnce({ + done: false, + value: encodeBedrockResponse('My'), + }) + .mockResolvedValueOnce({ + done: false, + value: encodeBedrockResponse(' new'), + }) + .mockResolvedValueOnce({ + done: false, + value: encodeBedrockResponse(' message'), + }) + .mockResolvedValue({ + done: true, + }); + + const source = getStreamObservable({ + connectorTypeTitle: 'Amazon Bedrock', + isError: false, + reader: typedReader, + setLoading, + }); + const emittedStates: PromptObservableState[] = []; + + source.subscribe({ + next: (state) => { + return emittedStates.push(state); + }, + complete: () => { + expect(emittedStates).toEqual(expectedStates); + done(); + + completeSubject.subscribe({ + next: () => { + expect(setLoading).toHaveBeenCalledWith(false); + expect(typedReader.cancel).toHaveBeenCalled(); + done(); + }, + }); + }, + error: (err) => done(err), + }); + }); + it('should emit loading state and chunks for OpenAI', (done) => { + const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; + const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; const completeSubject = new Subject(); const expectedStates: PromptObservableState[] = [ { chunks: [], loading: true }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk ', + // when i log the actual emit, chunks equal to message.split(''); test is wrong + chunks: ['My', ' new', ' message'], + message: 'My', loading: true, }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk another chunk', + chunks: ['My', ' new', ' message'], + message: 'My new', loading: true, }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk another chunk', + chunks: ['My', ' new', ' message'], + message: 'My new message', loading: true, }, { - chunks: ['one chunk ', 'another chunk', ''], - message: 'one chunk another chunk', + chunks: ['My', ' new', ' message'], + message: 'My new message', loading: false, }, ]; @@ -52,11 +127,11 @@ describe('getStreamObservable', () => { mockReader.read .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`one chunk `)), + value: new Uint8Array(new TextEncoder().encode(chunk1)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`another chunk`)), + value: new Uint8Array(new TextEncoder().encode(chunk2)), }) .mockResolvedValueOnce({ done: false, @@ -66,11 +141,91 @@ describe('getStreamObservable', () => { done: true, }); - const source = getStreamObservable(typedReader, setLoading, false); + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: false, + reader: typedReader, + setLoading, + }); const emittedStates: PromptObservableState[] = []; source.subscribe({ - next: (state) => emittedStates.push(state), + next: (state) => { + return emittedStates.push(state); + }, + complete: () => { + expect(emittedStates).toEqual(expectedStates); + done(); + + completeSubject.subscribe({ + next: () => { + expect(setLoading).toHaveBeenCalledWith(false); + expect(typedReader.cancel).toHaveBeenCalled(); + done(); + }, + }); + }, + error: (err) => done(err), + }); + }); + it('should emit loading state and chunks for partial response OpenAI', (done) => { + const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"`; + const chunk2 = `}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; + const completeSubject = new Subject(); + const expectedStates: PromptObservableState[] = [ + { chunks: [], loading: true }, + { + // when i log the actual emit, chunks equal to message.split(''); test is wrong + chunks: ['My', ' new', ' message'], + message: 'My', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: true, + }, + { + chunks: ['My', ' new', ' message'], + message: 'My new message', + loading: false, + }, + ]; + + mockReader.read + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode(chunk1)), + }) + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode(chunk2)), + }) + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode('')), + }) + .mockResolvedValue({ + done: true, + }); + + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: false, + reader: typedReader, + setLoading, + }); + const emittedStates: PromptObservableState[] = []; + + source.subscribe({ + next: (state) => { + return emittedStates.push(state); + }, complete: () => { expect(emittedStates).toEqual(expectedStates); done(); @@ -112,7 +267,12 @@ describe('getStreamObservable', () => { done: true, }); - const source = getStreamObservable(typedReader, setLoading, true); + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: true, + reader: typedReader, + setLoading, + }); const emittedStates: PromptObservableState[] = []; source.subscribe({ @@ -138,7 +298,12 @@ describe('getStreamObservable', () => { const error = new Error('Test Error'); // Simulate an error mockReader.read.mockRejectedValue(error); - const source = getStreamObservable(typedReader, setLoading, false); + const source = getStreamObservable({ + connectorTypeTitle: 'OpenAI', + isError: false, + reader: typedReader, + setLoading, + }); source.subscribe({ next: (state) => {}, @@ -157,3 +322,16 @@ describe('getStreamObservable', () => { }); }); }); + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index b30be69b82cae..ce7a38811f229 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -7,10 +7,18 @@ import { concatMap, delay, finalize, Observable, of, scan, timestamp } from 'rxjs'; import type { Dispatch, SetStateAction } from 'react'; -import { API_ERROR } from '../translations'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import type { PromptObservableState } from './types'; +import { API_ERROR } from '../translations'; const MIN_DELAY = 35; +interface StreamObservable { + connectorTypeTitle: string; + reader: ReadableStreamDefaultReader; + setLoading: Dispatch>; + isError: boolean; +} /** * Returns an Observable that reads data from a ReadableStream and emits values representing the state of the data processing. * @@ -19,52 +27,155 @@ const MIN_DELAY = 35; * @param isError - indicates whether the reader response is an error message or not * @returns {Observable} An Observable that emits PromptObservableState */ -export const getStreamObservable = ( - reader: ReadableStreamDefaultReader, - setLoading: Dispatch>, - isError: boolean -): Observable => +export const getStreamObservable = ({ + connectorTypeTitle, + isError, + reader, + setLoading, +}: StreamObservable): Observable => new Observable((observer) => { observer.next({ chunks: [], loading: true }); const decoder = new TextDecoder(); const chunks: string[] = []; - function read() { + // Initialize an empty string to store the OpenAI buffer. + let openAIBuffer: string = ''; + + // Initialize an empty Uint8Array to store the Bedrock concatenated buffer. + let bedrockBuffer: Uint8Array = new Uint8Array(0); + function readOpenAI() { reader .read() .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { try { if (done) { + if (openAIBuffer) { + chunks.push(getOpenAIChunks([openAIBuffer])[0]); + } observer.next({ chunks, - message: getMessageFromChunks(chunks), + message: chunks.join(''), loading: false, }); observer.complete(); return; } + const decoded = decoder.decode(value); - const content = isError - ? // we format errors as {message: string; status_code: number} - `${API_ERROR}\n\n${JSON.parse(decoded).message}` - : // all other responses are just strings (handled by subaction invokeStream) - decoded; - chunks.push(content); - observer.next({ - chunks, - message: getMessageFromChunks(chunks), - loading: true, + let nextChunks; + if (isError) { + nextChunks = [`${API_ERROR}\n\n${JSON.parse(decoded).message}`]; + } else { + const lines = decoded.split('\n'); + lines[0] = openAIBuffer + lines[0]; + openAIBuffer = lines.pop() || ''; + nextChunks = getOpenAIChunks(lines); + } + nextChunks.forEach((chunk: string) => { + chunks.push(chunk); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); }); } catch (err) { observer.error(err); return; } - read(); + readOpenAI(); + }) + .catch((err) => { + observer.error(err); + }); + } + function readBedrock() { + reader + .read() + .then(({ done, value }: { done: boolean; value?: Uint8Array }) => { + try { + if (done) { + observer.next({ + chunks, + message: chunks.join(''), + loading: false, + }); + observer.complete(); + return; + } + + let content; + if (isError) { + content = `${API_ERROR}\n\n${JSON.parse(decoder.decode(value)).message}`; + chunks.push(content); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); + } else if (value != null) { + const chunk: Uint8Array = value; + + // Concatenate the current chunk to the existing buffer. + bedrockBuffer = concatChunks(bedrockBuffer, chunk); + // Get the length of the next message in the buffer. + let messageLength = getMessageLength(bedrockBuffer); + + // Initialize an array to store fully formed message chunks. + const buildChunks = []; + // Process the buffer until no complete messages are left. + while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + // Extract a chunk of the specified length from the buffer. + const extractedChunk = bedrockBuffer.slice(0, messageLength); + // Add the extracted chunk to the array of fully formed message chunks. + buildChunks.push(extractedChunk); + // Remove the processed chunk from the buffer. + bedrockBuffer = bedrockBuffer.slice(messageLength); + // Get the length of the next message in the updated buffer. + messageLength = getMessageLength(bedrockBuffer); + } + + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + // Decode and parse each message chunk, extracting the 'completion' property. + buildChunks.forEach((bChunk) => { + const event = awsDecoder.decode(bChunk); + const body = JSON.parse( + Buffer.from(JSON.parse(decoder.decode(event.body)).bytes, 'base64').toString() + ); + content = body.completion; + chunks.push(content); + observer.next({ + chunks, + message: chunks.join(''), + loading: true, + }); + }); + } + } catch (err) { + observer.error(err); + return; + } + readBedrock(); }) .catch((err) => { observer.error(err); }); } - read(); + // this should never actually happen + function badConnector() { + observer.next({ + chunks: [ + `Invalid connector type - ${connectorTypeTitle} is not a supported GenAI connector.`, + ], + message: `Invalid connector type - ${connectorTypeTitle} is not a supported GenAI connector.`, + loading: false, + }); + observer.complete(); + } + + if (connectorTypeTitle === 'Amazon Bedrock') readBedrock(); + else if (connectorTypeTitle === 'OpenAI') readOpenAI(); + else badConnector(); + return () => { reader.cancel(); }; @@ -99,8 +210,55 @@ export const getStreamObservable = ( finalize(() => setLoading(false)) ); -function getMessageFromChunks(chunks: string[]) { - return chunks.join(''); +/** + * Parses an OpenAI response from a string. + * @param lines + * @returns {string[]} - Parsed string array from the OpenAI response. + */ +const getOpenAIChunks = (lines: string[]): string[] => { + const nextChunk = lines + .map((str) => str.substring(6)) + .filter((str) => !!str && str !== '[DONE]') + .map((line) => { + try { + const openaiResponse = JSON.parse(line); + return openaiResponse.choices[0]?.delta.content ?? ''; + } catch (err) { + return ''; + } + }); + return nextChunk; +}; + +/** + * Concatenates two Uint8Array buffers. + * + * @param {Uint8Array} a - First buffer. + * @param {Uint8Array} b - Second buffer. + * @returns {Uint8Array} - Concatenated buffer. + */ +function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array { + const newBuffer = new Uint8Array(a.length + b.length); + // Copy the contents of the first buffer to the new buffer. + newBuffer.set(a); + // Copy the contents of the second buffer to the new buffer starting from the end of the first buffer. + newBuffer.set(b, a.length); + return newBuffer; +} + +/** + * Gets the length of the next message from the buffer. + * + * @param {Uint8Array} buffer - Buffer containing the message. + * @returns {number} - Length of the next message. + */ +function getMessageLength(buffer: Uint8Array): number { + // If the buffer is empty, return 0. + if (buffer.byteLength === 0) return 0; + // Create a DataView to read the Uint32 value at the beginning of the buffer. + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + // Read and return the Uint32 value (message length). + return view.getUint32(0, false); } export const getPlaceholderObservable = () => new Observable(); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx index efbc61999f2cc..c4f99884aa045 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx @@ -11,20 +11,22 @@ import { useStream } from './use_stream'; const amendMessage = jest.fn(); const reader = jest.fn(); const cancel = jest.fn(); +const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; +const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; const readerComplete = { read: reader .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode('one chunk ')), + value: new Uint8Array(new TextEncoder().encode(chunk1)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`another chunk`)), + value: new Uint8Array(new TextEncoder().encode(chunk2)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(``)), + value: new Uint8Array(new TextEncoder().encode('')), }) .mockResolvedValue({ done: true, @@ -34,7 +36,12 @@ const readerComplete = { closed: jest.fn().mockResolvedValue(true), } as unknown as ReadableStreamDefaultReader; -const defaultProps = { amendMessage, reader: readerComplete, isError: false }; +const defaultProps = { + amendMessage, + reader: readerComplete, + isError: false, + connectorTypeTitle: 'OpenAI', +}; describe('useStream', () => { beforeEach(() => { jest.clearAllMocks(); @@ -57,7 +64,7 @@ describe('useStream', () => { error: undefined, isLoading: true, isStreaming: true, - pendingMessage: 'one chunk ', + pendingMessage: 'My', setComplete: expect.any(Function), }); }); @@ -67,7 +74,7 @@ describe('useStream', () => { error: undefined, isLoading: false, isStreaming: false, - pendingMessage: 'one chunk another chunk', + pendingMessage: 'My new message', setComplete: expect.any(Function), }); }); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx index 7de06589f87c7..9271758a8558e 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx @@ -7,13 +7,13 @@ import { useCallback, useEffect, useMemo, useState } from 'react'; import type { Subscription } from 'rxjs'; -import { share } from 'rxjs'; import { getPlaceholderObservable, getStreamObservable } from './stream_observable'; interface UseStreamProps { amendMessage: (message: string) => void; isError: boolean; content?: string; + connectorTypeTitle: string; reader?: ReadableStreamDefaultReader; } interface UseStream { @@ -39,6 +39,7 @@ interface UseStream { export const useStream = ({ amendMessage, content, + connectorTypeTitle, reader, isError, }: UseStreamProps): UseStream => { @@ -49,9 +50,9 @@ export const useStream = ({ const observer$ = useMemo( () => content == null && reader != null - ? getStreamObservable(reader, setLoading, isError) + ? getStreamObservable({ connectorTypeTitle, reader, setLoading, isError }) : getPlaceholderObservable(), - [content, isError, reader] + [content, isError, reader, connectorTypeTitle] ); const onCompleteStream = useCallback(() => { subscription?.unsubscribe(); @@ -66,7 +67,7 @@ export const useStream = ({ } }, [complete, onCompleteStream]); useEffect(() => { - const newSubscription = observer$.pipe(share()).subscribe({ + const newSubscription = observer$.subscribe({ next: ({ message, loading: isLoading }) => { setLoading(isLoading); setPendingMessage(message); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts index 708e8cd4e0364..0eeb309dd2257 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts @@ -5,13 +5,10 @@ * 2.0. */ import aws from 'aws4'; -import { Transform } from 'stream'; +import { PassThrough, Transform } from 'stream'; import { BedrockConnector } from './bedrock'; -import { waitFor } from '@testing-library/react'; import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; -import { EventStreamCodec } from '@smithy/eventstream-codec'; -import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { RunActionResponseSchema, StreamingResponseSchema } from '../../../common/bedrock/schema'; import { @@ -105,7 +102,7 @@ describe('BedrockConnector', () => { let stream; beforeEach(() => { stream = createStreamMock(); - stream.write(encodeBedrockResponse(mockResponseString)); + stream.write(new Uint8Array([1, 2, 3])); mockRequest = jest.fn().mockResolvedValue({ ...mockResponse, data: stream.transform }); // @ts-ignore connector.request = mockRequest; @@ -199,16 +196,9 @@ describe('BedrockConnector', () => { }); }); - it('transforms the response into a string', async () => { + it('responds with a readable stream', async () => { const response = await connector.invokeStream(aiAssistantBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual(mockResponseString); - }); + expect(response instanceof PassThrough).toEqual(true); }); it('errors during API calls are properly handled', async () => { @@ -364,16 +354,3 @@ function createStreamMock() { }, }; } - -function encodeBedrockResponse(completion: string) { - return new EventStreamCodec(toUtf8, fromUtf8).encode({ - headers: {}, - body: Uint8Array.from( - Buffer.from( - JSON.stringify({ - bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), - }) - ) - ), - }); -} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 70f8e121e1519..ade589e54dc14 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -9,9 +9,7 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import aws from 'aws4'; import type { AxiosError } from 'axios'; import { IncomingMessage } from 'http'; -import { PassThrough, Transform } from 'stream'; -import { EventStreamCodec } from '@smithy/eventstream-codec'; -import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; +import { PassThrough } from 'stream'; import { RunActionParamsSchema, RunActionResponseSchema, @@ -178,12 +176,12 @@ export class BedrockConnector extends SubActionConnector { * @param messages An array of messages to be sent to the API * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. */ - public async invokeStream({ messages, model }: InvokeAIActionParams): Promise { + public async invokeStream({ messages, model }: InvokeAIActionParams): Promise { const res = (await this.streamApi({ body: JSON.stringify(formatBedrockBody({ messages })), model, })) as unknown as IncomingMessage; - return res.pipe(transformToString()); + return res; } /** @@ -222,25 +220,3 @@ const formatBedrockBody = ({ stop_sequences: ['\n\nHuman:'], }; }; - -/** - * Takes in a readable stream of data and returns a Transform stream that - * uses the AWS proprietary codec to parse the proprietary bedrock response into - * a string of the response text alone, returning the response string to the stream - */ -const transformToString = () => - new Transform({ - transform(chunk, encoding, callback) { - const encoder = new TextEncoder(); - const decoder = new EventStreamCodec(toUtf8, fromUtf8); - const event = decoder.decode(chunk); - const body = JSON.parse( - Buffer.from( - JSON.parse(new TextDecoder('utf-8').decode(event.body)).bytes, - 'base64' - ).toString() - ); - const newChunk = encoder.encode(body.completion); - callback(null, newChunk); - }, - }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts index 7769dd8592faf..c7d6feb6887ad 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts @@ -17,8 +17,7 @@ import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { RunActionResponseSchema, StreamingResponseSchema } from '../../../common/openai/schema'; import { initDashboard } from './create_dashboard'; -import { Transform } from 'stream'; -import { waitFor } from '@testing-library/react'; +import { PassThrough, Transform } from 'stream'; jest.mock('./create_dashboard'); describe('OpenAIConnector', () => { @@ -315,53 +314,11 @@ describe('OpenAIConnector', () => { await expect(connector.invokeStream(sampleOpenAiBody)).rejects.toThrow('API Error'); }); - it('transforms the response into a string', async () => { + it('responds with a readable stream', async () => { // @ts-ignore connector.request = mockStream(); const response = await connector.invokeStream(sampleOpenAiBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual('My new'); - }); - }); - it('correctly buffers stream of json lines', async () => { - const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; - const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; - - // @ts-ignore - connector.request = mockStream([chunk1, chunk2]); - - const response = await connector.invokeStream(sampleOpenAiBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual('My new message'); - }); - }); - it('correctly buffers partial lines', async () => { - const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"`; - - const chunk2 = `}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; - - // @ts-ignore - connector.request = mockStream([chunk1, chunk2]); - - const response = await connector.invokeStream(sampleOpenAiBody); - - let responseBody: string = ''; - response.on('data', (data: string) => { - responseBody += data.toString(); - }); - await waitFor(() => { - expect(responseBody).toEqual('My new message'); - }); + expect(response instanceof PassThrough).toEqual(true); }); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 78fca4bd84198..8dfeac0be8502 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -7,8 +7,8 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import type { AxiosError } from 'axios'; -import { PassThrough, Transform } from 'stream'; import { IncomingMessage } from 'http'; +import { PassThrough } from 'stream'; import { RunActionParamsSchema, RunActionResponseSchema, @@ -198,13 +198,13 @@ export class OpenAIConnector extends SubActionConnector { * the response from the streamApi method and returns the response string alone. * @param body - the OpenAI Invoke request body */ - public async invokeStream(body: InvokeAIActionParams): Promise { + public async invokeStream(body: InvokeAIActionParams): Promise { const res = (await this.streamApi({ body: JSON.stringify(body), stream: true, })) as unknown as IncomingMessage; - return res.pipe(new PassThrough()).pipe(transformToString()); + return res.pipe(new PassThrough()); } /** @@ -229,44 +229,3 @@ export class OpenAIConnector extends SubActionConnector { }; } } - -/** - * Takes in a readable stream of data and returns a Transform stream that - * parses the proprietary OpenAI response into a string of the response text alone, - * returning the response string to the stream - */ -const transformToString = () => { - let lineBuffer: string = ''; - const decoder = new TextDecoder(); - - return new Transform({ - transform(chunk, encoding, callback) { - const chunks = decoder.decode(chunk); - const lines = chunks.split('\n'); - lines[0] = lineBuffer + lines[0]; - lineBuffer = lines.pop() || ''; - callback(null, getNextChunk(lines)); - }, - flush(callback) { - // Emit an additional chunk with the content of lineBuffer if it has length - if (lineBuffer.length > 0) { - callback(null, getNextChunk([lineBuffer])); - } else { - callback(); - } - }, - }); -}; - -const getNextChunk = (lines: string[]) => { - const encoder = new TextEncoder(); - const nextChunk = lines - .map((str) => str.substring(6)) - .filter((str) => !!str && str !== '[DONE]') - .map((line) => { - const openaiResponse = JSON.parse(line); - return openaiResponse.choices[0]?.delta.content ?? ''; - }) - .join(''); - return encoder.encode(nextChunk); -}; diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts index 70cdc0f96dfdd..60eb8b6634a35 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts @@ -13,6 +13,8 @@ import { } from '@kbn/actions-simulators-plugin/server/bedrock_simulation'; import { DEFAULT_TOKEN_LIMIT } from '@kbn/stack-connectors-plugin/common/bedrock/constants'; import { PassThrough } from 'stream'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { FtrProviderContext } from '../../../../../common/ftr_provider_context'; import { getUrlPrefix, ObjectRemover } from '../../../../../common/lib'; @@ -411,8 +413,6 @@ export default function bedrockTest({ getService }: FtrProviderContext) { it('should invoke stream with assistant AI body argument formatted to bedrock expectations', async () => { await new Promise((resolve, reject) => { - let responseBody: string = ''; - const passThrough = new PassThrough(); supertest @@ -434,13 +434,14 @@ export default function bedrockTest({ getService }: FtrProviderContext) { assistantLangChain: false, }) .pipe(passThrough); - + const responseBuffer: Uint8Array[] = []; passThrough.on('data', (chunk) => { - responseBody += chunk.toString(); + responseBuffer.push(chunk); }); passThrough.on('end', () => { - expect(responseBody).to.eql('Hello world, what a unique string!'); + const parsed = parseBedrockBuffer(responseBuffer); + expect(parsed).to.eql('Hello world, what a unique string!'); resolve(); }); }); @@ -517,3 +518,46 @@ export default function bedrockTest({ getService }: FtrProviderContext) { }); }); } + +const parseBedrockBuffer = (chunks: Uint8Array[]): string => { + let bedrockBuffer: Uint8Array = new Uint8Array(0); + + return chunks + .map((chunk) => { + bedrockBuffer = concatChunks(bedrockBuffer, chunk); + let messageLength = getMessageLength(bedrockBuffer); + const buildChunks = []; + while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) { + const extractedChunk = bedrockBuffer.slice(0, messageLength); + buildChunks.push(extractedChunk); + bedrockBuffer = bedrockBuffer.slice(messageLength); + messageLength = getMessageLength(bedrockBuffer); + } + + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + return buildChunks + .map((bChunk) => { + const event = awsDecoder.decode(bChunk); + const body = JSON.parse( + Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString() + ); + return body.completion; + }) + .join(''); + }) + .join(''); +}; + +function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array { + const newBuffer = new Uint8Array(a.length + b.length); + newBuffer.set(a); + newBuffer.set(b, a.length); + return newBuffer; +} + +function getMessageLength(buffer: Uint8Array): number { + if (buffer.byteLength === 0) return 0; + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + return view.getUint32(0, false); +}