Skip to content

Commit

Permalink
[Security solution] Fix streaming on cloud (elastic#171578)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Nov 23, 2023
1 parent 5f9e70c commit 6905a0f
Show file tree
Hide file tree
Showing 15 changed files with 675 additions and 225 deletions.
1 change: 1 addition & 0 deletions x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export const getGenAiTokenTracking = async ({
try {
const { total, prompt, completion } = await getTokenCountFromInvokeStream({
responseStream: result.data.pipe(new PassThrough()),
actionTypeId,
body: (validatedParams as { subActionParams: InvokeBody }).subActionParams,
logger,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,15 @@
import { Transform } from 'stream';
import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream';
import { loggerMock } from '@kbn/logging-mocks';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';

interface StreamMock {
write: (data: string) => void;
fail: () => void;
complete: () => void;
transform: Transform;
}

function createStreamMock(): StreamMock {
function createStreamMock() {
const transform: Transform = new Transform({});

return {
write: (data: string) => {
transform.push(`${data}\n`);
write: (data: unknown) => {
transform.push(data);
},
fail: () => {
transform.emit('error', new Error('Stream failed'));
Expand All @@ -34,7 +29,10 @@ function createStreamMock(): StreamMock {
}
const logger = loggerMock.create();
describe('getTokenCountFromInvokeStream', () => {
let stream: StreamMock;
beforeEach(() => {
jest.resetAllMocks();
});
let stream: ReturnType<typeof createStreamMock>;
const body = {
messages: [
{
Expand All @@ -48,36 +46,79 @@ describe('getTokenCountFromInvokeStream', () => {
],
};

const chunk = {
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: 'Single.',
},
},
],
};

const PROMPT_TOKEN_COUNT = 34;
const COMPLETION_TOKEN_COUNT = 2;
describe('OpenAI stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(chunk)}`);
});

beforeEach(() => {
stream = createStreamMock();
stream.write('Single');
});

describe('when a stream completes', () => {
beforeEach(async () => {
it('counts the prompt + completion tokens for OpenAI response', async () => {
stream.complete();
});
it('counts the prompt tokens', async () => {
const tokens = await getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.gen-ai',
});
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
it('resolves the promise with the correct prompt tokens', async () => {
const tokenPromise = getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.gen-ai',
});

stream.fail();

await expect(tokenPromise).resolves.toEqual({
prompt: PROMPT_TOKEN_COUNT,
total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT,
completion: COMPLETION_TOKEN_COUNT,
});
expect(logger.error).toHaveBeenCalled();
});
});
describe('Bedrock stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(encodeBedrockResponse('Simple.'));
});

describe('when a stream fails', () => {
it('counts the prompt + completion tokens for OpenAI response', async () => {
stream.complete();
const tokens = await getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.bedrock',
});
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
it('resolves the promise with the correct prompt tokens', async () => {
const tokenPromise = getTokenCountFromInvokeStream({
responseStream: stream.transform,
body,
logger,
actionTypeId: '.bedrock',
});

stream.fail();
Expand All @@ -91,3 +132,16 @@ describe('getTokenCountFromInvokeStream', () => {
});
});
});

function encodeBedrockResponse(completion: string) {
return new EventStreamCodec(toUtf8, fromUtf8).encode({
headers: {},
body: Uint8Array.from(
Buffer.from(
JSON.stringify({
bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'),
})
)
),
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { Logger } from '@kbn/logging';
import { encode } from 'gpt-tokenizer';
import { Readable } from 'stream';
import { finished } from 'stream/promises';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';

export interface InvokeBody {
messages: Array<{
Expand All @@ -26,10 +28,12 @@ export interface InvokeBody {
* @param logger the logger
*/
export async function getTokenCountFromInvokeStream({
actionTypeId,
responseStream,
body,
logger,
}: {
actionTypeId: string;
responseStream: Readable;
body: InvokeBody;
logger: Logger;
Expand All @@ -47,22 +51,147 @@ export async function getTokenCountFromInvokeStream({
.join('\n')
).length;

let responseBody: string = '';
const parser = actionTypeId === '.bedrock' ? parseBedrockStream : parseOpenAIStream;
const parsedResponse = await parser(responseStream, logger);

const completionTokens = encode(parsedResponse).length;
return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}

type StreamParser = (responseStream: Readable, logger: Logger) => Promise<string>;

responseStream.on('data', (chunk: string) => {
const parseBedrockStream: StreamParser = async (responseStream, logger) => {
const responseBuffer: Uint8Array[] = [];
responseStream.on('data', (chunk) => {
// special encoding for bedrock, do not attempt to convert to string
responseBuffer.push(chunk);
});
try {
await finished(responseStream);
} catch (e) {
logger.error('An error occurred while calculating streaming response tokens');
}
return parseBedrockBuffer(responseBuffer);
};

const parseOpenAIStream: StreamParser = async (responseStream, logger) => {
let responseBody: string = '';
responseStream.on('data', (chunk) => {
// no special encoding, can safely use toString and append to responseBody
responseBody += chunk.toString();
});
try {
await finished(responseStream);
} catch (e) {
logger.error('An error occurred while calculating streaming response tokens');
}
return parseOpenAIResponse(responseBody);
};

const completionTokens = encode(responseBody).length;
/**
* Parses a Bedrock buffer from an array of chunks.
*
* @param {Uint8Array[]} chunks - Array of Uint8Array chunks to be parsed.
* @returns {string} - Parsed string from the Bedrock buffer.
*/
const parseBedrockBuffer = (chunks: Uint8Array[]): string => {
// Initialize an empty Uint8Array to store the concatenated buffer.
let bedrockBuffer: Uint8Array = new Uint8Array(0);

return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
// Map through each chunk to process the Bedrock buffer.
return chunks
.map((chunk) => {
// Concatenate the current chunk to the existing buffer.
bedrockBuffer = concatChunks(bedrockBuffer, chunk);
// Get the length of the next message in the buffer.
let messageLength = getMessageLength(bedrockBuffer);
// Initialize an array to store fully formed message chunks.
const buildChunks = [];
// Process the buffer until no complete messages are left.
while (bedrockBuffer.byteLength > 0 && bedrockBuffer.byteLength >= messageLength) {
// Extract a chunk of the specified length from the buffer.
const extractedChunk = bedrockBuffer.slice(0, messageLength);
// Add the extracted chunk to the array of fully formed message chunks.
buildChunks.push(extractedChunk);
// Remove the processed chunk from the buffer.
bedrockBuffer = bedrockBuffer.slice(messageLength);
// Get the length of the next message in the updated buffer.
messageLength = getMessageLength(bedrockBuffer);
}

const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);

// Decode and parse each message chunk, extracting the 'completion' property.
return buildChunks
.map((bChunk) => {
const event = awsDecoder.decode(bChunk);
const body = JSON.parse(
Buffer.from(JSON.parse(new TextDecoder().decode(event.body)).bytes, 'base64').toString()
);
return body.completion;
})
.join('');
})
.join('');
};

/**
* Concatenates two Uint8Array buffers.
*
* @param {Uint8Array} a - First buffer.
* @param {Uint8Array} b - Second buffer.
* @returns {Uint8Array} - Concatenated buffer.
*/
function concatChunks(a: Uint8Array, b: Uint8Array): Uint8Array {
const newBuffer = new Uint8Array(a.length + b.length);
// Copy the contents of the first buffer to the new buffer.
newBuffer.set(a);
// Copy the contents of the second buffer to the new buffer starting from the end of the first buffer.
newBuffer.set(b, a.length);
return newBuffer;
}

/**
* Gets the length of the next message from the buffer.
*
* @param {Uint8Array} buffer - Buffer containing the message.
* @returns {number} - Length of the next message.
*/
function getMessageLength(buffer: Uint8Array): number {
// If the buffer is empty, return 0.
if (buffer.byteLength === 0) return 0;
// Create a DataView to read the Uint32 value at the beginning of the buffer.
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
// Read and return the Uint32 value (message length).
return view.getUint32(0, false);
}

const parseOpenAIResponse = (responseBody: string) =>
responseBody
.split('\n')
.filter((line) => {
return line.startsWith('data: ') && !line.endsWith('[DONE]');
})
.map((line) => {
return JSON.parse(line.replace('data: ', ''));
})
.filter(
(
line
): line is {
choices: Array<{
delta: { content?: string; function_call?: { name?: string; arguments: string } };
}>;
} => {
return 'object' in line && line.object === 'chat.completion.chunk';
}
)
.reduce((prev, line) => {
const msg = line.choices[0].delta!;
prev += msg.content || '';
return prev;
}, '');
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ export const getComments = ({
regenerateMessage(currentConversation.id);
};

const connectorTypeTitle = currentConversation.apiConfig.connectorTypeTitle ?? '';

const extraLoadingComment = isFetchingResponse
? [
{
Expand All @@ -75,6 +77,7 @@ export const getComments = ({
children: (
<StreamComment
amendMessage={amendMessageOfConversation}
connectorTypeTitle={connectorTypeTitle}
content=""
regenerateMessage={regenerateMessageOfConversation}
isLastComment
Expand Down Expand Up @@ -122,6 +125,7 @@ export const getComments = ({
children: (
<StreamComment
amendMessage={amendMessageOfConversation}
connectorTypeTitle={connectorTypeTitle}
index={index}
isLastComment={isLastComment}
isError={message.isError}
Expand All @@ -142,6 +146,7 @@ export const getComments = ({
children: (
<StreamComment
amendMessage={amendMessageOfConversation}
connectorTypeTitle={connectorTypeTitle}
content={transformedMessage.content}
index={index}
isLastComment={isLastComment}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const testProps = {
content,
index: 1,
isLastComment: true,
connectorTypeTitle: 'OpenAI',
regenerateMessage: jest.fn(),
transformMessage: jest.fn(),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ interface Props {
isFetching?: boolean;
isLastComment: boolean;
index: number;
connectorTypeTitle: string;
reader?: ReadableStreamDefaultReader<Uint8Array>;
regenerateMessage: () => void;
transformMessage: (message: string) => ContentMessage;
Expand All @@ -29,6 +30,7 @@ interface Props {
export const StreamComment = ({
amendMessage,
content,
connectorTypeTitle,
index,
isError = false,
isFetching = false,
Expand All @@ -40,6 +42,7 @@ export const StreamComment = ({
const { error, isLoading, isStreaming, pendingMessage, setComplete } = useStream({
amendMessage,
content,
connectorTypeTitle,
reader,
isError,
});
Expand Down
Loading

0 comments on commit 6905a0f

Please sign in to comment.