From 3e0a431144d6afbc35c5bab5cf09735d73cb879f Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Fri, 30 Aug 2024 12:47:06 +0200 Subject: [PATCH] Inference plugin: Add Bedrock model adapter (#191434) ## Summary Add the `bedrock` (well, bedrock-claude) model adapter for the inference plugin. Also had to perform minor changes on the associated connector to add support for new capabilities. --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com> --- .../connector_types.test.ts.snap | 373 +++++++++++++++++- .../bedrock/bedrock_claude_adapter.test.ts | 310 +++++++++++++++ .../bedrock/bedrock_claude_adapter.ts | 140 +++++++ .../chat_complete/adapters/bedrock/index.ts | 8 + .../bedrock/process_completion_chunks.test.ts | 336 ++++++++++++++++ .../bedrock/process_completion_chunks.ts | 113 ++++++ .../chat_complete/adapters/bedrock/prompts.ts | 15 + .../serde_eventstream_into_observable.test.ts | 87 ++++ .../serde_eventstream_into_observable.ts | 76 ++++ .../adapters/bedrock/serde_utils.test.ts | 20 + .../adapters/bedrock/serde_utils.ts | 33 ++ .../chat_complete/adapters/bedrock/types.ts | 98 +++++ .../adapters/get_inference_adapter.test.ts | 5 +- .../adapters/get_inference_adapter.ts | 4 +- x-pack/plugins/inference/tsconfig.json | 3 +- .../stack_connectors/common/bedrock/schema.ts | 38 +- .../stack_connectors/common/bedrock/types.ts | 4 + .../server/connector_types/bedrock/bedrock.ts | 60 ++- 18 files changed, 1691 insertions(+), 32 deletions(-) create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/index.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/prompts.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/bedrock/types.ts diff --git a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap index ae1289b2a9e2e..6e4392e3f737e 100644 --- a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap +++ b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap @@ -273,8 +273,15 @@ Object { "keys": Object { "content": Object { "flags": Object { + "default": [Function], "error": [Function], + "presence": "optional", }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], "rules": Array [ Object { "args": Object { @@ -285,6 +292,33 @@ Object { ], "type": "string", }, + "rawContent": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, "role": Object { "flags": Object { "error": [Function], @@ -300,6 +334,14 @@ Object { "type": "string", }, }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], "type": "object", }, ], @@ -419,6 +461,86 @@ Object { ], "type": "number", }, + "toolChoice": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "name": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "type": Object { + "flags": Object { + "error": [Function], + }, + "matches": Array [ + Object { + "schema": Object { + "allow": Array [ + "auto", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "any", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "tool", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + ], + "type": "alternatives", + }, + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "object", + }, "tools": Object { "flags": Object { "default": [Function], @@ -556,8 +678,15 @@ Object { "keys": Object { "content": Object { "flags": Object { + "default": [Function], "error": [Function], + "presence": "optional", }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], "rules": Array [ Object { "args": Object { @@ -568,6 +697,33 @@ Object { ], "type": "string", }, + "rawContent": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, "role": Object { "flags": Object { "error": [Function], @@ -583,6 +739,14 @@ Object { "type": "string", }, }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], "type": "object", }, ], @@ -702,6 +866,86 @@ Object { ], "type": "number", }, + "toolChoice": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "name": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "type": Object { + "flags": Object { + "error": [Function], + }, + "matches": Array [ + Object { + "schema": Object { + "allow": Array [ + "auto", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "any", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "tool", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + ], + "type": "alternatives", + }, + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "object", + }, "tools": Object { "flags": Object { "default": [Function], @@ -839,14 +1083,51 @@ Object { "keys": Object { "content": Object { "flags": Object { + "default": [Function], "error": [Function], + "presence": "optional", }, "metas": Array [ Object { - "x-oas-any-type": true, + "x-oas-optional": true, }, ], - "type": "any", + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "rawContent": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-any-type": true, + }, + ], + "type": "any", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", }, "role": Object { "flags": Object { @@ -863,6 +1144,14 @@ Object { "type": "string", }, }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], "type": "object", }, ], @@ -982,6 +1271,86 @@ Object { ], "type": "number", }, + "toolChoice": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "name": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + "type": Object { + "flags": Object { + "error": [Function], + }, + "matches": Array [ + Object { + "schema": Object { + "allow": Array [ + "auto", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "any", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "tool", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + ], + "type": "alternatives", + }, + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "object", + }, "tools": Object { "flags": Object { "default": [Function], diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts new file mode 100644 index 0000000000000..1d25f09dce3bc --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts @@ -0,0 +1,310 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { PassThrough } from 'stream'; +import type { InferenceExecutor } from '../../utils/inference_executor'; +import { MessageRole } from '../../../../common/chat_complete'; +import { ToolChoiceType } from '../../../../common/chat_complete/tools'; +import { bedrockClaudeAdapter } from './bedrock_claude_adapter'; +import { addNoToolUsageDirective } from './prompts'; + +describe('bedrockClaudeAdapter', () => { + const executorMock = { + invoke: jest.fn(), + } as InferenceExecutor & { invoke: jest.MockedFn }; + + beforeEach(() => { + executorMock.invoke.mockReset(); + executorMock.invoke.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: new PassThrough(), + }; + }); + }); + + function getCallParams() { + const params = executorMock.invoke.mock.calls[0][0].subActionParams as Record; + return { + system: params.system, + messages: params.messages, + tools: params.tools, + toolChoice: params.toolChoice, + }; + } + + describe('#chatComplete()', () => { + it('calls `executor.invoke` with the right fixed parameters', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'invokeStream', + subActionParams: { + messages: [ + { + role: 'user', + rawContent: [{ type: 'text', text: 'question' }], + }, + ], + temperature: 0, + stopSequences: ['\n\nHuman:'], + }, + }); + }); + + it('correctly format tools', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + tools: { + myFunction: { + description: 'myFunction', + }, + myFunctionWithArgs: { + description: 'myFunctionWithArgs', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + description: 'foo', + }, + }, + required: ['foo'], + }, + }, + }, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { tools } = getCallParams(); + expect(tools).toEqual([ + { + name: 'myFunction', + description: 'myFunction', + input_schema: { + properties: {}, + type: 'object', + }, + }, + { + name: 'myFunctionWithArgs', + description: 'myFunctionWithArgs', + input_schema: { + properties: { + foo: { + description: 'foo', + type: 'string', + }, + }, + required: ['foo'], + type: 'object', + }, + }, + ]); + }); + + it('correctly format messages', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.User, + content: 'another question', + }, + { + role: MessageRole.Assistant, + content: null, + toolCalls: [ + { + function: { + name: 'my_function', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '0', + }, + ], + }, + { + role: MessageRole.Tool, + toolCallId: '0', + response: { + bar: 'foo', + }, + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { messages } = getCallParams(); + expect(messages).toEqual([ + { + rawContent: [ + { + text: 'question', + type: 'text', + }, + ], + role: 'user', + }, + { + rawContent: [ + { + text: 'answer', + type: 'text', + }, + ], + role: 'assistant', + }, + { + rawContent: [ + { + text: 'another question', + type: 'text', + }, + ], + role: 'user', + }, + { + rawContent: [ + { + id: '0', + input: { + foo: 'bar', + }, + name: 'my_function', + type: 'tool_use', + }, + ], + role: 'assistant', + }, + { + rawContent: [ + { + content: '{"bar":"foo"}', + tool_use_id: '0', + type: 'tool_result', + }, + ], + role: 'user', + }, + ]); + }); + + it('correctly format system message', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + system: 'Some system message', + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { system } = getCallParams(); + expect(system).toEqual('Some system message'); + }); + + it('correctly format tool choice', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + toolChoice: ToolChoiceType.required, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { toolChoice } = getCallParams(); + expect(toolChoice).toEqual({ + type: 'any', + }); + }); + + it('correctly format tool choice for named function', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + toolChoice: { function: 'foobar' }, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { toolChoice } = getCallParams(); + expect(toolChoice).toEqual({ + type: 'tool', + name: 'foobar', + }); + }); + + it('correctly adapt the request for ToolChoiceType.None', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + system: 'some system instruction', + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + tools: { + myFunction: { + description: 'myFunction', + }, + }, + toolChoice: ToolChoiceType.none, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { toolChoice, tools, system } = getCallParams(); + expect(toolChoice).toBeUndefined(); + expect(tools).toEqual([]); + expect(system).toEqual(addNoToolUsageDirective('some system instruction')); + }); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts new file mode 100644 index 0000000000000..5a03dc04347b1 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, from, map, switchMap, tap } from 'rxjs'; +import { Readable } from 'stream'; +import type { InvokeAIActionParams } from '@kbn/stack-connectors-plugin/common/bedrock/types'; +import { parseSerdeChunkMessage } from './serde_utils'; +import { Message, MessageRole } from '../../../../common/chat_complete'; +import { createInferenceInternalError } from '../../../../common/errors'; +import { ToolChoiceType, type ToolOptions } from '../../../../common/chat_complete/tools'; +import { InferenceConnectorAdapter } from '../../types'; +import type { BedRockMessage, BedrockToolChoice } from './types'; +import { + BedrockChunkMember, + serdeEventstreamIntoObservable, +} from './serde_eventstream_into_observable'; +import { processCompletionChunks } from './process_completion_chunks'; +import { addNoToolUsageDirective } from './prompts'; + +export const bedrockClaudeAdapter: InferenceConnectorAdapter = { + chatComplete: ({ executor, system, messages, toolChoice, tools }) => { + const noToolUsage = toolChoice === ToolChoiceType.none; + + const connectorInvokeRequest: InvokeAIActionParams = { + system: noToolUsage ? addNoToolUsageDirective(system) : system, + messages: messagesToBedrock(messages), + tools: noToolUsage ? [] : toolsToBedrock(tools), + toolChoice: toolChoiceToBedrock(toolChoice), + temperature: 0, + stopSequences: ['\n\nHuman:'], + }; + + return from( + executor.invoke({ + subAction: 'invokeStream', + subActionParams: connectorInvokeRequest, + }) + ).pipe( + switchMap((response) => { + const readable = response.data as Readable; + return serdeEventstreamIntoObservable(readable); + }), + tap((eventData) => { + if ('modelStreamErrorException' in eventData) { + throw createInferenceInternalError(eventData.modelStreamErrorException.originalMessage); + } + }), + filter((value): value is BedrockChunkMember => { + return 'chunk' in value && value.chunk?.headers?.[':event-type']?.value === 'chunk'; + }), + map((message) => { + return parseSerdeChunkMessage(message.chunk); + }), + processCompletionChunks() + ); + }, +}; + +const toolChoiceToBedrock = ( + toolChoice: ToolOptions['toolChoice'] +): BedrockToolChoice | undefined => { + if (toolChoice === ToolChoiceType.required) { + return { + type: 'any', + }; + } else if (toolChoice === ToolChoiceType.auto) { + return { + type: 'auto', + }; + } else if (typeof toolChoice === 'object') { + return { + type: 'tool', + name: toolChoice.function, + }; + } + // ToolChoiceType.none is not supported by claude + // we are adding a directive to the system instructions instead in that case. + return undefined; +}; + +const toolsToBedrock = (tools: ToolOptions['tools']) => { + return tools + ? Object.entries(tools).map(([toolName, toolDef]) => { + return { + name: toolName, + description: toolDef.description, + input_schema: toolDef.schema ?? { + type: 'object' as const, + properties: {}, + }, + }; + }) + : undefined; +}; + +const messagesToBedrock = (messages: Message[]): BedRockMessage[] => { + return messages.map((message) => { + switch (message.role) { + case MessageRole.User: + return { + role: 'user' as const, + rawContent: [{ type: 'text' as const, text: message.content }], + }; + case MessageRole.Assistant: + return { + role: 'assistant' as const, + rawContent: [ + ...(message.content ? [{ type: 'text' as const, text: message.content }] : []), + ...(message.toolCalls + ? message.toolCalls.map((toolCall) => { + return { + type: 'tool_use' as const, + id: toolCall.toolCallId, + name: toolCall.function.name, + input: ('arguments' in toolCall.function + ? toolCall.function.arguments + : {}) as Record, + }; + }) + : []), + ], + }; + case MessageRole.Tool: + return { + role: 'user' as const, + rawContent: [ + { + type: 'tool_result' as const, + tool_use_id: message.toolCallId, + content: JSON.stringify(message.response), + }, + ], + }; + } + }); +}; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/index.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/index.ts new file mode 100644 index 0000000000000..01d849e1ea9af --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { bedrockClaudeAdapter } from './bedrock_claude_adapter'; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.test.ts new file mode 100644 index 0000000000000..6307aecaeefc4 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.test.ts @@ -0,0 +1,336 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { lastValueFrom, of, toArray } from 'rxjs'; +import { processCompletionChunks } from './process_completion_chunks'; +import type { CompletionChunk } from './types'; + +describe('processCompletionChunks', () => { + it('does not emit for a message_start event', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'message_start', + message: 'foo', + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual( + [] + ); + }); + + it('emits the correct value for a content_block_start event with text content ', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_start', + index: 0, + content_block: { type: 'text', text: 'foo' }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + type: 'chatCompletionChunk', + content: 'foo', + tool_calls: [], + }, + ]); + }); + + it('emits the correct value for a content_block_start event with tool_use content ', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_start', + index: 0, + content_block: { type: 'tool_use', id: 'id', name: 'name', input: '{}' }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + type: 'chatCompletionChunk', + content: '', + tool_calls: [ + { + toolCallId: 'id', + index: 0, + function: { + arguments: '', + name: 'name', + }, + }, + ], + }, + ]); + }); + + it('emits the correct value for a content_block_delta event with text content ', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'delta' }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + type: 'chatCompletionChunk', + content: 'delta', + tool_calls: [], + }, + ]); + }); + + it('emits the correct value for a content_block_delta event with tool_use content ', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_delta', + index: 0, + delta: { type: 'input_json_delta', partial_json: '{ "param' }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + type: 'chatCompletionChunk', + content: '', + tool_calls: [ + { + index: 0, + toolCallId: '', + function: { + arguments: '{ "param', + name: '', + }, + }, + ], + }, + ]); + }); + + it('does not emit for a content_block_stop event', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_stop', + index: 0, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual( + [] + ); + }); + + it('emits the correct value for a message_delta event with tool_use content ', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'message_delta', + delta: { stop_reason: 'end_turn', stop_sequence: 'stop_seq', usage: { output_tokens: 42 } }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + type: 'chatCompletionChunk', + content: 'stop_seq', + tool_calls: [], + }, + ]); + }); + + it('emits a token count for a message_stop event ', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'message_stop', + 'amazon-bedrock-invocationMetrics': { + inputTokenCount: 1, + outputTokenCount: 2, + invocationLatency: 3, + firstByteLatency: 4, + }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + type: 'chatCompletionTokenCount', + tokens: { + completion: 2, + prompt: 1, + total: 3, + }, + }, + ]); + }); + + it('emits the correct values for a text response scenario', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'message_start', + message: 'foo', + }, + { + type: 'content_block_start', + index: 0, + content_block: { type: 'text', text: 'foo' }, + }, + { + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'delta1' }, + }, + { + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'delta2' }, + }, + { + type: 'content_block_stop', + index: 0, + }, + { + type: 'message_delta', + delta: { stop_reason: 'end_turn', stop_sequence: 'stop_seq', usage: { output_tokens: 42 } }, + }, + { + type: 'message_stop', + 'amazon-bedrock-invocationMetrics': { + inputTokenCount: 1, + outputTokenCount: 2, + invocationLatency: 3, + firstByteLatency: 4, + }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + content: 'foo', + tool_calls: [], + type: 'chatCompletionChunk', + }, + { + content: 'delta1', + tool_calls: [], + type: 'chatCompletionChunk', + }, + { + content: 'delta2', + tool_calls: [], + type: 'chatCompletionChunk', + }, + { + content: 'stop_seq', + tool_calls: [], + type: 'chatCompletionChunk', + }, + { + tokens: { + completion: 2, + prompt: 1, + total: 3, + }, + type: 'chatCompletionTokenCount', + }, + ]); + }); + + it('emits the correct values for a tool_use response scenario', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'message_start', + message: 'foo', + }, + { + type: 'content_block_start', + index: 0, + content_block: { type: 'tool_use', id: 'id', name: 'name', input: '{}' }, + }, + { + type: 'content_block_delta', + index: 0, + delta: { type: 'input_json_delta', partial_json: '{ "param' }, + }, + { + type: 'content_block_delta', + index: 0, + delta: { type: 'input_json_delta', partial_json: '": 12 }' }, + }, + { + type: 'content_block_stop', + index: 0, + }, + { + type: 'message_delta', + delta: { stop_reason: 'tool_use', stop_sequence: null, usage: { output_tokens: 42 } }, + }, + { + type: 'message_stop', + 'amazon-bedrock-invocationMetrics': { + inputTokenCount: 1, + outputTokenCount: 2, + invocationLatency: 3, + firstByteLatency: 4, + }, + }, + ]; + + expect(await lastValueFrom(of(...chunks).pipe(processCompletionChunks(), toArray()))).toEqual([ + { + content: '', + tool_calls: [ + { + function: { + arguments: '', + name: 'name', + }, + index: 0, + toolCallId: 'id', + }, + ], + type: 'chatCompletionChunk', + }, + { + content: '', + tool_calls: [ + { + function: { + arguments: '{ "param', + name: '', + }, + index: 0, + toolCallId: '', + }, + ], + type: 'chatCompletionChunk', + }, + { + content: '', + tool_calls: [ + { + function: { + arguments: '": 12 }', + name: '', + }, + index: 0, + toolCallId: '', + }, + ], + type: 'chatCompletionChunk', + }, + { + tokens: { + completion: 2, + prompt: 1, + total: 3, + }, + type: 'chatCompletionTokenCount', + }, + ]); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.ts new file mode 100644 index 0000000000000..5513cc9028ac9 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/process_completion_chunks.ts @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable, Subscriber } from 'rxjs'; +import { + ChatCompletionChunkEvent, + ChatCompletionTokenCountEvent, + ChatCompletionChunkToolCall, + ChatCompletionEventType, +} from '../../../../common/chat_complete'; +import type { CompletionChunk, MessageStopChunk } from './types'; + +export function processCompletionChunks() { + return (source: Observable) => + new Observable((subscriber) => { + function handleNext(chunkBody: CompletionChunk) { + if (isTokenCountCompletionChunk(chunkBody)) { + return emitTokenCountEvent(subscriber, chunkBody); + } + + let completionChunk = ''; + let toolCallChunk: ChatCompletionChunkToolCall | undefined; + + switch (chunkBody.type) { + case 'content_block_start': + if (chunkBody.content_block.type === 'text') { + completionChunk = chunkBody.content_block.text || ''; + } else if (chunkBody.content_block.type === 'tool_use') { + toolCallChunk = { + index: chunkBody.index, + toolCallId: chunkBody.content_block.id, + function: { + name: chunkBody.content_block.name, + // the API returns '{}' here, which can't be merged with the deltas... + arguments: '', + }, + }; + } + break; + + case 'content_block_delta': + if (chunkBody.delta.type === 'text_delta') { + completionChunk = chunkBody.delta.text || ''; + } else if (chunkBody.delta.type === 'input_json_delta') { + toolCallChunk = { + index: chunkBody.index, + toolCallId: '', + function: { + name: '', + arguments: chunkBody.delta.partial_json, + }, + }; + } + break; + + case 'message_delta': + completionChunk = chunkBody.delta.stop_sequence || ''; + break; + + default: + break; + } + + if (completionChunk || toolCallChunk) { + subscriber.next({ + type: ChatCompletionEventType.ChatCompletionChunk, + content: completionChunk, + tool_calls: toolCallChunk ? [toolCallChunk] : [], + }); + } + } + + source.subscribe({ + next: (value) => { + try { + handleNext(value); + } catch (error) { + subscriber.error(error); + } + }, + error: (err) => { + subscriber.error(err); + }, + complete: () => { + subscriber.complete(); + }, + }); + }); +} + +function isTokenCountCompletionChunk(value: CompletionChunk): value is MessageStopChunk { + return value.type === 'message_stop' && 'amazon-bedrock-invocationMetrics' in value; +} + +function emitTokenCountEvent( + subscriber: Subscriber, + chunk: MessageStopChunk +) { + const { inputTokenCount, outputTokenCount } = chunk['amazon-bedrock-invocationMetrics']; + + subscriber.next({ + type: ChatCompletionEventType.ChatCompletionTokenCount, + tokens: { + completion: outputTokenCount, + prompt: inputTokenCount, + total: inputTokenCount + outputTokenCount, + }, + }); +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/prompts.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/prompts.ts new file mode 100644 index 0000000000000..ed8387bf75252 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/prompts.ts @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +const noToolUsageDirective = ` +Please answer with text. You should NOT call or use a tool, even if tools might be available and even if +the user explicitly asks for it. DO NOT UNDER ANY CIRCUMSTANCES call a tool. Instead, ALWAYS reply with text. +`; + +export const addNoToolUsageDirective = (systemMessage: string | undefined): string => { + return systemMessage ? `${systemMessage}\n\n${noToolUsageDirective}` : noToolUsageDirective; +}; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.test.ts new file mode 100644 index 0000000000000..bed6458a94dc7 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.test.ts @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Readable } from 'stream'; +import { Observable, toArray, firstValueFrom, map, filter } from 'rxjs'; +import { + BedrockChunkMember, + BedrockStreamMember, + serdeEventstreamIntoObservable, +} from './serde_eventstream_into_observable'; +import { EventStreamMarshaller } from '@smithy/eventstream-serde-node'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; +import type { CompletionChunk } from './types'; +import { parseSerdeChunkMessage, serializeSerdeChunkMessage } from './serde_utils'; + +describe('serdeEventstreamIntoObservable', () => { + const marshaller = new EventStreamMarshaller({ + utf8Encoder: toUtf8, + utf8Decoder: fromUtf8, + }); + + const getSerdeEventStream = (chunks: CompletionChunk[]) => { + const input = Readable.from(chunks); + return marshaller.serialize(input, serializeSerdeChunkMessage); + }; + + const getChunks = async (serde$: Observable) => { + return await firstValueFrom( + serde$.pipe( + filter((value): value is BedrockChunkMember => { + return 'chunk' in value && value.chunk?.headers?.[':event-type']?.value === 'chunk'; + }), + map((message) => { + return parseSerdeChunkMessage(message.chunk); + }), + toArray() + ) + ); + }; + + it('converts a single chunk', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'Hello' }, + }, + ]; + + const inputStream = getSerdeEventStream(chunks); + const serde$ = serdeEventstreamIntoObservable(inputStream); + + const result = await getChunks(serde$); + + expect(result).toEqual(chunks); + }); + + it('converts multiple chunks', async () => { + const chunks: CompletionChunk[] = [ + { + type: 'content_block_start', + index: 0, + content_block: { type: 'text', text: 'start' }, + }, + { + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'Hello' }, + }, + { + type: 'content_block_stop', + index: 0, + }, + ]; + + const inputStream = getSerdeEventStream(chunks); + const serde$ = serdeEventstreamIntoObservable(inputStream); + + const result = await getChunks(serde$); + + expect(result).toEqual(chunks); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.ts new file mode 100644 index 0000000000000..24a245ab2efcc --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_eventstream_into_observable.ts @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { EventStreamMarshaller } from '@smithy/eventstream-serde-node'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; +import { identity } from 'lodash'; +import { Observable } from 'rxjs'; +import { Readable } from 'stream'; +import { Message } from '@smithy/types'; +import { createInferenceInternalError } from '../../../../common/errors'; + +interface ModelStreamErrorException { + name: 'ModelStreamErrorException'; + originalStatusCode?: number; + originalMessage?: string; +} + +export interface BedrockChunkMember { + chunk: Message; +} + +export interface ModelStreamErrorExceptionMember { + modelStreamErrorException: ModelStreamErrorException; +} + +export type BedrockStreamMember = BedrockChunkMember | ModelStreamErrorExceptionMember; + +// AWS uses SerDe to send over serialized data, so we use their +// @smithy library to parse the stream data + +export function serdeEventstreamIntoObservable( + readable: Readable +): Observable { + return new Observable((subscriber) => { + const marshaller = new EventStreamMarshaller({ + utf8Encoder: toUtf8, + utf8Decoder: fromUtf8, + }); + + async function processStream() { + for await (const chunk of marshaller.deserialize(readable, identity)) { + if (chunk) { + subscriber.next(chunk); + } + } + } + + processStream().then( + () => { + subscriber.complete(); + }, + (error) => { + if (!(error instanceof Error)) { + try { + const exceptionType = error.headers[':exception-type'].value; + const body = toUtf8(error.body); + let message = `Encountered error in Bedrock stream of type ${exceptionType}`; + try { + message += '\n' + JSON.parse(body).message; + } catch (parseError) { + // trap + } + error = createInferenceInternalError(message); + } catch (decodeError) { + error = createInferenceInternalError(decodeError.message); + } + } + subscriber.error(error); + } + ); + }); +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.test.ts new file mode 100644 index 0000000000000..c763fd8c9daf3 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.test.ts @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { CompletionChunk } from './types'; +import { serializeSerdeChunkMessage, parseSerdeChunkMessage } from './serde_utils'; + +describe('parseSerdeChunkMessage', () => { + it('parses a serde chunk message', () => { + const chunk: CompletionChunk = { + type: 'content_block_stop', + index: 0, + }; + + expect(parseSerdeChunkMessage(serializeSerdeChunkMessage(chunk))).toEqual(chunk); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.ts new file mode 100644 index 0000000000000..d7050b7744940 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/serde_utils.ts @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { toUtf8, fromUtf8 } from '@smithy/util-utf8'; +import type { Message } from '@smithy/types'; +import type { CompletionChunk } from './types'; + +/** + * Extract the completion chunk from a chunk message + */ +export function parseSerdeChunkMessage(chunk: Message): CompletionChunk { + return JSON.parse(Buffer.from(JSON.parse(toUtf8(chunk.body)).bytes, 'base64').toString('utf-8')); +} + +/** + * Reverse `parseSerdeChunkMessage` + */ +export const serializeSerdeChunkMessage = (input: CompletionChunk): Message => { + const b64 = Buffer.from(JSON.stringify(input), 'utf-8').toString('base64'); + const body = fromUtf8(JSON.stringify({ bytes: b64 })); + return { + headers: { + ':event-type': { type: 'string', value: 'chunk' }, + ':content-type': { type: 'string', value: 'application/json' }, + ':message-type': { type: 'string', value: 'event' }, + }, + body, + }; +}; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/types.ts b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/types.ts new file mode 100644 index 0000000000000..f0937a8d8ec18 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/bedrock/types.ts @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * BedRock message as expected by the bedrock connector + */ +export interface BedRockMessage { + role: 'user' | 'assistant'; + content?: string; + rawContent?: BedRockMessagePart[]; +} + +/** + * Bedrock message parts + */ +export type BedRockMessagePart = + | { type: 'text'; text: string } + | { + type: 'tool_use'; + id: string; + name: string; + input: Record; + } + | { type: 'tool_result'; tool_use_id: string; content: string }; + +export type BedrockToolChoice = { type: 'auto' } | { type: 'any' } | { type: 'tool'; name: string }; + +interface CompletionChunkBase { + type: string; +} + +export interface MessageStartChunk extends CompletionChunkBase { + type: 'message_start'; + message: unknown; +} + +export interface ContentBlockStartChunk extends CompletionChunkBase { + type: 'content_block_start'; + index: number; + content_block: + | { + type: 'text'; + text: string; + } + | { type: 'tool_use'; id: string; name: string; input: string }; +} + +export interface ContentBlockDeltaChunk extends CompletionChunkBase { + type: 'content_block_delta'; + index: number; + delta: + | { + type: 'text_delta'; + text: string; + } + | { + type: 'input_json_delta'; + partial_json: string; + }; +} + +export interface ContentBlockStopChunk extends CompletionChunkBase { + type: 'content_block_stop'; + index: number; +} + +export interface MessageDeltaChunk extends CompletionChunkBase { + type: 'message_delta'; + delta: { + stop_reason: string; + stop_sequence: null | string; + usage: { + output_tokens: number; + }; + }; +} + +export interface MessageStopChunk extends CompletionChunkBase { + type: 'message_stop'; + 'amazon-bedrock-invocationMetrics': { + inputTokenCount: number; + outputTokenCount: number; + invocationLatency: number; + firstByteLatency: number; + }; +} + +export type CompletionChunk = + | MessageStartChunk + | ContentBlockStartChunk + | ContentBlockDeltaChunk + | ContentBlockStopChunk + | MessageDeltaChunk + | MessageStopChunk; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts index 9e0b0da6d5894..558e0cd06ef91 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts @@ -9,6 +9,7 @@ import { InferenceConnectorType } from '../../../common/connectors'; import { getInferenceAdapter } from './get_inference_adapter'; import { openAIAdapter } from './openai'; import { geminiAdapter } from './gemini'; +import { bedrockClaudeAdapter } from './bedrock'; describe('getInferenceAdapter', () => { it('returns the openAI adapter for OpenAI type', () => { @@ -19,7 +20,7 @@ describe('getInferenceAdapter', () => { expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(geminiAdapter); }); - it('returns undefined for Bedrock type', () => { - expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined); + it('returns the bedrock adapter for Bedrock type', () => { + expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(bedrockClaudeAdapter); }); }); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts index 0538d828a473a..f34b0c27a339f 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts @@ -9,6 +9,7 @@ import { InferenceConnectorType } from '../../../common/connectors'; import type { InferenceConnectorAdapter } from '../types'; import { openAIAdapter } from './openai'; import { geminiAdapter } from './gemini'; +import { bedrockClaudeAdapter } from './bedrock'; export const getInferenceAdapter = ( connectorType: InferenceConnectorType @@ -21,8 +22,7 @@ export const getInferenceAdapter = ( return geminiAdapter; case InferenceConnectorType.Bedrock: - // not implemented yet - break; + return bedrockClaudeAdapter; } return undefined; diff --git a/x-pack/plugins/inference/tsconfig.json b/x-pack/plugins/inference/tsconfig.json index 16d7ca041582c..593556c8f39c8 100644 --- a/x-pack/plugins/inference/tsconfig.json +++ b/x-pack/plugins/inference/tsconfig.json @@ -22,6 +22,7 @@ "@kbn/logging", "@kbn/core-http-server", "@kbn/actions-plugin", - "@kbn/config-schema" + "@kbn/config-schema", + "@kbn/stack-connectors-plugin" ] } diff --git a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts index 093a5f9b11518..03f4f5cc01735 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/schema.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/schema.ts @@ -28,13 +28,30 @@ export const RunActionParamsSchema = schema.object({ raw: schema.maybe(schema.boolean()), }); +export const BedrockMessageSchema = schema.object( + { + role: schema.string(), + content: schema.maybe(schema.string()), + rawContent: schema.maybe(schema.arrayOf(schema.any())), + }, + { + validate: (value) => { + if (value.content === undefined && value.rawContent === undefined) { + return 'Must specify either content or rawContent'; + } else if (value.content !== undefined && value.rawContent !== undefined) { + return 'content and rawContent can not be used at the same time'; + } + }, + } +); + +export const BedrockToolChoiceSchema = schema.object({ + type: schema.oneOf([schema.literal('auto'), schema.literal('any'), schema.literal('tool')]), + name: schema.maybe(schema.string()), +}); + export const InvokeAIActionParamsSchema = schema.object({ - messages: schema.arrayOf( - schema.object({ - role: schema.string(), - content: schema.string(), - }) - ), + messages: schema.arrayOf(BedrockMessageSchema), model: schema.maybe(schema.string()), temperature: schema.maybe(schema.number()), stopSequences: schema.maybe(schema.arrayOf(schema.string())), @@ -53,6 +70,7 @@ export const InvokeAIActionParamsSchema = schema.object({ }) ) ), + toolChoice: schema.maybe(BedrockToolChoiceSchema), }); export const InvokeAIActionResponseSchema = schema.object({ @@ -60,12 +78,7 @@ export const InvokeAIActionResponseSchema = schema.object({ }); export const InvokeAIRawActionParamsSchema = schema.object({ - messages: schema.arrayOf( - schema.object({ - role: schema.string(), - content: schema.any(), - }) - ), + messages: schema.arrayOf(BedrockMessageSchema), model: schema.maybe(schema.string()), temperature: schema.maybe(schema.number()), stopSequences: schema.maybe(schema.arrayOf(schema.string())), @@ -84,6 +97,7 @@ export const InvokeAIRawActionParamsSchema = schema.object({ }) ) ), + toolChoice: schema.maybe(BedrockToolChoiceSchema), }); export const InvokeAIRawActionResponseSchema = schema.object({}, { unknowns: 'allow' }); diff --git a/x-pack/plugins/stack_connectors/common/bedrock/types.ts b/x-pack/plugins/stack_connectors/common/bedrock/types.ts index b144f78b91edd..3b02f40d2de62 100644 --- a/x-pack/plugins/stack_connectors/common/bedrock/types.ts +++ b/x-pack/plugins/stack_connectors/common/bedrock/types.ts @@ -19,6 +19,8 @@ import { InvokeAIRawActionResponseSchema, StreamingResponseSchema, RunApiLatestResponseSchema, + BedrockMessageSchema, + BedrockToolChoiceSchema, } from './schema'; export type Config = TypeOf; @@ -33,3 +35,5 @@ export type RunActionResponse = TypeOf; export type StreamingResponse = TypeOf; export type DashboardActionParams = TypeOf; export type DashboardActionResponse = TypeOf; +export type BedRockMessage = TypeOf; +export type BedrockToolChoice = TypeOf; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts index b5ec114a9c456..c2c773bdeaf87 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -22,7 +22,7 @@ import { RunActionResponseSchema, RunApiLatestResponseSchema, } from '../../../common/bedrock/schema'; -import { +import type { Config, Secrets, RunActionParams, @@ -32,6 +32,8 @@ import { InvokeAIRawActionParams, InvokeAIRawActionResponse, RunApiLatestResponse, + BedRockMessage, + BedrockToolChoice, } from '../../../common/bedrock/types'; import { SUB_ACTION, @@ -309,13 +311,14 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, tools, + toolChoice, }: InvokeAIActionParams | InvokeAIRawActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { const res = (await this.streamApi( { body: JSON.stringify( - formatBedrockBody({ messages, stopSequences, system, temperature, tools }) + formatBedrockBody({ messages, stopSequences, system, temperature, tools, toolChoice }) ), model, signal, @@ -344,13 +347,23 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B maxTokens, signal, timeout, + tools, + toolChoice, }: InvokeAIActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { const res = (await this.runApi( { body: JSON.stringify( - formatBedrockBody({ messages, stopSequences, system, temperature, maxTokens }) + formatBedrockBody({ + messages, + stopSequences, + system, + temperature, + maxTokens, + tools, + toolChoice, + }) ), model, signal, @@ -372,6 +385,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B signal, timeout, tools, + toolChoice, anthropicVersion, }: InvokeAIRawActionParams, connectorUsageCollector: ConnectorUsageCollector @@ -385,6 +399,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B temperature, max_tokens: maxTokens, tools, + tool_choice: toolChoice, anthropic_version: anthropicVersion, }), model, @@ -405,14 +420,16 @@ const formatBedrockBody = ({ system, maxTokens = DEFAULT_TOKEN_LIMIT, tools, + toolChoice, }: { - messages: Array<{ role: string; content?: string }>; + messages: BedRockMessage[]; stopSequences?: string[]; temperature?: number; maxTokens?: number; // optional system message to be sent to the API system?: string; tools?: Array<{ name: string; description: string }>; + toolChoice?: BedrockToolChoice; }) => ({ anthropic_version: 'bedrock-2023-05-31', ...ensureMessageFormat(messages, system), @@ -420,8 +437,14 @@ const formatBedrockBody = ({ stop_sequences: stopSequences, temperature, tools, + tool_choice: toolChoice, }); +interface FormattedBedRockMessage { + role: string; + content: string | BedRockMessage['rawContent']; +} + /** * Ensures that the messages are in the correct format for the Bedrock API * If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error @@ -429,19 +452,32 @@ const formatBedrockBody = ({ * @param messages */ const ensureMessageFormat = ( - messages: Array<{ role: string; content?: string }>, + messages: BedRockMessage[], systemPrompt?: string -): { messages: Array<{ role: string; content?: string }>; system?: string } => { +): { + messages: FormattedBedRockMessage[]; + system?: string; +} => { let system = systemPrompt ? systemPrompt : ''; - const newMessages = messages.reduce((acc: Array<{ role: string; content?: string }>, m) => { - const lastMessage = acc[acc.length - 1]; + const newMessages = messages.reduce((acc, m) => { if (m.role === 'system') { system = `${system.length ? `${system}\n` : ''}${m.content}`; return acc; } - if (lastMessage && lastMessage.role === m.role) { + const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user'); + + if (m.rawContent) { + acc.push({ + role: messageRole(), + content: m.rawContent, + }); + return acc; + } + + const lastMessage = acc[acc.length - 1]; + if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') { // Bedrock only accepts assistant and user roles. // If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message return [ @@ -451,11 +487,9 @@ const ensureMessageFormat = ( } // force role outside of system to ensure it is either assistant or user - return [ - ...acc, - { content: m.content, role: ['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user' }, - ]; + return [...acc, { content: m.content, role: messageRole() }]; }, []); + return system.length ? { system, messages: newMessages } : { messages: newMessages }; };