From d20161030d81e7bde30af63d921a95a300c6b3ca Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 16 Nov 2023 15:49:53 -0700 Subject: [PATCH] [Security solution] Bedrock streaming and token tracking (#170815) --- package.json | 2 + .../impl/assistant/api.test.tsx | 51 ++++- .../impl/assistant/api.tsx | 26 ++- .../actions/server/lib/action_executor.ts | 76 +++---- .../server/lib/gen_ai_token_tracking.test.ts | 183 +++++++++++++++++ .../server/lib/gen_ai_token_tracking.ts | 123 ++++++++++++ ...et_token_count_from_bedrock_invoke.test.ts | 27 +++ .../get_token_count_from_bedrock_invoke.ts | 46 +++++ ...get_token_count_from_invoke_stream.test.ts | 93 +++++++++ .../lib/get_token_count_from_invoke_stream.ts | 68 +++++++ ...get_token_count_from_openai_stream.test.ts | 6 + .../lib/get_token_count_from_openai_stream.ts | 7 +- .../elastic_assistant/server/lib/executor.ts | 10 +- .../public/assistant/get_comments/index.tsx | 1 + .../assistant/get_comments/stream/index.tsx | 3 + .../stream/stream_observable.test.ts | 107 ++++++---- .../get_comments/stream/stream_observable.ts | 48 ++--- .../assistant/get_comments/stream/types.ts | 14 +- .../get_comments/stream/use_stream.test.tsx | 30 +-- .../get_comments/stream/use_stream.tsx | 13 +- .../assistant/get_comments/translations.ts | 4 + .../common/bedrock/constants.ts | 1 + .../stack_connectors/common/bedrock/schema.ts | 7 + .../stack_connectors/common/bedrock/types.ts | 4 + .../common/openai/constants.ts | 1 + .../connector_types/bedrock/bedrock.test.ts | 185 +++++++++++++++++- .../server/connector_types/bedrock/bedrock.ts | 149 +++++++++++--- .../connector_types/openai/openai.test.ts | 115 +++++++++++ .../server/connector_types/openai/openai.ts | 70 ++++++- .../server/bedrock_simulation.ts | 31 +++ .../tests/actions/connector_types/bedrock.ts | 38 ++++ yarn.lock | 88 ++++++++- 32 files changed, 1403 insertions(+), 224 deletions(-) create mode 100644 x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts create mode 100644 x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts create mode 100644 x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.test.ts create mode 100644 x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.ts create mode 100644 x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts create mode 100644 x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts diff --git a/package.json b/package.json index 2bc3b7c7e1fb9..fd9c91a03d08a 100644 --- a/package.json +++ b/package.json @@ -842,6 +842,8 @@ "@opentelemetry/semantic-conventions": "^1.4.0", "@reduxjs/toolkit": "1.7.2", "@slack/webhook": "^5.0.4", + "@smithy/eventstream-codec": "^2.0.12", + "@smithy/util-utf8": "^2.0.0", "@tanstack/react-query": "^4.29.12", "@tanstack/react-query-devtools": "^4.29.12", "@turf/along": "6.0.1", diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx index b1d9145a9e612..b8ee12525c68a 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx @@ -75,19 +75,20 @@ describe('API tests', () => { expect(mockHttp.fetch).toHaveBeenCalledWith( '/internal/elastic_assistant/actions/connector/foo/_execute', { - body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}', - headers: { 'Content-Type': 'application/json' }, + body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeStream"},"assistantLangChain":false}', method: 'POST', + asResponse: true, + rawResponse: true, signal: undefined, } ); }); - it('returns API_ERROR when the response status is not ok', async () => { + it('returns API_ERROR when the response status is error and langchain is on', async () => { (mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' }); const testProps: FetchConnectorExecuteAction = { - assistantLangChain: false, + assistantLangChain: true, http: mockHttp, messages, apiConfig, @@ -98,10 +99,50 @@ describe('API tests', () => { expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true }); }); + it('returns API_ERROR when the response status is error, langchain is off, and response is not a reader', async () => { + (mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' }); + + const testProps: FetchConnectorExecuteAction = { + assistantLangChain: false, + http: mockHttp, + messages, + apiConfig, + }; + + const result = await fetchConnectorExecuteAction(testProps); + + expect(result).toEqual({ + response: `${API_ERROR}\n\nCould not get reader from response`, + isStream: false, + isError: true, + }); + }); + + it('returns API_ERROR when the response is error, langchain is off, and response is a reader', async () => { + const mockReader = jest.fn(); + (mockHttp.fetch as jest.Mock).mockRejectedValue({ + response: { body: { getReader: jest.fn().mockImplementation(() => mockReader) } }, + }); + const testProps: FetchConnectorExecuteAction = { + assistantLangChain: false, + http: mockHttp, + messages, + apiConfig, + }; + + const result = await fetchConnectorExecuteAction(testProps); + + expect(result).toEqual({ + response: mockReader, + isStream: true, + isError: true, + }); + }); + it('returns API_ERROR when there are no choices', async () => { (mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'ok', data: '' }); const testProps: FetchConnectorExecuteAction = { - assistantLangChain: false, + assistantLangChain: true, http: mockHttp, messages, apiConfig, diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index 69e6d39d85e11..f92585cbdd011 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -54,17 +54,16 @@ export const fetchConnectorExecuteAction = async ({ messages: outboundMessages, }; - // TODO: Remove in part 2 of streaming work for security solution + // TODO: Remove in part 3 of streaming work for security solution // tracked here: https://github.com/elastic/security-team/issues/7363 - // My "Feature Flag", turn to false before merging - // In part 2 I will make enhancements to invokeAI to make it work with both openA, but to keep it to a Security Soltuion only review on this PR, - // I'm calling the stream action directly - const isStream = !assistantLangChain && false; + // In part 3 I will make enhancements to langchain to introduce streaming + // Once implemented, invokeAI can be removed + const isStream = !assistantLangChain; const requestBody = isStream ? { params: { subActionParams: body, - subAction: 'stream', + subAction: 'invokeStream', }, assistantLangChain, } @@ -105,7 +104,7 @@ export const fetchConnectorExecuteAction = async ({ }; } - // TODO: Remove in part 2 of streaming work for security solution + // TODO: Remove in part 3 of streaming work for security solution // tracked here: https://github.com/elastic/security-team/issues/7363 // This is a temporary code to support the non-streaming API const response = await http.fetch<{ @@ -140,10 +139,19 @@ export const fetchConnectorExecuteAction = async ({ isStream: false, }; } catch (error) { + const reader = error?.response?.body?.getReader(); + + if (!reader) { + return { + response: `${API_ERROR}\n\n${error?.body?.message ?? error?.message}`, + isError: true, + isStream: false, + }; + } return { - response: `${API_ERROR}\n\n${error?.body?.message ?? error?.message}`, + response: reader, + isStream: true, isError: true, - isStream: false, }; } }; diff --git a/x-pack/plugins/actions/server/lib/action_executor.ts b/x-pack/plugins/actions/server/lib/action_executor.ts index 9cd70d4c7bf91..b27e2c1ee79c7 100644 --- a/x-pack/plugins/actions/server/lib/action_executor.ts +++ b/x-pack/plugins/actions/server/lib/action_executor.ts @@ -8,12 +8,13 @@ import type { PublicMethodsOf } from '@kbn/utility-types'; import { Logger, KibanaRequest } from '@kbn/core/server'; import { cloneDeep } from 'lodash'; +import { set } from '@kbn/safer-lodash-set'; import { withSpan } from '@kbn/apm-utils'; import { EncryptedSavedObjectsClient } from '@kbn/encrypted-saved-objects-plugin/server'; import { SpacesServiceStart } from '@kbn/spaces-plugin/server'; import { IEventLogger, SAVED_OBJECT_REL_PRIMARY } from '@kbn/event-log-plugin/server'; import { SecurityPluginStart } from '@kbn/security-plugin/server'; -import { PassThrough, Readable } from 'stream'; +import { getGenAiTokenTracking, shouldTrackGenAiToken } from './gen_ai_token_tracking'; import { validateParams, validateConfig, @@ -38,7 +39,6 @@ import { RelatedSavedObjects } from './related_saved_objects'; import { createActionEventLogRecordObject } from './create_action_event_log_record_object'; import { ActionExecutionError, ActionExecutionErrorReason } from './errors/action_execution_error'; import type { ActionsAuthorization } from '../authorization/actions_authorization'; -import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream'; // 1,000,000 nanoseconds in 1 millisecond const Millis2Nanos = 1000 * 1000; @@ -328,55 +328,33 @@ export class ActionExecutor { eventLogger.logEvent(event); } - // start openai extension - // add event.kibana.action.execution.openai to event log when OpenAI Connector is executed - if (result.status === 'ok' && actionTypeId === '.gen-ai') { - const data = result.data as unknown as { - usage: { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number }; - }; - event.kibana = event.kibana || {}; - event.kibana.action = event.kibana.action || {}; - event.kibana = { - ...event.kibana, - action: { - ...event.kibana.action, - execution: { - ...event.kibana.action.execution, - gen_ai: { - usage: { - total_tokens: data.usage?.total_tokens, - prompt_tokens: data.usage?.prompt_tokens, - completion_tokens: data.usage?.completion_tokens, - }, - }, - }, - }, - }; - - if (result.data instanceof Readable) { - getTokenCountFromOpenAIStream({ - responseStream: result.data.pipe(new PassThrough()), - body: (validatedParams as { subActionParams: { body: string } }).subActionParams.body, + // start genai extension + if (result.status === 'ok' && shouldTrackGenAiToken(actionTypeId)) { + getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }) + .then((tokenTracking) => { + if (tokenTracking != null) { + set(event, 'kibana.action.execution.gen_ai.usage', { + total_tokens: tokenTracking.total_tokens, + prompt_tokens: tokenTracking.prompt_tokens, + completion_tokens: tokenTracking.completion_tokens, + }); + } }) - .then(({ total, prompt, completion }) => { - event.kibana!.action!.execution!.gen_ai!.usage = { - total_tokens: total, - prompt_tokens: prompt, - completion_tokens: completion, - }; - }) - .catch((err) => { - logger.error('Failed to calculate tokens from streaming response'); - logger.error(err); - }) - .finally(() => { - completeEventLogging(); - }); - - return resultWithoutError; - } + .catch((err) => { + logger.error('Failed to calculate tokens from streaming response'); + logger.error(err); + }) + .finally(() => { + completeEventLogging(); + }); + return resultWithoutError; } - // end openai extension + // end genai extension completeEventLogging(); diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts new file mode 100644 index 0000000000000..22f91b71d4492 --- /dev/null +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { getGenAiTokenTracking, shouldTrackGenAiToken } from './gen_ai_token_tracking'; +import { loggerMock } from '@kbn/logging-mocks'; +import { getTokenCountFromBedrockInvoke } from './get_token_count_from_bedrock_invoke'; +import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream'; +import { IncomingMessage } from 'http'; +import { Socket } from 'net'; + +jest.mock('./get_token_count_from_bedrock_invoke'); +jest.mock('./get_token_count_from_invoke_stream'); + +const logger = loggerMock.create(); + +describe('getGenAiTokenTracking', () => { + let mockGetTokenCountFromBedrockInvoke: jest.Mock; + let mockGetTokenCountFromInvokeStream: jest.Mock; + beforeEach(() => { + mockGetTokenCountFromBedrockInvoke = ( + getTokenCountFromBedrockInvoke as jest.Mock + ).mockResolvedValueOnce({ + total: 100, + prompt: 50, + completion: 50, + }); + mockGetTokenCountFromInvokeStream = ( + getTokenCountFromInvokeStream as jest.Mock + ).mockResolvedValueOnce({ + total: 100, + prompt: 50, + completion: 50, + }); + }); + it('should return the total, prompt, and completion token counts when given a valid OpenAI response', async () => { + const actionTypeId = '.gen-ai'; + + const result = { + actionId: '123', + status: 'ok' as const, + data: { + usage: { + total_tokens: 100, + prompt_tokens: 50, + completion_tokens: 50, + }, + }, + }; + const validatedParams = {}; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toEqual({ + total_tokens: 100, + prompt_tokens: 50, + completion_tokens: 50, + }); + expect(logger.error).not.toHaveBeenCalled(); + }); + + it('should return the total, prompt, and completion token counts when given a valid Bedrock response', async () => { + const actionTypeId = '.bedrock'; + + const result = { + actionId: '123', + status: 'ok' as const, + data: { + completion: 'Sample completion', + }, + }; + const validatedParams = { + subAction: 'run', + subActionParams: { + body: 'Sample body', + }, + }; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toEqual({ + total_tokens: 100, + prompt_tokens: 50, + completion_tokens: 50, + }); + expect(logger.error).not.toHaveBeenCalled(); + expect(mockGetTokenCountFromBedrockInvoke).toHaveBeenCalledWith({ + response: 'Sample completion', + body: 'Sample body', + }); + }); + + it('should return the total, prompt, and completion token counts when given a valid OpenAI streamed response', async () => { + const mockReader = new IncomingMessage(new Socket()); + const actionTypeId = '.gen-ai'; + const result = { + actionId: '123', + status: 'ok' as const, + data: mockReader, + }; + const validatedParams = { + subAction: 'invokeStream', + subActionParams: { + messages: [ + { + role: 'user', + content: 'Sample message', + }, + ], + }, + }; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toEqual({ + total_tokens: 100, + prompt_tokens: 50, + completion_tokens: 50, + }); + expect(logger.error).not.toHaveBeenCalled(); + + expect(JSON.stringify(mockGetTokenCountFromInvokeStream.mock.calls[0][0].body)).toStrictEqual( + JSON.stringify({ + messages: [ + { + role: 'user', + content: 'Sample message', + }, + ], + }) + ); + }); + + it('should return null when given an invalid OpenAI response', async () => { + const actionTypeId = '.gen-ai'; + const result = { + actionId: '123', + status: 'ok' as const, + data: {}, + }; + const validatedParams = {}; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toBeNull(); + expect(logger.error).toHaveBeenCalled(); + }); + + describe('shouldTrackGenAiToken', () => { + it('should be true with OpenAI action', () => { + expect(shouldTrackGenAiToken('.gen-ai')).toEqual(true); + }); + it('should be true with bedrock action', () => { + expect(shouldTrackGenAiToken('.bedrock')).toEqual(true); + }); + it('should be false with any other action', () => { + expect(shouldTrackGenAiToken('.jira')).toEqual(false); + }); + }); +}); diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts new file mode 100644 index 0000000000000..7c104177ea36e --- /dev/null +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { PassThrough, Readable } from 'stream'; +import { Logger } from '@kbn/logging'; +import { getTokenCountFromBedrockInvoke } from './get_token_count_from_bedrock_invoke'; +import { ActionTypeExecutorRawResult } from '../../common'; +import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream'; +import { getTokenCountFromInvokeStream, InvokeBody } from './get_token_count_from_invoke_stream'; + +interface OwnProps { + actionTypeId: string; + logger: Logger; + result: ActionTypeExecutorRawResult; + validatedParams: Record; +} +/* + * Calculates the total, prompt, and completion token counts from different types of responses. + * It handles both streamed and non-streamed responses from OpenAI and Bedrock. + * It returns null if it cannot calculate the token counts. + * @param actionTypeId the action type id + * @param logger the logger + * @param result the result from the action executor + * @param validatedParams the validated params from the action executor + */ +export const getGenAiTokenTracking = async ({ + actionTypeId, + logger, + result, + validatedParams, +}: OwnProps): Promise<{ + total_tokens: number; + prompt_tokens: number; + completion_tokens: number; +} | null> => { + // this is a streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string + if (validatedParams.subAction === 'invokeStream' && result.data instanceof Readable) { + try { + const { total, prompt, completion } = await getTokenCountFromInvokeStream({ + responseStream: result.data.pipe(new PassThrough()), + body: (validatedParams as { subActionParams: InvokeBody }).subActionParams, + logger, + }); + return { + total_tokens: total, + prompt_tokens: prompt, + completion_tokens: completion, + }; + } catch (e) { + logger.error('Failed to calculate tokens from Invoke Stream subaction streaming response'); + logger.error(e); + } + } + + // this is a streamed OpenAI response, which did not use the subAction invokeStream + if (actionTypeId === '.gen-ai' && result.data instanceof Readable) { + try { + const { total, prompt, completion } = await getTokenCountFromOpenAIStream({ + responseStream: result.data.pipe(new PassThrough()), + body: (validatedParams as { subActionParams: { body: string } }).subActionParams.body, + logger, + }); + return { + total_tokens: total, + prompt_tokens: prompt, + completion_tokens: completion, + }; + } catch (e) { + logger.error('Failed to calculate tokens from streaming response'); + logger.error(e); + } + } + + // this is a non-streamed OpenAI response, which comes with the usage object + if (actionTypeId === '.gen-ai') { + const data = result.data as unknown as { + usage: { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number }; + }; + if (data.usage == null) { + logger.error('Response did not contain usage object'); + return null; + } + return { + total_tokens: data.usage?.total_tokens ?? 0, + prompt_tokens: data.usage?.prompt_tokens ?? 0, + completion_tokens: data.usage?.completion_tokens ?? 0, + }; + } + + // this is a non-streamed Bedrock response + if ( + actionTypeId === '.bedrock' && + (validatedParams.subAction === 'run' || validatedParams.subAction === 'test') + ) { + try { + const { total, prompt, completion } = await getTokenCountFromBedrockInvoke({ + response: ( + result.data as unknown as { + completion: string; + } + ).completion, + body: (validatedParams as { subActionParams: { body: string } }).subActionParams.body, + }); + + return { + total_tokens: total, + prompt_tokens: prompt, + completion_tokens: completion, + }; + } catch (e) { + logger.error('Failed to calculate tokens from Bedrock invoke response'); + logger.error(e); + } + } + return null; +}; + +export const shouldTrackGenAiToken = (actionTypeId: string) => + actionTypeId === '.gen-ai' || actionTypeId === '.bedrock'; diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.test.ts new file mode 100644 index 0000000000000..efb0c2cd0eaff --- /dev/null +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.test.ts @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { getTokenCountFromBedrockInvoke } from './get_token_count_from_bedrock_invoke'; + +describe('getTokenCountFromBedrockInvoke', () => { + const body = JSON.stringify({ + prompt: `\n\nAssistant: This is a system message\n\nHuman: This is a user message\n\nAssistant:`, + }); + + const PROMPT_TOKEN_COUNT = 27; + const COMPLETION_TOKEN_COUNT = 4; + + it('counts the prompt tokens', async () => { + const tokens = await getTokenCountFromBedrockInvoke({ + response: 'This is a response', + body, + }); + expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); + expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); + expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); + }); +}); diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.ts new file mode 100644 index 0000000000000..26e320200830b --- /dev/null +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_invoke.ts @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { encode } from 'gpt-tokenizer'; + +export interface InvokeBody { + prompt: string; +} + +/** + * Takes the Bedrock `run` and `test` sub action response and the request prompt as inputs. + * Uses gpt-tokenizer encoding to calculate the number of tokens in the prompt and completion. + * Returns an object containing the total, prompt, and completion token counts. + * @param response (string) - the response completion from the `run` or `test` sub action + * @param body - the stringified request prompt + */ +export async function getTokenCountFromBedrockInvoke({ + response, + body, +}: { + response: string; + body: string; +}): Promise<{ + total: number; + prompt: number; + completion: number; +}> { + const chatCompletionRequest = JSON.parse(body) as InvokeBody; + + // per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + const tokensFromMessages = encode(`<|start|>${chatCompletionRequest.prompt}<|end|>`).length; + + const promptTokens = tokensFromMessages; + + const completionTokens = encode(response).length; + + return { + prompt: promptTokens, + completion: completionTokens, + total: promptTokens + completionTokens, + }; +} diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts new file mode 100644 index 0000000000000..3c0dd66130f3a --- /dev/null +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { Transform } from 'stream'; +import { getTokenCountFromInvokeStream } from './get_token_count_from_invoke_stream'; +import { loggerMock } from '@kbn/logging-mocks'; + +interface StreamMock { + write: (data: string) => void; + fail: () => void; + complete: () => void; + transform: Transform; +} + +function createStreamMock(): StreamMock { + const transform: Transform = new Transform({}); + + return { + write: (data: string) => { + transform.push(`${data}\n`); + }, + fail: () => { + transform.emit('error', new Error('Stream failed')); + transform.end(); + }, + transform, + complete: () => { + transform.end(); + }, + }; +} +const logger = loggerMock.create(); +describe('getTokenCountFromInvokeStream', () => { + let stream: StreamMock; + const body = { + messages: [ + { + role: 'system', + content: 'This is a system message', + }, + { + role: 'user', + content: 'This is a user message', + }, + ], + }; + + const PROMPT_TOKEN_COUNT = 34; + const COMPLETION_TOKEN_COUNT = 2; + + beforeEach(() => { + stream = createStreamMock(); + stream.write('Single'); + }); + + describe('when a stream completes', () => { + beforeEach(async () => { + stream.complete(); + }); + it('counts the prompt tokens', async () => { + const tokens = await getTokenCountFromInvokeStream({ + responseStream: stream.transform, + body, + logger, + }); + expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); + expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); + expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); + }); + }); + + describe('when a stream fails', () => { + it('resolves the promise with the correct prompt tokens', async () => { + const tokenPromise = getTokenCountFromInvokeStream({ + responseStream: stream.transform, + body, + logger, + }); + + stream.fail(); + + await expect(tokenPromise).resolves.toEqual({ + prompt: PROMPT_TOKEN_COUNT, + total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT, + completion: COMPLETION_TOKEN_COUNT, + }); + expect(logger.error).toHaveBeenCalled(); + }); + }); +}); diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts new file mode 100644 index 0000000000000..594fec89d93c0 --- /dev/null +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Logger } from '@kbn/logging'; +import { encode } from 'gpt-tokenizer'; +import { Readable } from 'stream'; +import { finished } from 'stream/promises'; + +export interface InvokeBody { + messages: Array<{ + role: string; + content: string; + }>; +} + +/** + * Takes the OpenAI and Bedrock `invokeStream` sub action response stream and the request messages array as inputs. + * Uses gpt-tokenizer encoding to calculate the number of tokens in the prompt and completion parts of the response stream + * Returns an object containing the total, prompt, and completion token counts. + * @param responseStream the response stream from the `invokeStream` sub action + * @param body the request messages array + * @param logger the logger + */ +export async function getTokenCountFromInvokeStream({ + responseStream, + body, + logger, +}: { + responseStream: Readable; + body: InvokeBody; + logger: Logger; +}): Promise<{ + total: number; + prompt: number; + completion: number; +}> { + const chatCompletionRequest = body; + + // per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + const promptTokens = encode( + chatCompletionRequest.messages + .map((msg) => `<|start|>${msg.role}\n${msg.content}<|end|>`) + .join('\n') + ).length; + + let responseBody: string = ''; + + responseStream.on('data', (chunk: string) => { + responseBody += chunk.toString(); + }); + try { + await finished(responseStream); + } catch (e) { + logger.error('An error occurred while calculating streaming response tokens'); + } + + const completionTokens = encode(responseBody).length; + + return { + prompt: promptTokens, + completion: completionTokens, + total: promptTokens + completionTokens, + }; +} diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts index 080b7cb5f972f..cc81706fc257c 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts @@ -6,6 +6,7 @@ */ import { Transform } from 'stream'; import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream'; +import { loggerMock } from '@kbn/logging-mocks'; interface StreamMock { write: (data: string) => void; @@ -32,6 +33,7 @@ function createStreamMock(): StreamMock { }; } +const logger = loggerMock.create(); describe('getTokenCountFromOpenAIStream', () => { let tokens: Awaited>; let stream: StreamMock; @@ -77,6 +79,7 @@ describe('getTokenCountFromOpenAIStream', () => { beforeEach(async () => { tokens = await getTokenCountFromOpenAIStream({ responseStream: stream.transform, + logger, body: JSON.stringify(body), }); }); @@ -92,6 +95,7 @@ describe('getTokenCountFromOpenAIStream', () => { beforeEach(async () => { tokens = await getTokenCountFromOpenAIStream({ responseStream: stream.transform, + logger, body: JSON.stringify({ ...body, functions: [ @@ -123,6 +127,7 @@ describe('getTokenCountFromOpenAIStream', () => { it('resolves the promise with the correct prompt tokens', async () => { const tokenPromise = getTokenCountFromOpenAIStream({ responseStream: stream.transform, + logger, body: JSON.stringify(body), }); @@ -133,6 +138,7 @@ describe('getTokenCountFromOpenAIStream', () => { total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT, completion: COMPLETION_TOKEN_COUNT, }); + expect(logger.error).toHaveBeenCalled(); }); }); }); diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts index 74c89f716171e..0091faca468e3 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts @@ -10,13 +10,16 @@ import { isEmpty, omitBy } from 'lodash'; import { Readable } from 'stream'; import { finished } from 'stream/promises'; import { CreateChatCompletionRequest } from 'openai'; +import { Logger } from '@kbn/logging'; export async function getTokenCountFromOpenAIStream({ responseStream, body, + logger, }: { responseStream: Readable; body: string; + logger: Logger; }): Promise<{ total: number; prompt: number; @@ -65,8 +68,8 @@ export async function getTokenCountFromOpenAIStream({ try { await finished(responseStream); - } catch { - // no need to handle this explicitly + } catch (e) { + logger.error('An error occurred while calculating streaming response tokens'); } const response = responseBody diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.ts index 88266914f36ed..27064f3fb1961 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.ts @@ -31,15 +31,7 @@ export const executeAction = async ({ const actionResult = await actionsClient.execute({ actionId: connectorId, - params: { - ...request.body.params, - subActionParams: - // TODO: Remove in part 2 of streaming work for security solution - // tracked here: https://github.com/elastic/security-team/issues/7363 - request.body.params.subAction === 'invokeAI' - ? request.body.params.subActionParams - : { body: JSON.stringify(request.body.params.subActionParams), stream: true }, - }, + params: request.body.params, }); if (actionResult.status === 'error') { diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx index 9c547b3033112..3b778013a42d1 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/index.tsx @@ -124,6 +124,7 @@ export const getComments = ({ amendMessage={amendMessageOfConversation} index={index} isLastComment={isLastComment} + isError={message.isError} reader={message.reader} regenerateMessage={regenerateMessageOfConversation} transformMessage={transformMessage} diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.tsx index db394f39bfa32..219a8565481cf 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/index.tsx @@ -17,6 +17,7 @@ import { MessageText } from './message_text'; interface Props { amendMessage: (message: string) => void; content?: string; + isError?: boolean; isFetching?: boolean; isLastComment: boolean; index: number; @@ -29,6 +30,7 @@ export const StreamComment = ({ amendMessage, content, index, + isError = false, isFetching = false, isLastComment, reader, @@ -39,6 +41,7 @@ export const StreamComment = ({ amendMessage, content, reader, + isError, }); const currentState = useRef({ isStreaming, pendingMessage, amendMessage }); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts index 9a63621021cc3..764db1b3990ae 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.test.ts @@ -5,7 +5,8 @@ * 2.0. */ import { getStreamObservable } from './stream_observable'; -// import { getReaderValue, mockUint8Arrays } from './mock'; +import { API_ERROR } from '../translations'; + import type { PromptObservableState } from './types'; import { Subject } from 'rxjs'; describe('getStreamObservable', () => { @@ -27,41 +28,23 @@ describe('getStreamObservable', () => { const expectedStates: PromptObservableState[] = [ { chunks: [], loading: true }, { - chunks: [ - { - id: '1', - object: 'chunk', - created: 1635633600000, - model: 'model-1', - choices: [ - { - index: 0, - delta: { role: 'role-1', content: 'content-1' }, - finish_reason: null, - }, - ], - }, - ], - message: 'content-1', + chunks: ['one chunk ', 'another chunk', ''], + message: 'one chunk ', loading: true, }, { - chunks: [ - { - id: '1', - object: 'chunk', - created: 1635633600000, - model: 'model-1', - choices: [ - { - index: 0, - delta: { role: 'role-1', content: 'content-1' }, - finish_reason: null, - }, - ], - }, - ], - message: 'content-1', + chunks: ['one chunk ', 'another chunk', ''], + message: 'one chunk another chunk', + loading: true, + }, + { + chunks: ['one chunk ', 'another chunk', ''], + message: 'one chunk another chunk', + loading: true, + }, + { + chunks: ['one chunk ', 'another chunk', ''], + message: 'one chunk another chunk', loading: false, }, ]; @@ -69,23 +52,67 @@ describe('getStreamObservable', () => { mockReader.read .mockResolvedValueOnce({ done: false, - value: new Uint8Array( - new TextEncoder().encode(`data: ${JSON.stringify(expectedStates[1].chunks[0])}`) - ), + value: new Uint8Array(new TextEncoder().encode(`one chunk `)), + }) + .mockResolvedValueOnce({ + done: false, + value: new Uint8Array(new TextEncoder().encode(`another chunk`)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(``)), + value: new Uint8Array(new TextEncoder().encode('')), }) + .mockResolvedValue({ + done: true, + }); + + const source = getStreamObservable(typedReader, setLoading, false); + const emittedStates: PromptObservableState[] = []; + + source.subscribe({ + next: (state) => emittedStates.push(state), + complete: () => { + expect(emittedStates).toEqual(expectedStates); + done(); + + completeSubject.subscribe({ + next: () => { + expect(setLoading).toHaveBeenCalledWith(false); + expect(typedReader.cancel).toHaveBeenCalled(); + done(); + }, + }); + }, + error: (err) => done(err), + }); + }); + + it('should stream errors when reader contains errors', (done) => { + const completeSubject = new Subject(); + const expectedStates: PromptObservableState[] = [ + { chunks: [], loading: true }, + { + chunks: [`${API_ERROR}\n\nis an error`], + message: `${API_ERROR}\n\nis an error`, + loading: true, + }, + { + chunks: [`${API_ERROR}\n\nis an error`], + message: `${API_ERROR}\n\nis an error`, + loading: false, + }, + ]; + + mockReader.read .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode('data: [DONE]\n')), + value: new Uint8Array(new TextEncoder().encode(JSON.stringify({ message: 'is an error' }))), }) .mockResolvedValue({ done: true, }); - const source = getStreamObservable(typedReader, setLoading); + const source = getStreamObservable(typedReader, setLoading, true); const emittedStates: PromptObservableState[] = []; source.subscribe({ @@ -111,7 +138,7 @@ describe('getStreamObservable', () => { const error = new Error('Test Error'); // Simulate an error mockReader.read.mockRejectedValue(error); - const source = getStreamObservable(typedReader, setLoading); + const source = getStreamObservable(typedReader, setLoading, false); source.subscribe({ next: (state) => {}, diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index 83f9b4cf8ead3..b30be69b82cae 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -5,26 +5,29 @@ * 2.0. */ -import { concatMap, delay, finalize, Observable, of, scan, shareReplay, timestamp } from 'rxjs'; +import { concatMap, delay, finalize, Observable, of, scan, timestamp } from 'rxjs'; import type { Dispatch, SetStateAction } from 'react'; -import type { PromptObservableState, Chunk } from './types'; - +import { API_ERROR } from '../translations'; +import type { PromptObservableState } from './types'; const MIN_DELAY = 35; + /** * Returns an Observable that reads data from a ReadableStream and emits values representing the state of the data processing. * * @param reader - The ReadableStreamDefaultReader used to read data from the stream. * @param setLoading - A function to update the loading state. + * @param isError - indicates whether the reader response is an error message or not * @returns {Observable} An Observable that emits PromptObservableState */ export const getStreamObservable = ( reader: ReadableStreamDefaultReader, - setLoading: Dispatch> + setLoading: Dispatch>, + isError: boolean ): Observable => new Observable((observer) => { observer.next({ chunks: [], loading: true }); const decoder = new TextDecoder(); - const chunks: Chunk[] = []; + const chunks: string[] = []; function read() { reader .read() @@ -39,23 +42,17 @@ export const getStreamObservable = ( observer.complete(); return; } - - const nextChunks: Chunk[] = decoder - .decode(value) - .split('\n') - // every line starts with "data: ", we remove it and are left with stringified JSON or the string "[DONE]" - .map((str) => str.substring(6)) - // filter out empty lines and the "[DONE]" string - .filter((str) => !!str && str !== '[DONE]') - .map((line) => JSON.parse(line)); - - nextChunks.forEach((chunk) => { - chunks.push(chunk); - observer.next({ - chunks, - message: getMessageFromChunks(chunks), - loading: true, - }); + const decoded = decoder.decode(value); + const content = isError + ? // we format errors as {message: string; status_code: number} + `${API_ERROR}\n\n${JSON.parse(decoded).message}` + : // all other responses are just strings (handled by subaction invokeStream) + decoded; + chunks.push(content); + observer.next({ + chunks, + message: getMessageFromChunks(chunks), + loading: true, }); } catch (err) { observer.error(err); @@ -72,9 +69,6 @@ export const getStreamObservable = ( reader.cancel(); }; }).pipe( - // make sure the request is only triggered once, - // even with multiple subscribers - shareReplay(1), // append a timestamp of when each value was emitted timestamp(), // use the previous timestamp to calculate a target @@ -105,8 +99,8 @@ export const getStreamObservable = ( finalize(() => setLoading(false)) ); -function getMessageFromChunks(chunks: Chunk[]) { - return chunks.map((chunk) => chunk.choices[0]?.delta.content ?? '').join(''); +function getMessageFromChunks(chunks: string[]) { + return chunks.join(''); } export const getPlaceholderObservable = () => new Observable(); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/types.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/types.ts index 3cf45852ddb11..80ef5e4ae6eda 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/types.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/types.ts @@ -6,20 +6,8 @@ */ export interface PromptObservableState { - chunks: Chunk[]; + chunks: string[]; message?: string; error?: string; loading: boolean; } -export interface ChunkChoice { - index: 0; - delta: { role: string; content: string }; - finish_reason: null | string; -} -export interface Chunk { - id: string; - object: string; - created: number; - model: string; - choices: ChunkChoice[]; -} diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx index 4fbecfac870e1..efbc61999f2cc 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.test.tsx @@ -11,32 +11,20 @@ import { useStream } from './use_stream'; const amendMessage = jest.fn(); const reader = jest.fn(); const cancel = jest.fn(); -const exampleChunk = { - id: '1', - object: 'chunk', - created: 1635633600000, - model: 'model-1', - choices: [ - { - index: 0, - delta: { role: 'role-1', content: 'content-1' }, - finish_reason: null, - }, - ], -}; + const readerComplete = { read: reader .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`data: ${JSON.stringify(exampleChunk)}`)), + value: new Uint8Array(new TextEncoder().encode('one chunk ')), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(``)), + value: new Uint8Array(new TextEncoder().encode(`another chunk`)), }) .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode('data: [DONE]\n')), + value: new Uint8Array(new TextEncoder().encode(``)), }) .mockResolvedValue({ done: true, @@ -46,7 +34,7 @@ const readerComplete = { closed: jest.fn().mockResolvedValue(true), } as unknown as ReadableStreamDefaultReader; -const defaultProps = { amendMessage, reader: readerComplete }; +const defaultProps = { amendMessage, reader: readerComplete, isError: false }; describe('useStream', () => { beforeEach(() => { jest.clearAllMocks(); @@ -69,7 +57,7 @@ describe('useStream', () => { error: undefined, isLoading: true, isStreaming: true, - pendingMessage: 'content-1', + pendingMessage: 'one chunk ', setComplete: expect.any(Function), }); }); @@ -79,7 +67,7 @@ describe('useStream', () => { error: undefined, isLoading: false, isStreaming: false, - pendingMessage: 'content-1', + pendingMessage: 'one chunk another chunk', setComplete: expect.any(Function), }); }); @@ -104,7 +92,7 @@ describe('useStream', () => { .fn() .mockResolvedValueOnce({ done: false, - value: new Uint8Array(new TextEncoder().encode(`data: ${JSON.stringify(exampleChunk)}`)), + value: new Uint8Array(new TextEncoder().encode(`one chunk`)), }) .mockRejectedValue(new Error(errorMessage)), cancel, @@ -113,7 +101,7 @@ describe('useStream', () => { } as unknown as ReadableStreamDefaultReader; const { result, waitForNextUpdate } = renderHook(() => useStream({ - amendMessage, + ...defaultProps, reader: errorReader, }) ); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx index 148338f2afafa..7de06589f87c7 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/use_stream.tsx @@ -12,6 +12,7 @@ import { getPlaceholderObservable, getStreamObservable } from './stream_observab interface UseStreamProps { amendMessage: (message: string) => void; + isError: boolean; content?: string; reader?: ReadableStreamDefaultReader; } @@ -33,8 +34,14 @@ interface UseStream { * @param amendMessage - handles the amended message * @param content - the content of the message. If provided, the function will not use the reader to stream data. * @param reader - The readable stream reader used to stream data. If provided, the function will use this reader to stream data. + * @param isError - indicates whether the reader response is an error message or not */ -export const useStream = ({ amendMessage, content, reader }: UseStreamProps): UseStream => { +export const useStream = ({ + amendMessage, + content, + reader, + isError, +}: UseStreamProps): UseStream => { const [pendingMessage, setPendingMessage] = useState(); const [loading, setLoading] = useState(false); const [error, setError] = useState(); @@ -42,9 +49,9 @@ export const useStream = ({ amendMessage, content, reader }: UseStreamProps): Us const observer$ = useMemo( () => content == null && reader != null - ? getStreamObservable(reader, setLoading) + ? getStreamObservable(reader, setLoading, isError) : getPlaceholderObservable(), - [content, reader] + [content, isError, reader] ); const onCompleteStream = useCallback(() => { subscription?.unsubscribe(); diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/translations.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/translations.ts index 2b83d580ef2cd..fbccef68f7398 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/translations.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/translations.ts @@ -20,3 +20,7 @@ export const AT = (timestamp: string) => export const YOU = i18n.translate('xpack.securitySolution.assistant.getComments.you', { defaultMessage: 'You', }); + +export const API_ERROR = i18n.translate('xpack.securitySolution.assistant.apiErrorTitle', { + defaultMessage: 'An error occurred sending your message.', +}); diff --git a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts index 8071091194049..ff165f6678db9 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts @@ -17,6 +17,7 @@ export const BEDROCK_CONNECTOR_ID = '.bedrock'; export enum SUB_ACTION { RUN = 'run', INVOKE_AI = 'invokeAI', + INVOKE_STREAM = 'invokeStream', TEST = 'test', } diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index 64699253c709f..6fbc0252eb61b 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -38,6 +38,11 @@ export const InvokeAIActionResponseSchema = schema.object({ message: schema.string(), }); +export const StreamActionParamsSchema = schema.object({ + body: schema.string(), + model: schema.maybe(schema.string()), +}); + export const RunActionResponseSchema = schema.object( { completion: schema.string(), @@ -45,3 +50,5 @@ export const RunActionResponseSchema = schema.object( }, { unknowns: 'ignore' } ); + +export const StreamingResponseSchema = schema.any(); diff --git a/x-pack/plugins/stack_connectors/common/bedrock/types.ts b/x-pack/plugins/stack_connectors/common/bedrock/types.ts index c6fad07cdba37..3d9fada237987 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/types.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/types.ts @@ -13,6 +13,8 @@ import { RunActionResponseSchema, InvokeAIActionParamsSchema, InvokeAIActionResponseSchema, + StreamActionParamsSchema, + StreamingResponseSchema, } from './schema'; export type Config = TypeOf; @@ -21,3 +23,5 @@ export type RunActionParams = TypeOf; export type InvokeAIActionParams = TypeOf; export type InvokeAIActionResponse = TypeOf; export type RunActionResponse = TypeOf; +export type StreamActionParams = TypeOf; +export type StreamingResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/common/openai/constants.ts b/x-pack/plugins/stack_connectors/common/openai/constants.ts index db01f52d762cf..5cf6ecb659f8a 100644 --- a/x-pack/plugins/stack_connectors/common/openai/constants.ts +++ b/x-pack/plugins/stack_connectors/common/openai/constants.ts @@ -17,6 +17,7 @@ export const OPENAI_CONNECTOR_ID = '.gen-ai'; export enum SUB_ACTION { RUN = 'run', INVOKE_AI = 'invokeAI', + INVOKE_STREAM = 'invokeStream', STREAM = 'stream', DASHBOARD = 'getDashboard', TEST = 'test', diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts index dcd3d70f9b4ff..708e8cd4e0364 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts @@ -4,12 +4,16 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ - +import aws from 'aws4'; +import { Transform } from 'stream'; import { BedrockConnector } from './bedrock'; +import { waitFor } from '@testing-library/react'; import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; -import { RunActionResponseSchema } from '../../../common/bedrock/schema'; +import { RunActionResponseSchema, StreamingResponseSchema } from '../../../common/bedrock/schema'; import { BEDROCK_CONNECTOR_ID, DEFAULT_BEDROCK_MODEL, @@ -19,10 +23,8 @@ import { import { DEFAULT_BODY } from '../../../public/connector_types/bedrock/constants'; import { AxiosError } from 'axios'; -jest.mock('aws4', () => ({ - sign: () => ({ signed: true }), -})); - +// @ts-ignore +const mockSigner = jest.spyOn(aws, 'sign').mockReturnValue({ signed: true }); describe('BedrockConnector', () => { let mockRequest: jest.Mock; let mockError: jest.Mock; @@ -35,6 +37,7 @@ describe('BedrockConnector', () => { }, }; beforeEach(() => { + jest.clearAllMocks(); mockRequest = jest.fn().mockResolvedValue(mockResponse); mockError = jest.fn().mockImplementation(() => { throw new Error('API Error'); @@ -53,14 +56,29 @@ describe('BedrockConnector', () => { logger: loggingSystemMock.createLogger(), services: actionsMock.createServices(), }); - beforeEach(() => { // @ts-ignore connector.request = mockRequest; - jest.clearAllMocks(); }); describe('runApi', () => { + it('the aws signature has non-streaming headers', async () => { + await connector.runApi({ body: DEFAULT_BODY }); + + expect(mockSigner).toHaveBeenCalledWith( + { + body: '{"prompt":"\\n\\nHuman: Hello world! \\n\\nAssistant:","max_tokens_to_sample":8191,"stop_sequences":["\\n\\nHuman:"]}', + headers: { + Accept: '*/*', + 'Content-Type': 'application/json', + }, + host: 'bedrock.us-east-1.amazonaws.com', + path: '/model/anthropic.claude-v2/invoke', + service: 'bedrock', + }, + { accessKeyId: '123', secretAccessKey: 'secret' } + ); + }); it('the Bedrock API call is successful with correct parameters', async () => { const response = await connector.runApi({ body: DEFAULT_BODY }); expect(mockRequest).toBeCalledTimes(1); @@ -83,6 +101,124 @@ describe('BedrockConnector', () => { }); }); + describe('invokeStream', () => { + let stream; + beforeEach(() => { + stream = createStreamMock(); + stream.write(encodeBedrockResponse(mockResponseString)); + mockRequest = jest.fn().mockResolvedValue({ ...mockResponse, data: stream.transform }); + // @ts-ignore + connector.request = mockRequest; + }); + + const aiAssistantBody = { + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }; + + it('the aws signature has streaming headers', async () => { + await connector.invokeStream(aiAssistantBody); + + expect(mockSigner).toHaveBeenCalledWith( + { + body: JSON.stringify({ + prompt: '\n\nHuman:Hello world \n\nAssistant:', + max_tokens_to_sample: DEFAULT_TOKEN_LIMIT, + temperature: 0.5, + stop_sequences: ['\n\nHuman:'], + }), + headers: { + accept: 'application/vnd.amazon.eventstream', + 'Content-Type': 'application/json', + 'x-amzn-bedrock-accept': '*/*', + }, + host: 'bedrock.us-east-1.amazonaws.com', + path: '/model/anthropic.claude-v2/invoke-with-response-stream', + service: 'bedrock', + }, + { accessKeyId: '123', secretAccessKey: 'secret' } + ); + }); + + it('the API call is successful with correct request parameters', async () => { + await connector.invokeStream(aiAssistantBody); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith({ + signed: true, + url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke-with-response-stream`, + method: 'post', + responseSchema: StreamingResponseSchema, + responseType: 'stream', + data: JSON.stringify({ + prompt: '\n\nHuman:Hello world \n\nAssistant:', + max_tokens_to_sample: DEFAULT_TOKEN_LIMIT, + temperature: 0.5, + stop_sequences: ['\n\nHuman:'], + }), + }); + }); + + it('formats messages from user, assistant, and system', async () => { + await connector.invokeStream({ + messages: [ + { + role: 'user', + content: 'Hello world', + }, + { + role: 'system', + content: 'Be a good chatbot', + }, + { + role: 'assistant', + content: 'Hi, I am a good chatbot', + }, + { + role: 'user', + content: 'What is 2+2?', + }, + ], + }); + expect(mockRequest).toHaveBeenCalledWith({ + signed: true, + responseType: 'stream', + url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke-with-response-stream`, + method: 'post', + responseSchema: StreamingResponseSchema, + data: JSON.stringify({ + prompt: + '\n\nHuman:Hello world\n\nHuman:Be a good chatbot\n\nAssistant:Hi, I am a good chatbot\n\nHuman:What is 2+2? \n\nAssistant:', + max_tokens_to_sample: DEFAULT_TOKEN_LIMIT, + temperature: 0.5, + stop_sequences: ['\n\nHuman:'], + }), + }); + }); + + it('transforms the response into a string', async () => { + const response = await connector.invokeStream(aiAssistantBody); + + let responseBody: string = ''; + response.on('data', (data: string) => { + responseBody += data.toString(); + }); + await waitFor(() => { + expect(responseBody).toEqual(mockResponseString); + }); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + connector.request = mockError; + + await expect(connector.invokeStream(aiAssistantBody)).rejects.toThrow('API Error'); + }); + }); + describe('invokeAI', () => { const aiAssistantBody = { messages: [ @@ -112,7 +248,7 @@ describe('BedrockConnector', () => { expect(response.message).toEqual(mockResponseString); }); - it('Properly formats messages from user, assistant, and system', async () => { + it('formats messages from user, assistant, and system', async () => { const response = await connector.invokeAI({ messages: [ { @@ -210,3 +346,34 @@ describe('BedrockConnector', () => { }); }); }); + +function createStreamMock() { + const transform: Transform = new Transform({}); + + return { + write: (data: Uint8Array) => { + transform.push(data); + }, + fail: () => { + transform.emit('error', new Error('Stream failed')); + transform.end(); + }, + transform, + complete: () => { + transform.end(); + }, + }; +} + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index ea0c72420b41e..70f8e121e1519 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -8,10 +8,15 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import aws from 'aws4'; import type { AxiosError } from 'axios'; +import { IncomingMessage } from 'http'; +import { PassThrough, Transform } from 'stream'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { RunActionParamsSchema, RunActionResponseSchema, InvokeAIActionParamsSchema, + StreamingResponseSchema, } from '../../../common/bedrock/schema'; import type { Config, @@ -20,8 +25,10 @@ import type { RunActionResponse, InvokeAIActionParams, InvokeAIActionResponse, + StreamActionParams, } from '../../../common/bedrock/types'; import { SUB_ACTION, DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants'; +import { StreamingResponse } from '../../../common/bedrock/types'; interface SignedRequest { host: string; @@ -55,12 +62,17 @@ export class BedrockConnector extends SubActionConnector { method: 'runApi', schema: RunActionParamsSchema, }); - this.registerSubAction({ name: SUB_ACTION.INVOKE_AI, method: 'invokeAI', schema: InvokeAIActionParamsSchema, }); + + this.registerSubAction({ + name: SUB_ACTION.INVOKE_STREAM, + method: 'invokeStream', + schema: InvokeAIActionParamsSchema, + }); } protected getResponseErrorMessage(error: AxiosError<{ message?: string }>): string { @@ -82,15 +94,21 @@ export class BedrockConnector extends SubActionConnector { * @param body The request body to be signed. * @param path The path of the request URL. */ - private signRequest(body: string, path: string) { + private signRequest(body: string, path: string, stream: boolean) { const { host } = new URL(this.url); return aws.sign( { host, - headers: { - 'Content-Type': 'application/json', - Accept: '*/*', - }, + headers: stream + ? { + accept: 'application/vnd.amazon.eventstream', + 'Content-Type': 'application/json', + 'x-amzn-bedrock-accept': '*/*', + } + : { + 'Content-Type': 'application/json', + Accept: '*/*', + }, body, path, // Despite AWS docs, this value does not always get inferred. We need to always send it @@ -110,11 +128,11 @@ export class BedrockConnector extends SubActionConnector { */ public async runApi({ body, model: reqModel }: RunActionParams): Promise { // set model on per request basis - const model = reqModel ? reqModel : this.model; - const signed = this.signRequest(body, `/model/${model}/invoke`); + const path = `/model/${reqModel ?? this.model}/invoke`; + const signed = this.signRequest(body, path, false); const response = await this.request({ ...signed, - url: `${this.url}/model/${model}/invoke`, + url: `${this.url}${path}`, method: 'post', responseSchema: RunActionResponseSchema, data: body, @@ -125,33 +143,104 @@ export class BedrockConnector extends SubActionConnector { } /** - * takes in an array of messages and a model as input, and returns a promise that resolves to a string. - * The method combines the messages into a single prompt formatted for bedrock,sends a request to the - * runApi method with the prompt and model, and returns the trimmed completion from the response. - * @param messages An array of message objects, where each object has a role (string) and content (string) property. + * NOT INTENDED TO BE CALLED DIRECTLY + * call invokeStream instead + * responsible for making a POST request to a specified URL with a given request body. + * The response is then processed based on whether it is a streaming response or a regular response. + * @param body The stringified request body to be sent in the POST request. * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. */ + private async streamApi({ + body, + model: reqModel, + }: StreamActionParams): Promise { + // set model on per request basis + const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`; + const signed = this.signRequest(body, path, true); + + const response = await this.request({ + ...signed, + url: `${this.url}${path}`, + method: 'post', + responseSchema: StreamingResponseSchema, + data: body, + responseType: 'stream', + }); + + return response.data.pipe(new PassThrough()); + } + + /** + * takes in an array of messages and a model as inputs. It calls the streamApi method to make a + * request to the Bedrock API with the formatted messages and model. It then returns a Transform stream + * that pipes the response from the API through the transformToString function, + * which parses the proprietary response into a string of the response text alone + * @param messages An array of messages to be sent to the API + * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. + */ + public async invokeStream({ messages, model }: InvokeAIActionParams): Promise { + const res = (await this.streamApi({ + body: JSON.stringify(formatBedrockBody({ messages })), + model, + })) as unknown as IncomingMessage; + return res.pipe(transformToString()); + } + + /** + * Deprecated. Use invokeStream instead. + * TODO: remove before 8.12 FF in part 3 of streaming work for security solution + * tracked here: https://github.com/elastic/security-team/issues/7363 + * No token tracking implemented for this method + */ public async invokeAI({ messages, model, }: InvokeAIActionParams): Promise { - const combinedMessages = messages.reduce((acc: string, message) => { - const { role, content } = message; - // Bedrock only has Assistant and Human, so 'system' and 'user' will be converted to Human - const bedrockRole = role === 'assistant' ? '\n\nAssistant:' : '\n\nHuman:'; - return `${acc}${bedrockRole}${content}`; - }, ''); - - const req = { - // end prompt in "Assistant:" to avoid the model starting its message with "Assistant:" - prompt: `${combinedMessages} \n\nAssistant:`, - max_tokens_to_sample: DEFAULT_TOKEN_LIMIT, - temperature: 0.5, - // prevent model from talking to itself - stop_sequences: ['\n\nHuman:'], - }; - - const res = await this.runApi({ body: JSON.stringify(req), model }); + const res = await this.runApi({ body: JSON.stringify(formatBedrockBody({ messages })), model }); return { message: res.completion.trim() }; } } + +const formatBedrockBody = ({ + messages, +}: { + messages: Array<{ role: string; content: string }>; +}) => { + const combinedMessages = messages.reduce((acc: string, message) => { + const { role, content } = message; + // Bedrock only has Assistant and Human, so 'system' and 'user' will be converted to Human + const bedrockRole = role === 'assistant' ? '\n\nAssistant:' : '\n\nHuman:'; + return `${acc}${bedrockRole}${content}`; + }, ''); + + return { + // end prompt in "Assistant:" to avoid the model starting its message with "Assistant:" + prompt: `${combinedMessages} \n\nAssistant:`, + max_tokens_to_sample: DEFAULT_TOKEN_LIMIT, + temperature: 0.5, + // prevent model from talking to itself + stop_sequences: ['\n\nHuman:'], + }; +}; + +/** + * Takes in a readable stream of data and returns a Transform stream that + * uses the AWS proprietary codec to parse the proprietary bedrock response into + * a string of the response text alone, returning the response string to the stream + */ +const transformToString = () => + new Transform({ + transform(chunk, encoding, callback) { + const encoder = new TextEncoder(); + const decoder = new EventStreamCodec(toUtf8, fromUtf8); + const event = decoder.decode(chunk); + const body = JSON.parse( + Buffer.from( + JSON.parse(new TextDecoder('utf-8').decode(event.body)).bytes, + 'base64' + ).toString() + ); + const newChunk = encoder.encode(body.completion); + callback(null, newChunk); + }, + }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts index e214047a5c6d5..7769dd8592faf 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts @@ -17,6 +17,8 @@ import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { RunActionResponseSchema, StreamingResponseSchema } from '../../../common/openai/schema'; import { initDashboard } from './create_dashboard'; +import { Transform } from 'stream'; +import { waitFor } from '@testing-library/react'; jest.mock('./create_dashboard'); describe('OpenAIConnector', () => { @@ -33,6 +35,9 @@ describe('OpenAIConnector', () => { role: 'assistant', content: mockResponseString, }, + delta: { + content: mockResponseString, + }, finish_reason: 'stop', index: 0, }, @@ -268,6 +273,98 @@ describe('OpenAIConnector', () => { }); }); + describe('invokeStream', () => { + const mockStream = ( + dataToStream: string[] = [ + 'data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}', + ] + ) => { + const streamMock = createStreamMock(); + dataToStream.forEach((chunk) => { + streamMock.write(chunk); + }); + streamMock.complete(); + mockRequest = jest.fn().mockResolvedValue({ ...mockResponse, data: streamMock.transform }); + return mockRequest; + }; + beforeEach(() => { + // @ts-ignore + connector.request = mockStream(); + }); + + it('the API call is successful with correct request parameters', async () => { + await connector.invokeStream(sampleOpenAiBody); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith({ + url: 'https://api.openai.com/v1/chat/completions', + method: 'post', + responseSchema: StreamingResponseSchema, + responseType: 'stream', + data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }), + headers: { + Authorization: 'Bearer 123', + 'content-type': 'application/json', + }, + }); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + connector.request = mockError; + + await expect(connector.invokeStream(sampleOpenAiBody)).rejects.toThrow('API Error'); + }); + + it('transforms the response into a string', async () => { + // @ts-ignore + connector.request = mockStream(); + const response = await connector.invokeStream(sampleOpenAiBody); + + let responseBody: string = ''; + response.on('data', (data: string) => { + responseBody += data.toString(); + }); + await waitFor(() => { + expect(responseBody).toEqual('My new'); + }); + }); + it('correctly buffers stream of json lines', async () => { + const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}`; + const chunk2 = `\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; + + // @ts-ignore + connector.request = mockStream([chunk1, chunk2]); + + const response = await connector.invokeStream(sampleOpenAiBody); + + let responseBody: string = ''; + response.on('data', (data: string) => { + responseBody += data.toString(); + }); + await waitFor(() => { + expect(responseBody).toEqual('My new message'); + }); + }); + it('correctly buffers partial lines', async () => { + const chunk1 = `data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"`; + + const chunk2 = `}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" message"}}]}\ndata: [DONE]`; + + // @ts-ignore + connector.request = mockStream([chunk1, chunk2]); + + const response = await connector.invokeStream(sampleOpenAiBody); + + let responseBody: string = ''; + response.on('data', (data: string) => { + responseBody += data.toString(); + }); + await waitFor(() => { + expect(responseBody).toEqual('My new message'); + }); + }); + }); + describe('invokeAI', () => { it('the API call is successful with correct parameters', async () => { const response = await connector.invokeAI(sampleOpenAiBody); @@ -598,3 +695,21 @@ describe('OpenAIConnector', () => { }); }); }); + +function createStreamMock() { + const transform: Transform = new Transform({}); + + return { + write: (data: string) => { + transform.push(data); + }, + fail: () => { + transform.emit('error', new Error('Stream failed')); + transform.end(); + }, + transform, + complete: () => { + transform.end(); + }, + }; +} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 7680cae94db9a..78fca4bd84198 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -7,6 +7,8 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import type { AxiosError } from 'axios'; +import { PassThrough, Transform } from 'stream'; +import { IncomingMessage } from 'http'; import { RunActionParamsSchema, RunActionResponseSchema, @@ -82,6 +84,12 @@ export class OpenAIConnector extends SubActionConnector { method: 'invokeAI', schema: InvokeAIActionParamsSchema, }); + + this.registerSubAction({ + name: SUB_ACTION.INVOKE_STREAM, + method: 'invokeStream', + schema: InvokeAIActionParamsSchema, + }); } protected getResponseErrorMessage(error: AxiosError<{ error?: { message?: string } }>): string { @@ -185,9 +193,24 @@ export class OpenAIConnector extends SubActionConnector { } /** - * takes an array of messages and a model as input and returns a promise that resolves to a string. - * Sends the stringified input to the runApi method. Returns the trimmed completion from the response. - * @param body An object containing array of message objects, and possible other OpenAI properties + * Responsible for invoking the streamApi method with the provided body and + * stream parameters set to true. It then returns a Transform stream that processes + * the response from the streamApi method and returns the response string alone. + * @param body - the OpenAI Invoke request body + */ + public async invokeStream(body: InvokeAIActionParams): Promise { + const res = (await this.streamApi({ + body: JSON.stringify(body), + stream: true, + })) as unknown as IncomingMessage; + + return res.pipe(new PassThrough()).pipe(transformToString()); + } + + /** + * Deprecated. Use invokeStream instead. + * TODO: remove before 8.12 FF in part 3 of streaming work for security solution + * tracked here: https://github.com/elastic/security-team/issues/7363 */ public async invokeAI(body: InvokeAIActionParams): Promise { const res = await this.runApi({ body: JSON.stringify(body) }); @@ -206,3 +229,44 @@ export class OpenAIConnector extends SubActionConnector { }; } } + +/** + * Takes in a readable stream of data and returns a Transform stream that + * parses the proprietary OpenAI response into a string of the response text alone, + * returning the response string to the stream + */ +const transformToString = () => { + let lineBuffer: string = ''; + const decoder = new TextDecoder(); + + return new Transform({ + transform(chunk, encoding, callback) { + const chunks = decoder.decode(chunk); + const lines = chunks.split('\n'); + lines[0] = lineBuffer + lines[0]; + lineBuffer = lines.pop() || ''; + callback(null, getNextChunk(lines)); + }, + flush(callback) { + // Emit an additional chunk with the content of lineBuffer if it has length + if (lineBuffer.length > 0) { + callback(null, getNextChunk([lineBuffer])); + } else { + callback(); + } + }, + }); +}; + +const getNextChunk = (lines: string[]) => { + const encoder = new TextEncoder(); + const nextChunk = lines + .map((str) => str.substring(6)) + .filter((str) => !!str && str !== '[DONE]') + .map((line) => { + const openaiResponse = JSON.parse(line); + return openaiResponse.choices[0]?.delta.content ?? ''; + }) + .join(''); + return encoder.encode(nextChunk); +}; diff --git a/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts b/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts index bfa8c5cb0736f..29e77feb5edaf 100644 --- a/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts +++ b/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts @@ -7,6 +7,8 @@ import http from 'http'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { ProxyArgs, Simulator } from './simulator'; export class BedrockSimulator extends Simulator { @@ -27,6 +29,10 @@ export class BedrockSimulator extends Simulator { return BedrockSimulator.sendErrorResponse(response); } + if (request.url === '/model/anthropic.claude-v2/invoke-with-response-stream') { + return BedrockSimulator.sendStreamResponse(response); + } + return BedrockSimulator.sendResponse(response); } @@ -36,6 +42,14 @@ export class BedrockSimulator extends Simulator { response.end(JSON.stringify(bedrockSuccessResponse, null, 4)); } + private static sendStreamResponse(response: http.ServerResponse) { + response.statusCode = 200; + response.setHeader('Content-Type', 'application/octet-stream'); + response.setHeader('Transfer-Encoding', 'chunked'); + response.write(encodeBedrockResponse('Hello world, what a unique string!')); + response.end(); + } + private static sendErrorResponse(response: http.ServerResponse) { response.statusCode = 422; response.setHeader('Content-Type', 'application/json;charset=UTF-8'); @@ -52,3 +66,20 @@ export const bedrockFailedResponse = { message: 'Malformed input request: extraneous key [ooooo] is not permitted, please reformat your input and try again.', }; + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: { + ':event-type': { type: 'string', value: 'chunk' }, + ':content-type': { type: 'string', value: 'application/json' }, + ':message-type': { type: 'string', value: 'event' }, + }, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts index 67053bef7801b..70cdc0f96dfdd 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts @@ -12,6 +12,7 @@ import { bedrockSuccessResponse, } from '@kbn/actions-simulators-plugin/server/bedrock_simulation'; import { DEFAULT_TOKEN_LIMIT } from '@kbn/stack-connectors-plugin/common/bedrock/constants'; +import { PassThrough } from 'stream'; import { FtrProviderContext } from '../../../../../common/ftr_provider_context'; import { getUrlPrefix, ObjectRemover } from '../../../../../common/lib'; @@ -407,6 +408,43 @@ export default function bedrockTest({ getService }: FtrProviderContext) { data: { message: bedrockSuccessResponse.completion }, }); }); + + it('should invoke stream with assistant AI body argument formatted to bedrock expectations', async () => { + await new Promise((resolve, reject) => { + let responseBody: string = ''; + + const passThrough = new PassThrough(); + + supertest + .post(`/internal/elastic_assistant/actions/connector/${bedrockActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .on('error', reject) + .send({ + params: { + subAction: 'invokeStream', + subActionParams: { + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }, + }, + assistantLangChain: false, + }) + .pipe(passThrough); + + passThrough.on('data', (chunk) => { + responseBody += chunk.toString(); + }); + + passThrough.on('end', () => { + expect(responseBody).to.eql('Hello world, what a unique string!'); + resolve(); + }); + }); + }); }); }); diff --git a/yarn.lock b/yarn.lock index 6bb6ba84d79b7..27dfd72978aae 100644 --- a/yarn.lock +++ b/yarn.lock @@ -91,6 +91,39 @@ resolved "https://registry.yarnpkg.com/@assemblyscript/loader/-/loader-0.10.1.tgz#70e45678f06c72fa2e350e8553ec4a4d72b92e06" integrity sha512-H71nDOOL8Y7kWRLqf6Sums+01Q5msqBW2KhDUTemh1tvY04eSkSXrK0uj/4mmY0Xr16/3zyZmsrxN7CKuRbNRg== +"@aws-crypto/crc32@3.0.0": + version "3.0.0" + resolved "https://registry.yarnpkg.com/@aws-crypto/crc32/-/crc32-3.0.0.tgz#07300eca214409c33e3ff769cd5697b57fdd38fa" + integrity sha512-IzSgsrxUcsrejQbPVilIKy16kAT52EwB6zSaI+M3xxIhKh5+aldEyvI+z6erM7TCLB2BJsFrtHjp6/4/sr+3dA== + dependencies: + "@aws-crypto/util" "^3.0.0" + "@aws-sdk/types" "^3.222.0" + tslib "^1.11.1" + +"@aws-crypto/util@^3.0.0": + version "3.0.0" + resolved "https://registry.yarnpkg.com/@aws-crypto/util/-/util-3.0.0.tgz#1c7ca90c29293f0883468ad48117937f0fe5bfb0" + integrity sha512-2OJlpeJpCR48CC8r+uKVChzs9Iungj9wkZrl8Z041DWEWvyIHILYKCPNzJghKsivj+S3mLo6BVc7mBNzdxA46w== + dependencies: + "@aws-sdk/types" "^3.222.0" + "@aws-sdk/util-utf8-browser" "^3.0.0" + tslib "^1.11.1" + +"@aws-sdk/types@^3.222.0": + version "3.433.0" + resolved "https://registry.yarnpkg.com/@aws-sdk/types/-/types-3.433.0.tgz#0f94eae2a4a3525ca872c9ab04e143c01806d755" + integrity sha512-0jEE2mSrNDd8VGFjTc1otYrwYPIkzZJEIK90ZxisKvQ/EURGBhNzWn7ejWB9XCMFT6XumYLBR0V9qq5UPisWtA== + dependencies: + "@smithy/types" "^2.4.0" + tslib "^2.5.0" + +"@aws-sdk/util-utf8-browser@^3.0.0": + version "3.259.0" + resolved "https://registry.yarnpkg.com/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz#3275a6f5eb334f96ca76635b961d3c50259fd9ff" + integrity sha512-UvFa/vR+e19XookZF8RzFZBrw2EUkQWxiBW0yYQAhvk3C+QVGl0H3ouca8LDBlBfQKXwmW3huo/59H8rwb1wJw== + dependencies: + tslib "^2.3.1" + "@babel/cli@^7.21.0": version "7.21.0" resolved "https://registry.yarnpkg.com/@babel/cli/-/cli-7.21.0.tgz#1868eb70e9824b427fc607610cce8e9e7889e7e1" @@ -7362,6 +7395,53 @@ "@types/node" ">=8.9.0" axios "^0.21.1" +"@smithy/eventstream-codec@^2.0.12": + version "2.0.12" + resolved "https://registry.yarnpkg.com/@smithy/eventstream-codec/-/eventstream-codec-2.0.12.tgz#99fab750d0ac3941f341d912d3c3a1ab985e1a7a" + integrity sha512-ZZQLzHBJkbiAAdj2C5K+lBlYp/XJ+eH2uy+jgJgYIFW/o5AM59Hlj7zyI44/ZTDIQWmBxb3EFv/c5t44V8/g8A== + dependencies: + "@aws-crypto/crc32" "3.0.0" + "@smithy/types" "^2.4.0" + "@smithy/util-hex-encoding" "^2.0.0" + tslib "^2.5.0" + +"@smithy/is-array-buffer@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/is-array-buffer/-/is-array-buffer-2.0.0.tgz#8fa9b8040651e7ba0b2f6106e636a91354ff7d34" + integrity sha512-z3PjFjMyZNI98JFRJi/U0nGoLWMSJlDjAW4QUX2WNZLas5C0CmVV6LJ01JI0k90l7FvpmixjWxPFmENSClQ7ug== + dependencies: + tslib "^2.5.0" + +"@smithy/types@^2.4.0": + version "2.4.0" + resolved "https://registry.yarnpkg.com/@smithy/types/-/types-2.4.0.tgz#ed35e429e3ea3d089c68ed1bf951d0ccbdf2692e" + integrity sha512-iH1Xz68FWlmBJ9vvYeHifVMWJf82ONx+OybPW8ZGf5wnEv2S0UXcU4zwlwJkRXuLKpcSLHrraHbn2ucdVXLb4g== + dependencies: + tslib "^2.5.0" + +"@smithy/util-buffer-from@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/util-buffer-from/-/util-buffer-from-2.0.0.tgz#7eb75d72288b6b3001bc5f75b48b711513091deb" + integrity sha512-/YNnLoHsR+4W4Vf2wL5lGv0ksg8Bmk3GEGxn2vEQt52AQaPSCuaO5PM5VM7lP1K9qHRKHwrPGktqVoAHKWHxzw== + dependencies: + "@smithy/is-array-buffer" "^2.0.0" + tslib "^2.5.0" + +"@smithy/util-hex-encoding@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/util-hex-encoding/-/util-hex-encoding-2.0.0.tgz#0aa3515acd2b005c6d55675e377080a7c513b59e" + integrity sha512-c5xY+NUnFqG6d7HFh1IFfrm3mGl29lC+vF+geHv4ToiuJCBmIfzx6IeHLg+OgRdPFKDXIw6pvi+p3CsscaMcMA== + dependencies: + tslib "^2.5.0" + +"@smithy/util-utf8@^2.0.0": + version "2.0.0" + resolved "https://registry.yarnpkg.com/@smithy/util-utf8/-/util-utf8-2.0.0.tgz#b4da87566ea7757435e153799df9da717262ad42" + integrity sha512-rctU1VkziY84n5OXe3bPNpKR001ZCME2JCaBBFgtiM2hfKbHFudc/BkMuPab8hRbLd0j3vbnBTTZ1igBf0wgiQ== + dependencies: + "@smithy/util-buffer-from" "^2.0.0" + tslib "^2.5.0" + "@storybook/addon-a11y@^6.5.16": version "6.5.16" resolved "https://registry.yarnpkg.com/@storybook/addon-a11y/-/addon-a11y-6.5.16.tgz#9288a6c1d111fa4ec501d213100ffff91757d3fc" @@ -29060,10 +29140,10 @@ tslib@2.3.1: resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.3.1.tgz#e8a335add5ceae51aa261d32a490158ef042ef01" integrity sha512-77EbyPPpMz+FRFRuAFlWMtmgUWGe9UOG2Z25NqCwiIjRhOf5iKGuzSe5P2w1laq+FkRy4p+PCuVkJSGkzTEKVw== -tslib@^1.10.0, tslib@^1.8.1, tslib@^1.9.0, tslib@^1.9.3: - version "1.13.0" - resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.13.0.tgz#c881e13cc7015894ed914862d276436fa9a47043" - integrity sha512-i/6DQjL8Xf3be4K/E6Wgpekn5Qasl1usyw++dAA35Ue5orEn65VIxOA+YvNNl9HV3qv70T7CNwjODHZrLwvd1Q== +tslib@^1.10.0, tslib@^1.11.1, tslib@^1.8.1, tslib@^1.9.0, tslib@^1.9.3: + version "1.14.1" + resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00" + integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg== tslib@^2.0.0, tslib@^2.0.1, tslib@^2.0.3, tslib@^2.1.0, tslib@^2.3.1, tslib@^2.4.0, tslib@^2.5.0, tslib@^2.5.2: version "2.6.2"