diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/bedrock_runtime_client.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/bedrock_runtime_client.ts index 359342870a8b9..7f20591bd51a4 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/bedrock_runtime_client.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/bedrock_runtime_client.ts @@ -8,12 +8,17 @@ import { BedrockRuntimeClient as _BedrockRuntimeClient, BedrockRuntimeClientConfig, + ConverseCommand, + ConverseResponse, + ConverseStreamCommand, + ConverseStreamResponse, } from '@aws-sdk/client-bedrock-runtime'; import { constructStack } from '@smithy/middleware-stack'; +import { HttpHandlerOptions } from '@smithy/types'; import { PublicMethodsOf } from '@kbn/utility-types'; import type { ActionsClient } from '@kbn/actions-plugin/server'; -import { NodeHttpHandler } from './node_http_handler'; +import { prepareMessages } from '../../utils/bedrock'; export interface CustomChatModelInput extends BedrockRuntimeClientConfig { actionsClient: PublicMethodsOf; @@ -23,15 +28,51 @@ export interface CustomChatModelInput extends BedrockRuntimeClientConfig { export class BedrockRuntimeClient extends _BedrockRuntimeClient { middlewareStack: _BedrockRuntimeClient['middlewareStack']; + streaming: boolean; + actionsClient: PublicMethodsOf; + connectorId: string; constructor({ actionsClient, connectorId, ...fields }: CustomChatModelInput) { super(fields ?? {}); - this.config.requestHandler = new NodeHttpHandler({ - streaming: fields.streaming ?? true, - actionsClient, - connectorId, - }); + this.streaming = fields.streaming ?? true; + this.actionsClient = actionsClient; + this.connectorId = connectorId; // eliminate middleware steps that handle auth as Kibana connector handles auth this.middlewareStack = constructStack() as _BedrockRuntimeClient['middlewareStack']; } + + public async send( + command: ConverseCommand | ConverseStreamCommand, + optionsOrCb?: HttpHandlerOptions | ((err: unknown, data: unknown) => void) + ) { + const options = typeof optionsOrCb !== 'function' ? optionsOrCb : {}; + if (command.input.messages) { + // without this, our human + human messages do not work and result in error: + // A conversation must alternate between user and assistant roles. + command.input.messages = prepareMessages(command.input.messages); + } + const data = (await this.actionsClient.execute({ + actionId: this.connectorId, + params: { + subAction: 'bedrockClientSend', + subActionParams: { + command, + signal: options?.abortSignal, + }, + }, + })) as { + data: ConverseResponse | ConverseStreamResponse; + status: string; + message?: string; + serviceMessage?: string; + }; + + if (data.status === 'error') { + throw new Error( + `ActionsClient BedrockRuntimeClient: action result status is error: ${data?.message} - ${data?.serviceMessage}` + ); + } + + return data.data; + } } diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.test.ts deleted file mode 100644 index ba8a1db1fbb00..0000000000000 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.test.ts +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { NodeHttpHandler } from './node_http_handler'; -import { HttpRequest } from '@smithy/protocol-http'; -import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { Readable } from 'stream'; -import { fromUtf8 } from '@smithy/util-utf8'; - -const mockActionsClient = actionsClientMock.create(); -const connectorId = 'mock-connector-id'; -const mockOutput = { - output: { - message: { - role: 'assistant', - content: [{ text: 'This is a response from the assistant.' }], - }, - }, - stopReason: 'end_turn', - usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, - metrics: { latencyMs: 123 }, - additionalModelResponseFields: {}, - trace: { guardrail: { modelOutput: ['Output text'] } }, -}; -describe('NodeHttpHandler', () => { - let handler: NodeHttpHandler; - - beforeEach(() => { - jest.clearAllMocks(); - handler = new NodeHttpHandler({ - streaming: false, - actionsClient: mockActionsClient, - connectorId, - }); - - mockActionsClient.execute.mockResolvedValue({ - data: mockOutput, - actionId: 'mock-action-id', - status: 'ok', - }); - }); - - it('handles non-streaming requests successfully', async () => { - const request = new HttpRequest({ - body: JSON.stringify({ messages: [] }), - }); - - const result = await handler.handle(request); - - expect(result.response.statusCode).toBe(200); - expect(result.response.headers['content-type']).toBe('application/json'); - expect(result.response.body).toStrictEqual(fromUtf8(JSON.stringify(mockOutput))); - }); - - it('handles streaming requests successfully', async () => { - handler = new NodeHttpHandler({ - streaming: true, - actionsClient: mockActionsClient, - connectorId, - }); - - const request = new HttpRequest({ - body: JSON.stringify({ messages: [] }), - }); - - const readable = new Readable(); - readable.push('streaming data'); - readable.push(null); - - mockActionsClient.execute.mockResolvedValue({ - data: readable, - status: 'ok', - actionId: 'mock-action-id', - }); - - const result = await handler.handle(request); - - expect(result.response.statusCode).toBe(200); - expect(result.response.body).toBe(readable); - }); - - it('throws an error for non-streaming requests with error status', async () => { - const request = new HttpRequest({ - body: JSON.stringify({ messages: [] }), - }); - - mockActionsClient.execute.mockResolvedValue({ - status: 'error', - message: 'error message', - serviceMessage: 'service error message', - actionId: 'mock-action-id', - }); - - await expect(handler.handle(request)).rejects.toThrow( - 'ActionsClientBedrockChat: action result status is error: error message - service error message' - ); - }); - - it('throws an error for streaming requests with error status', async () => { - handler = new NodeHttpHandler({ - streaming: true, - actionsClient: mockActionsClient, - connectorId, - }); - - const request = new HttpRequest({ - body: JSON.stringify({ messages: [] }), - }); - - mockActionsClient.execute.mockResolvedValue({ - status: 'error', - message: 'error message', - serviceMessage: 'service error message', - actionId: 'mock-action-id', - }); - - await expect(handler.handle(request)).rejects.toThrow( - 'ActionsClientBedrockChat: action result status is error: error message - service error message' - ); - }); -}); diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.ts deleted file mode 100644 index bd5143ef45d4a..0000000000000 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { NodeHttpHandler as _NodeHttpHandler } from '@smithy/node-http-handler'; -import { HttpRequest, HttpResponse } from '@smithy/protocol-http'; -import { HttpHandlerOptions, NodeHttpHandlerOptions } from '@smithy/types'; -import { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; -import { Readable } from 'stream'; -import { fromUtf8 } from '@smithy/util-utf8'; -import { ConverseResponse } from '@aws-sdk/client-bedrock-runtime'; -import { prepareMessages } from '../../utils/bedrock'; - -interface NodeHandlerOptions extends NodeHttpHandlerOptions { - streaming: boolean; - actionsClient: PublicMethodsOf; - connectorId: string; -} - -export class NodeHttpHandler extends _NodeHttpHandler { - streaming: boolean; - actionsClient: PublicMethodsOf; - connectorId: string; - constructor(options: NodeHandlerOptions) { - super(options); - this.streaming = options.streaming; - this.actionsClient = options.actionsClient; - this.connectorId = options.connectorId; - } - - async handle( - request: HttpRequest, - options: HttpHandlerOptions = {} - ): Promise<{ response: HttpResponse }> { - const body = JSON.parse(request.body); - const messages = prepareMessages(body.messages); - - if (this.streaming) { - const data = (await this.actionsClient.execute({ - actionId: this.connectorId, - params: { - subAction: 'converseStream', - subActionParams: { ...body, messages, signal: options.abortSignal }, - }, - })) as { data: Readable; status: string; message?: string; serviceMessage?: string }; - - if (data.status === 'error') { - throw new Error( - `ActionsClientBedrockChat: action result status is error: ${data?.message} - ${data?.serviceMessage}` - ); - } - - return { - response: { - statusCode: 200, - headers: {}, - body: data.data, - }, - }; - } - - const data = (await this.actionsClient.execute({ - actionId: this.connectorId, - params: { - subAction: 'converse', - subActionParams: { ...body, messages, signal: options.abortSignal }, - }, - })) as { data: ConverseResponse; status: string; message?: string; serviceMessage?: string }; - - if (data.status === 'error') { - throw new Error( - `ActionsClientBedrockChat: action result status is error: ${data?.message} - ${data?.serviceMessage}` - ); - } - - return { - response: { - statusCode: 200, - headers: { 'content-type': 'application/json' }, - body: fromUtf8(JSON.stringify(data.data)), - }, - }; - } -} diff --git a/x-pack/packages/kbn-langchain/server/utils/bedrock.ts b/x-pack/packages/kbn-langchain/server/utils/bedrock.ts index 7c8c069e5eb5a..b61144a5f9ad1 100644 --- a/x-pack/packages/kbn-langchain/server/utils/bedrock.ts +++ b/x-pack/packages/kbn-langchain/server/utils/bedrock.ts @@ -10,6 +10,7 @@ import { finished } from 'stream/promises'; import { Logger } from '@kbn/core/server'; import { EventStreamCodec } from '@smithy/eventstream-codec'; import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; +import { Message } from '@aws-sdk/client-bedrock-runtime'; import { StreamParser } from './types'; export const parseBedrockStreamAsAsyncIterator = async function* ( @@ -227,7 +228,7 @@ function parseContent(content: Array<{ text?: string; type: string }>): string { * Prepare messages for the bedrock API by combining messages from the same role * @param messages */ -export const prepareMessages = (messages: Array<{ role: string; content: string[] }>) => +export const prepareMessages = (messages: Message[]) => messages.reduce((acc, { role, content }) => { const lastMessage = acc[acc.length - 1]; @@ -236,13 +237,13 @@ export const prepareMessages = (messages: Array<{ role: string; content: string[ return acc; } - if (lastMessage.role === role) { - acc[acc.length - 1].content = lastMessage.content.concat(content); + if (lastMessage.role === role && lastMessage.content) { + acc[acc.length - 1].content = lastMessage.content.concat(content || []); return acc; } return acc; - }, [] as Array<{ role: string; content: string[] }>); + }, [] as Message[]); export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0'; export const DEFAULT_BEDROCK_REGION = 'us-east-1'; diff --git a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap index 936c5ba61b701..b9e4d8efe9f45 100644 --- a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap +++ b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap @@ -10,45 +10,6 @@ Object { "presence": "optional", }, "keys": Object { - "apiType": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "matches": Array [ - Object { - "schema": Object { - "allow": Array [ - "converse", - ], - "flags": Object { - "error": [Function], - "only": true, - }, - "type": "any", - }, - }, - Object { - "schema": Object { - "allow": Array [ - "invoke", - ], - "flags": Object { - "error": [Function], - "only": true, - }, - "type": "any", - }, - }, - ], - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "alternatives", - }, "body": Object { "flags": Object { "error": [Function], @@ -170,45 +131,6 @@ Object { "presence": "optional", }, "keys": Object { - "apiType": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "matches": Array [ - Object { - "schema": Object { - "allow": Array [ - "converse", - ], - "flags": Object { - "error": [Function], - "only": true, - }, - "type": "any", - }, - }, - Object { - "schema": Object { - "allow": Array [ - "invoke", - ], - "flags": Object { - "error": [Function], - "only": true, - }, - "type": "any", - }, - }, - ], - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "alternatives", - }, "body": Object { "flags": Object { "error": [Function], @@ -1471,23 +1393,18 @@ Object { "presence": "optional", }, "keys": Object { - "additionalModelRequestFields": Object { + "command": Object { "flags": Object { - "default": [Function], "error": [Function], - "presence": "optional", }, "metas": Array [ Object { "x-oas-any-type": true, }, - Object { - "x-oas-optional": true, - }, ], "type": "any", }, - "additionalModelResponseFieldPaths": Object { + "signal": Object { "flags": Object { "default": [Function], "error": [Function], @@ -1503,927 +1420,83 @@ Object { ], "type": "any", }, - "guardrailConfig": Object { + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .bedrock 8`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "apiUrl": Object { "flags": Object { - "default": [Function], "error": [Function], - "presence": "optional", }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, + "rules": Array [ Object { - "x-oas-optional": true, + "args": Object { + "method": [Function], + }, + "name": "custom", }, ], - "type": "any", + "type": "string", }, - "inferenceConfig": Object { + "defaultModel": Object { "flags": Object { - "default": Object { - "special": "deep", - }, + "default": "anthropic.claude-3-5-sonnet-20240620-v1:0", "error": [Function], "presence": "optional", }, - "keys": Object { - "maxTokens": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "number", - }, - "stopSequences": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "items": Array [ - Object { - "flags": Object { - "error": [Function], - "presence": "optional", - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - ], - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "array", - }, - "temperature": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "number", - }, - "topP": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", + "rules": Array [ + Object { + "args": Object { + "method": [Function], }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "number", + "name": "custom", }, - }, - "type": "object", + ], + "type": "string", }, - "messages": Object { + }, + "type": "object", +} +`; + +exports[`Connector type config checks detect connector type changes for: .bedrock 9`] = ` +Object { + "flags": Object { + "default": Object { + "special": "deep", + }, + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "accessKey": Object { "flags": Object { "error": [Function], }, - "items": Array [ + "rules": Array [ Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "content": Object { - "flags": Object { - "error": [Function], - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - ], - "type": "any", - }, - "role": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, + "args": Object { + "method": [Function], }, - "type": "object", + "name": "custom", }, ], - "type": "array", + "type": "string", }, - "modelId": Object { + "secret": Object { "flags": Object { - "default": [Function], "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "signal": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - Object { - "x-oas-optional": true, - }, - ], - "type": "any", - }, - "system": Object { - "flags": Object { - "error": [Function], - }, - "items": Array [ - Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "text": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - ], - "type": "array", - }, - "toolConfig": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "toolChoice": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - "unknown": true, - }, - "keys": Object {}, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "preferences": Object { - "stripUnknown": Object { - "objects": false, - }, - }, - "type": "object", - }, - "tools": Object { - "flags": Object { - "error": [Function], - }, - "items": Array [ - Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "toolSpec": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "description": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "inputSchema": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "json": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "$schema": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "additionalProperties": Object { - "flags": Object { - "error": [Function], - }, - "type": "boolean", - }, - "properties": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - "unknown": true, - }, - "keys": Object {}, - "preferences": Object { - "stripUnknown": Object { - "objects": false, - }, - }, - "type": "object", - }, - "required": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "items": Array [ - Object { - "flags": Object { - "error": [Function], - "presence": "optional", - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - ], - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "array", - }, - "type": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - }, - "type": "object", - }, - "name": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - }, - "type": "object", - }, - ], - "type": "array", - }, - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "object", - }, - }, - "type": "object", -} -`; - -exports[`Connector type config checks detect connector type changes for: .bedrock 8`] = ` -Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "additionalModelRequestFields": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - Object { - "x-oas-optional": true, - }, - ], - "type": "any", - }, - "additionalModelResponseFieldPaths": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - Object { - "x-oas-optional": true, - }, - ], - "type": "any", - }, - "guardrailConfig": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - Object { - "x-oas-optional": true, - }, - ], - "type": "any", - }, - "inferenceConfig": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "maxTokens": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "number", - }, - "stopSequences": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "items": Array [ - Object { - "flags": Object { - "error": [Function], - "presence": "optional", - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - ], - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "array", - }, - "temperature": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "number", - }, - "topP": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "number", - }, - }, - "type": "object", - }, - "messages": Object { - "flags": Object { - "error": [Function], - }, - "items": Array [ - Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "content": Object { - "flags": Object { - "error": [Function], - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - ], - "type": "any", - }, - "role": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - ], - "type": "array", - }, - "modelId": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "signal": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-any-type": true, - }, - Object { - "x-oas-optional": true, - }, - ], - "type": "any", - }, - "system": Object { - "flags": Object { - "error": [Function], - }, - "items": Array [ - Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "text": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - ], - "type": "array", - }, - "toolConfig": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "toolChoice": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - "unknown": true, - }, - "keys": Object {}, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "preferences": Object { - "stripUnknown": Object { - "objects": false, - }, - }, - "type": "object", - }, - "tools": Object { - "flags": Object { - "error": [Function], - }, - "items": Array [ - Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "toolSpec": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "description": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "inputSchema": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "json": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "$schema": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "additionalProperties": Object { - "flags": Object { - "error": [Function], - }, - "type": "boolean", - }, - "properties": Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - "unknown": true, - }, - "keys": Object {}, - "preferences": Object { - "stripUnknown": Object { - "objects": false, - }, - }, - "type": "object", - }, - "required": Object { - "flags": Object { - "default": [Function], - "error": [Function], - "presence": "optional", - }, - "items": Array [ - Object { - "flags": Object { - "error": [Function], - "presence": "optional", - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - ], - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "array", - }, - "type": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - }, - "type": "object", - }, - "name": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", - }, - }, - "type": "object", - }, - ], - "type": "array", - }, - }, - "metas": Array [ - Object { - "x-oas-optional": true, - }, - ], - "type": "object", - }, - }, - "type": "object", -} -`; - -exports[`Connector type config checks detect connector type changes for: .bedrock 9`] = ` -Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "apiUrl": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "defaultModel": Object { - "flags": Object { - "default": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "error": [Function], - "presence": "optional", }, "rules": Array [ Object { @@ -2441,49 +1514,6 @@ Object { `; exports[`Connector type config checks detect connector type changes for: .bedrock 10`] = ` -Object { - "flags": Object { - "default": Object { - "special": "deep", - }, - "error": [Function], - "presence": "optional", - }, - "keys": Object { - "accessKey": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - "secret": Object { - "flags": Object { - "error": [Function], - }, - "rules": Array [ - Object { - "args": Object { - "method": [Function], - }, - "name": "custom", - }, - ], - "type": "string", - }, - }, - "type": "object", -} -`; - -exports[`Connector type config checks detect connector type changes for: .bedrock 11`] = ` Object { "flags": Object { "default": Object { diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts index d6c89a35ba7ff..f9cc6ff5a5ba0 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.test.ts @@ -24,6 +24,7 @@ describe('getGenAiTokenTracking', () => { let mockGetTokenCountFromInvokeStream: jest.Mock; let mockGetTokenCountFromInvokeAsyncIterator: jest.Mock; beforeEach(() => { + jest.clearAllMocks(); mockGetTokenCountFromBedrockInvoke = ( getTokenCountFromBedrockInvoke as jest.Mock ).mockResolvedValueOnce({ @@ -163,6 +164,103 @@ describe('getGenAiTokenTracking', () => { }); }); + it('should return the total, prompt, and completion token counts when given a valid ConverseResponse for bedrockClientSend subaction', async () => { + const actionTypeId = '.bedrock'; + + const result = { + actionId: '123', + status: 'ok' as const, + data: { + usage: { + inputTokens: 50, + outputTokens: 50, + totalTokens: 100, + }, + }, + }; + const validatedParams = { + subAction: 'bedrockClientSend', + }; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toEqual({ + total_tokens: 100, + prompt_tokens: 50, + completion_tokens: 50, + }); + expect(logger.error).not.toHaveBeenCalled(); + }); + + it('should return the total, prompt, and completion token counts when given a valid ConverseStreamResponse for bedrockClientSend subaction', async () => { + const chunkIterable = { + async *[Symbol.asyncIterator]() { + await new Promise((resolve) => setTimeout(resolve, 100)); + yield { + metadata: { + usage: { + totalTokens: 100, + inputTokens: 40, + outputTokens: 60, + }, + }, + }; + }, + }; + const actionTypeId = '.bedrock'; + + const result = { + actionId: '123', + status: 'ok' as const, + data: { + tokenStream: chunkIterable, + }, + }; + const validatedParams = { + subAction: 'bedrockClientSend', + }; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toEqual({ + total_tokens: 100, + prompt_tokens: 40, + completion_tokens: 60, + }); + expect(logger.error).not.toHaveBeenCalled(); + }); + + it('should return null when given an invalid Bedrock response for bedrockClientSend subaction', async () => { + const actionTypeId = '.bedrock'; + const result = { + actionId: '123', + status: 'ok' as const, + data: {}, + }; + const validatedParams = { + subAction: 'bedrockClientSend', + }; + + const tokenTracking = await getGenAiTokenTracking({ + actionTypeId, + logger, + result, + validatedParams, + }); + + expect(tokenTracking).toBeNull(); + expect(logger.error).toHaveBeenCalled(); + }); it('should return the total, prompt, and completion token counts when given a valid OpenAI streamed response', async () => { const mockReader = new IncomingMessage(new Socket()); const actionTypeId = '.gen-ai'; diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index 41bfa28605f40..d73610892098d 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -9,6 +9,10 @@ import { PassThrough, Readable } from 'stream'; import { Logger } from '@kbn/logging'; import { Stream } from 'openai/streaming'; import { ChatCompletionChunk } from 'openai/resources/chat/completions'; +import { + getTokensFromBedrockConverseStream, + SmithyStream, +} from './get_token_count_from_bedrock_converse'; import { InvokeAsyncIteratorBody, getTokenCountFromInvokeAsyncIterator, @@ -264,6 +268,29 @@ export const getGenAiTokenTracking = async ({ // silently fail and null is returned at bottom of function } } + + // BedrockRuntimeClient.send response used by chat model ActionsClientChatBedrockConverse + if (actionTypeId === '.bedrock' && validatedParams.subAction === 'bedrockClientSend') { + const { tokenStream, usage } = result.data as unknown as { + tokenStream?: SmithyStream; + usage?: { inputTokens: number; outputTokens: number; totalTokens: number }; + }; + if (tokenStream) { + const res = await getTokensFromBedrockConverseStream(tokenStream, logger); + return res; + } + if (usage) { + return { + total_tokens: usage.totalTokens, + prompt_tokens: usage.inputTokens, + completion_tokens: usage.outputTokens, + }; + } else { + logger.error('Response from Bedrock converse API did not contain usage object'); + return null; + } + } + return null; }; diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_converse.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_converse.ts new file mode 100644 index 0000000000000..55bd9e6582e00 --- /dev/null +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_bedrock_converse.ts @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; +import { Logger } from '@kbn/logging'; + +export type SmithyStream = SmithyMessageDecoderStream<{ + metadata?: { + usage: { inputTokens: number; outputTokens: number; totalTokens: number }; + }; +}>; + +export const getTokensFromBedrockConverseStream = async function ( + responseStream: SmithyStream, + logger: Logger +): Promise<{ total_tokens: number; prompt_tokens: number; completion_tokens: number } | null> { + try { + for await (const { metadata } of responseStream) { + if (metadata) { + return { + total_tokens: metadata.usage.totalTokens, + prompt_tokens: metadata.usage.inputTokens, + completion_tokens: metadata.usage.outputTokens, + }; + } + } + return null; // Return the final tokens once the generator finishes + } catch (e) { + logger.error('Response from Bedrock converse API did not contain usage object'); + return null; + } +}; diff --git a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts index d2ffa0b116bda..de6c10246298a 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/constants.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/constants.ts @@ -21,8 +21,7 @@ export enum SUB_ACTION { INVOKE_STREAM = 'invokeStream', DASHBOARD = 'getDashboard', TEST = 'test', - CONVERSE = 'converse', - CONVERSE_STREAM = 'converseStream', + BEDROCK_CLIENT_SEND = 'bedrockClientSend', } export const DEFAULT_TIMEOUT_MS = 120000; diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index c444159c010b2..e9194a752300c 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -26,11 +26,6 @@ export const RunActionParamsSchema = schema.object({ signal: schema.maybe(schema.any()), timeout: schema.maybe(schema.number()), raw: schema.maybe(schema.boolean()), - apiType: schema.maybe( - schema.oneOf([schema.literal('converse'), schema.literal('invoke')], { - defaultValue: 'invoke', - }) - ), }); export const BedrockMessageSchema = schema.object( @@ -154,53 +149,11 @@ export const DashboardActionResponseSchema = schema.object({ available: schema.boolean(), }); -export const ConverseActionParamsSchema = schema.object({ - // Bedrock API Properties - modelId: schema.maybe(schema.string()), - messages: schema.arrayOf( - schema.object({ - role: schema.string(), - content: schema.any(), - }) - ), - system: schema.arrayOf( - schema.object({ - text: schema.string(), - }) - ), - inferenceConfig: schema.object({ - temperature: schema.maybe(schema.number()), - maxTokens: schema.maybe(schema.number()), - stopSequences: schema.maybe(schema.arrayOf(schema.string())), - topP: schema.maybe(schema.number()), - }), - toolConfig: schema.maybe( - schema.object({ - tools: schema.arrayOf( - schema.object({ - toolSpec: schema.object({ - name: schema.string(), - description: schema.string(), - inputSchema: schema.object({ - json: schema.object({ - type: schema.string(), - properties: schema.object({}, { unknowns: 'allow' }), - required: schema.maybe(schema.arrayOf(schema.string())), - additionalProperties: schema.boolean(), - $schema: schema.maybe(schema.string()), - }), - }), - }), - }) - ), - toolChoice: schema.maybe(schema.object({}, { unknowns: 'allow' })), - }) - ), - additionalModelRequestFields: schema.maybe(schema.any()), - additionalModelResponseFieldPaths: schema.maybe(schema.any()), - guardrailConfig: schema.maybe(schema.any()), +export const BedrockClientSendParamsSchema = schema.object({ + // ConverseCommand | ConverseStreamCommand from @aws-sdk/client-bedrock-runtime + command: schema.any(), // Kibana related properties signal: schema.maybe(schema.any()), }); -export const ConverseActionResponseSchema = schema.object({}, { unknowns: 'allow' }); +export const BedrockClientSendResponseSchema = schema.object({}, { unknowns: 'allow' }); diff --git a/x-pack/plugins/stack_connectors/common/bedrock/types.ts b/x-pack/plugins/stack_connectors/common/bedrock/types.ts index e3dd49538176f..2e716a52547cd 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/types.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/types.ts @@ -21,8 +21,8 @@ import { RunApiLatestResponseSchema, BedrockMessageSchema, BedrockToolChoiceSchema, - ConverseActionParamsSchema, - ConverseActionResponseSchema, + BedrockClientSendParamsSchema, + BedrockClientSendResponseSchema, } from './schema'; export type Config = TypeOf; @@ -39,5 +39,5 @@ export type DashboardActionParams = TypeOf; export type DashboardActionResponse = TypeOf; export type BedrockMessage = TypeOf; export type BedrockToolChoice = TypeOf; -export type ConverseActionParams = TypeOf; -export type ConverseActionResponse = TypeOf; +export type ConverseActionParams = TypeOf; +export type ConverseActionResponse = TypeOf; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts index 2a4d91a07f1d3..ce3dd90942cf5 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.test.ts @@ -30,6 +30,7 @@ jest.mock('../lib/gen_ai/create_gen_ai_dashboard'); // @ts-ignore const mockSigner = jest.spyOn(aws, 'sign').mockReturnValue({ signed: true }); +const mockSend = jest.fn(); describe('BedrockConnector', () => { let mockRequest: jest.Mock; let mockError: jest.Mock; @@ -89,6 +90,8 @@ describe('BedrockConnector', () => { beforeEach(() => { // @ts-ignore connector.request = mockRequest; + // @ts-ignore + connector.bedrockClient.send = mockSend; }); describe('runApi', () => { @@ -630,6 +633,57 @@ describe('BedrockConnector', () => { ); }); }); + + describe('bedrockClientSend', () => { + it('should send the command and return the response', async () => { + const command = { input: 'test' }; + const response = { result: 'success' }; + mockSend.mockResolvedValue(response); + + const result = await connector.bedrockClientSend( + { signal: undefined, command }, + connectorUsageCollector + ); + + expect(mockSend).toHaveBeenCalledWith(command, { abortSignal: undefined }); + expect(result).toEqual(response); + }); + + it('should handle and split streaming response', async () => { + const command = { input: 'test' }; + const stream = new PassThrough(); + const response = { stream }; + mockSend.mockResolvedValue(response); + + const result = (await connector.bedrockClientSend( + { signal: undefined, command }, + connectorUsageCollector + )) as unknown as { + stream?: unknown; + tokenStream?: unknown; + }; + + expect(mockSend).toHaveBeenCalledWith(command, { abortSignal: undefined }); + expect(result.stream).toBeDefined(); + expect(result.tokenStream).toBeDefined(); + }); + + it('should handle non-streaming response', async () => { + const command = { input: 'test' }; + const usage = { stats: 0 }; + const response = { usage }; + mockSend.mockResolvedValue(response); + + const result = (await connector.bedrockClientSend( + { signal: undefined, command }, + connectorUsageCollector + )) as unknown as { + usage?: unknown; + }; + expect(result.usage).toBeDefined(); + }); + }); + describe('getResponseErrorMessage', () => { it('returns an unknown error message', () => { // @ts-expect-error expects an axios error as the parameter diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index 55b631ba9441c..339efa49f69bf 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -7,6 +7,8 @@ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import aws from 'aws4'; +import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'; +import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; import { AxiosError, Method } from 'axios'; import { IncomingMessage } from 'http'; import { PassThrough } from 'stream'; @@ -21,7 +23,7 @@ import { StreamingResponseSchema, RunActionResponseSchema, RunApiLatestResponseSchema, - ConverseActionParamsSchema, + BedrockClientSendParamsSchema, } from '../../../common/bedrock/schema'; import { Config, @@ -60,13 +62,20 @@ interface SignedRequest { export class BedrockConnector extends SubActionConnector { private url; private model; + private bedrockClient; constructor(params: ServiceParams) { super(params); this.url = this.config.apiUrl; this.model = this.config.defaultModel; - + this.bedrockClient = new BedrockRuntimeClient({ + region: extractRegionId(this.config.apiUrl), + credentials: { + accessKeyId: this.secrets.accessKey, + secretAccessKey: this.secrets.secret, + }, + }); this.registerSubActions(); } @@ -108,15 +117,9 @@ export class BedrockConnector extends SubActionConnector { }); this.registerSubAction({ - name: SUB_ACTION.CONVERSE, - method: 'converse', - schema: ConverseActionParamsSchema, - }); - - this.registerSubAction({ - name: SUB_ACTION.CONVERSE_STREAM, - method: 'converseStream', - schema: ConverseActionParamsSchema, + name: SUB_ACTION.BEDROCK_CLIENT_SEND, + method: 'bedrockClientSend', + schema: BedrockClientSendParamsSchema, }); } @@ -240,15 +243,14 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B * @param signal Optional signal to cancel the request. * @param timeout Optional timeout for the request. * @param raw Optional flag to indicate if the response should be returned as raw data. - * @param apiType Optional type of API to be called. Defaults to 'invoke', . */ public async runApi( - { body, model: reqModel, signal, timeout, raw, apiType = 'invoke' }: RunActionParams, + { body, model: reqModel, signal, timeout, raw }: RunActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { // set model on per request basis const currentModel = reqModel ?? this.model; - const path = `/model/${currentModel}/${apiType}`; + const path = `/model/${currentModel}/invoke`; const signed = this.signRequest(body, path, false); const requestArgs = { ...signed, @@ -281,22 +283,18 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B /** * NOT INTENDED TO BE CALLED DIRECTLY - * call invokeStream or converseStream instead + * call invokeStream instead * responsible for making a POST request to a specified URL with a given request body. * The response is then processed based on whether it is a streaming response or a regular response. * @param body The stringified request body to be sent in the POST request. * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used. */ private async streamApi( - { body, model: reqModel, signal, timeout, apiType = 'invoke' }: RunActionParams, + { body, model: reqModel, signal, timeout }: RunActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { - const streamingApiRoute = { - invoke: 'invoke-with-response-stream', - converse: 'converse-stream', - }; // set model on per request basis - const path = `/model/${reqModel ?? this.model}/${streamingApiRoute[apiType]}`; + const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`; const signed = this.signRequest(body, path, true); const response = await this.request( @@ -436,45 +434,28 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B } /** - * Sends a request to the Bedrock API to perform a conversation action. - * @param input - The parameters for the conversation action. + * Sends a request via the BedrockRuntimeClient to perform a conversation action. + * @param params - The parameters for the conversation action. + * @param params.signal - The signal to cancel the request. + * @param params.command - The command class to be sent to the API. (ConverseCommand | ConverseStreamCommand) * @param connectorUsageCollector - The usage collector for the connector. * @returns A promise that resolves to the response of the conversation action. */ - public async converse( - { signal, ...converseApiInput }: ConverseActionParams, + public async bedrockClientSend( + { signal, command }: ConverseActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { - const res = await this.runApi( - { - body: JSON.stringify(converseApiInput), - raw: true, - apiType: 'converse', - signal, - }, - connectorUsageCollector - ); - return res; - } + connectorUsageCollector.addRequestBodyBytes(undefined, command); + const res = await this.bedrockClient.send(command, { + abortSignal: signal, + }); - /** - * Sends a request to the Bedrock API to perform a streaming conversation action. - * @param input - The parameters for the streaming conversation action. - * @param connectorUsageCollector - The usage collector for the connector. - * @returns A promise that resolves to the streaming response of the conversation action. - */ - public async converseStream( - { signal, ...converseApiInput }: ConverseActionParams, - connectorUsageCollector: ConnectorUsageCollector - ): Promise { - const res = await this.streamApi( - { - body: JSON.stringify(converseApiInput), - apiType: 'converse', - signal, - }, - connectorUsageCollector - ); + if ('stream' in res) { + const resultStream = res.stream as SmithyMessageDecoderStream; + // splits the stream in two, [stream = consumer, tokenStream = token tracking] + const [stream, tokenStream] = tee(resultStream); + return { ...res, stream, tokenStream }; + } return res; } @@ -571,3 +552,91 @@ function parseContent(content: Array<{ text?: string; type: string }>): string { } const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null; + +function extractRegionId(url: string) { + const match = (url ?? '').match(/bedrock\.(.*?)\.amazonaws\./); + if (match) { + return match[1]; + } else { + // fallback to us-east-1 + return 'us-east-1'; + } +} + +/** + * Splits an async iterator into two independent async iterators which can be independently read from at different speeds. + * @param asyncIterator The async iterator returned from Bedrock to split + */ +function tee( + asyncIterator: SmithyMessageDecoderStream +): [SmithyMessageDecoderStream, SmithyMessageDecoderStream] { + // @ts-ignore options is private, but we need it to create the new streams + const streamOptions = asyncIterator.options; + + const streamLeft = new SmithyMessageDecoderStream(streamOptions); + const streamRight = new SmithyMessageDecoderStream(streamOptions); + + // Queues to store chunks for each stream + const leftQueue: T[] = []; + const rightQueue: T[] = []; + + // Promises for managing when a chunk is available + let leftPending: ((chunk: T | null) => void) | null = null; + let rightPending: ((chunk: T | null) => void) | null = null; + + const distribute = async () => { + for await (const chunk of asyncIterator) { + // Push the chunk into both queues + if (leftPending) { + leftPending(chunk); + leftPending = null; + } else { + leftQueue.push(chunk); + } + + if (rightPending) { + rightPending(chunk); + rightPending = null; + } else { + rightQueue.push(chunk); + } + } + + // Signal the end of the iterator + if (leftPending) { + leftPending(null); + } + if (rightPending) { + rightPending(null); + } + }; + + // Start distributing chunks from the iterator + distribute().catch(() => { + // swallow errors + }); + + // Helper to create an async iterator for each stream + const createIterator = ( + queue: T[], + setPending: (fn: ((chunk: T | null) => void) | null) => void + ) => { + return async function* () { + while (true) { + if (queue.length > 0) { + yield queue.shift()!; + } else { + const chunk = await new Promise((resolve) => setPending(resolve)); + if (chunk === null) break; // End of the stream + yield chunk; + } + } + }; + }; + + // Assign independent async iterators to each stream + streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn)); + streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn)); + + return [streamLeft, streamRight]; +}