Skip to content

Commit

Permalink
[Security solution] ChatBedrockConverse (#200042)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Nov 19, 2024
1 parent 8e7799a commit 755ef31
Show file tree
Hide file tree
Showing 19 changed files with 2,482 additions and 156 deletions.
8 changes: 6 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"@appland/sql-parser": "^1.5.1",
"@aws-crypto/sha256-js": "^5.2.0",
"@aws-crypto/util": "^5.2.0",
"@aws-sdk/client-bedrock-runtime": "^3.687.0",
"@babel/runtime": "^7.24.7",
"@dagrejs/dagre": "^1.1.4",
"@dnd-kit/core": "^6.1.0",
Expand Down Expand Up @@ -1019,7 +1020,8 @@
"@kbn/xstate-utils": "link:packages/kbn-xstate-utils",
"@kbn/zod": "link:packages/kbn-zod",
"@kbn/zod-helpers": "link:packages/kbn-zod-helpers",
"@langchain/community": "0.3.11",
"@langchain/aws": "^0.1.2",
"@langchain/community": "0.3.14",
"@langchain/core": "^0.3.16",
"@langchain/google-common": "^0.1.1",
"@langchain/google-genai": "^0.1.2",
Expand Down Expand Up @@ -1054,7 +1056,9 @@
"@slack/webhook": "^7.0.1",
"@smithy/eventstream-codec": "^3.1.1",
"@smithy/eventstream-serde-node": "^3.0.3",
"@smithy/protocol-http": "^4.0.2",
"@smithy/middleware-stack": "^3.0.10",
"@smithy/node-http-handler": "^3.3.1",
"@smithy/protocol-http": "^4.1.7",
"@smithy/signature-v4": "^3.1.1",
"@smithy/types": "^3.2.0",
"@smithy/util-utf8": "^3.0.0",
Expand Down
2 changes: 2 additions & 0 deletions x-pack/packages/kbn-langchain/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { ActionsClientLlm } from './language_models/llm';
import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model';
import { ActionsClientGeminiChatModel } from './language_models/gemini_chat';
import { ActionsClientChatVertexAI } from './language_models/chat_vertex';
import { ActionsClientChatBedrockConverse } from './language_models/chat_bedrock_converse';
import { parseBedrockStream } from './utils/bedrock';
import { parseGeminiResponse } from './utils/gemini';
import { getDefaultArguments } from './language_models/constants';
Expand All @@ -25,4 +26,5 @@ export {
ActionsClientGeminiChatModel,
ActionsClientLlm,
ActionsClientSimpleChatModel,
ActionsClientChatBedrockConverse,
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ import { BedrockChat as _BedrockChat } from '@langchain/community/chat_models/be
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { Logger } from '@kbn/logging';
import { Readable } from 'stream';
import { PublicMethodsOf } from '@kbn/utility-types';

export const DEFAULT_BEDROCK_MODEL = 'anthropic.claude-3-5-sonnet-20240620-v1:0';
export const DEFAULT_BEDROCK_REGION = 'us-east-1';
import { prepareMessages, DEFAULT_BEDROCK_MODEL, DEFAULT_BEDROCK_REGION } from '../utils/bedrock';

export interface CustomChatModelInput extends BaseChatModelParams {
actionsClient: PublicMethodsOf<ActionsClient>;
Expand All @@ -25,6 +22,11 @@ export interface CustomChatModelInput extends BaseChatModelParams {
maxTokens?: number;
}

/**
* @deprecated Use the ActionsClientChatBedrockConverse chat model instead.
* ActionsClientBedrockChatModel chat model supports non-streaming only the Bedrock Invoke API.
* The LangChain team will support only the Bedrock Converse API in the future.
*/
export class ActionsClientBedrockChatModel extends _BedrockChat {
constructor({ actionsClient, connectorId, logger, ...params }: CustomChatModelInput) {
super({
Expand All @@ -36,32 +38,10 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
fetchFn: async (url, options) => {
const inputBody = JSON.parse(options?.body as string);

if (this.streaming && !inputBody.tools?.length) {
const data = (await actionsClient.execute({
actionId: connectorId,
params: {
subAction: 'invokeStream',
subActionParams: {
messages: inputBody.messages,
temperature: params.temperature ?? inputBody.temperature,
stopSequences: inputBody.stop_sequences,
system: inputBody.system,
maxTokens: params.maxTokens ?? inputBody.max_tokens,
tools: inputBody.tools,
anthropicVersion: inputBody.anthropic_version,
},
},
})) as { data: Readable; status: string; message?: string; serviceMessage?: string };

if (data.status === 'error') {
throw new Error(
`ActionsClientBedrockChat: action result status is error: ${data?.message} - ${data?.serviceMessage}`
);
}

return {
body: Readable.toWeb(data.data),
} as unknown as Response;
if (this.streaming) {
throw new Error(
`ActionsClientBedrockChat does not support streaming, use ActionsClientChatBedrockConverse instead`
);
}

const data = (await actionsClient.execute({
Expand All @@ -84,7 +64,6 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
message?: string;
serviceMessage?: string;
};

if (data.status === 'error') {
throw new Error(
`ActionsClientBedrockChat: action result status is error: ${data?.message} - ${data?.serviceMessage}`
Expand All @@ -99,20 +78,3 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
});
}
}

const prepareMessages = (messages: Array<{ role: string; content: string[] }>) =>
messages.reduce((acc, { role, content }) => {
const lastMessage = acc[acc.length - 1];

if (!lastMessage || lastMessage.role !== role) {
acc.push({ role, content });
return acc;
}

if (lastMessage.role === role) {
acc[acc.length - 1].content = lastMessage.content.concat(content);
return acc;
}

return acc;
}, [] as Array<{ role: string; content: string[] }>);
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import {
BedrockRuntimeClient as _BedrockRuntimeClient,
BedrockRuntimeClientConfig,
} from '@aws-sdk/client-bedrock-runtime';
import { constructStack } from '@smithy/middleware-stack';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { ActionsClient } from '@kbn/actions-plugin/server';

import { NodeHttpHandler } from './node_http_handler';

export interface CustomChatModelInput extends BedrockRuntimeClientConfig {
actionsClient: PublicMethodsOf<ActionsClient>;
connectorId: string;
streaming?: boolean;
}

export class BedrockRuntimeClient extends _BedrockRuntimeClient {
middlewareStack: _BedrockRuntimeClient['middlewareStack'];

constructor({ actionsClient, connectorId, ...fields }: CustomChatModelInput) {
super(fields ?? {});
this.config.requestHandler = new NodeHttpHandler({
streaming: fields.streaming ?? true,
actionsClient,
connectorId,
});
// eliminate middleware steps that handle auth as Kibana connector handles auth
this.middlewareStack = constructStack() as _BedrockRuntimeClient['middlewareStack'];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ChatBedrockConverse } from '@langchain/aws';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { Logger } from '@kbn/logging';
import { PublicMethodsOf } from '@kbn/utility-types';
import { BedrockRuntimeClient } from './bedrock_runtime_client';
import { DEFAULT_BEDROCK_MODEL, DEFAULT_BEDROCK_REGION } from '../../utils/bedrock';

export interface CustomChatModelInput extends BaseChatModelParams {
actionsClient: PublicMethodsOf<ActionsClient>;
connectorId: string;
logger: Logger;
signal?: AbortSignal;
model?: string;
}

/**
* Custom chat model class for Bedrock Converse API.
* The ActionsClientChatBedrockConverse chat model supports streaming and
* non-streaming via the Bedrock Converse and ConverseStream APIs.
*
* @param {Object} params - The parameters for the chat model.
* @param {ActionsClient} params.actionsClient - The actions client.
* @param {string} params.connectorId - The connector ID.
* @param {Logger} params.logger - The logger instance.
* @param {AbortSignal} [params.signal] - Optional abort signal.
* @param {string} [params.model] - Optional model name.
*/
export class ActionsClientChatBedrockConverse extends ChatBedrockConverse {
constructor({ actionsClient, connectorId, logger, ...fields }: CustomChatModelInput) {
super({
...(fields ?? {}),
credentials: { accessKeyId: '', secretAccessKey: '' },
model: fields?.model ?? DEFAULT_BEDROCK_MODEL,
region: DEFAULT_BEDROCK_REGION,
});
this.client = new BedrockRuntimeClient({
actionsClient,
connectorId,
streaming: this.streaming,
region: DEFAULT_BEDROCK_REGION,
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ActionsClientChatBedrockConverse } from './chat_bedrock_converse';

export { ActionsClientChatBedrockConverse };
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { NodeHttpHandler } from './node_http_handler';
import { HttpRequest } from '@smithy/protocol-http';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
import { Readable } from 'stream';
import { fromUtf8 } from '@smithy/util-utf8';

const mockActionsClient = actionsClientMock.create();
const connectorId = 'mock-connector-id';
const mockOutput = {
output: {
message: {
role: 'assistant',
content: [{ text: 'This is a response from the assistant.' }],
},
},
stopReason: 'end_turn',
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
metrics: { latencyMs: 123 },
additionalModelResponseFields: {},
trace: { guardrail: { modelOutput: ['Output text'] } },
};
describe('NodeHttpHandler', () => {
let handler: NodeHttpHandler;

beforeEach(() => {
jest.clearAllMocks();
handler = new NodeHttpHandler({
streaming: false,
actionsClient: mockActionsClient,
connectorId,
});

mockActionsClient.execute.mockResolvedValue({
data: mockOutput,
actionId: 'mock-action-id',
status: 'ok',
});
});

it('handles non-streaming requests successfully', async () => {
const request = new HttpRequest({
body: JSON.stringify({ messages: [] }),
});

const result = await handler.handle(request);

expect(result.response.statusCode).toBe(200);
expect(result.response.headers['content-type']).toBe('application/json');
expect(result.response.body).toStrictEqual(fromUtf8(JSON.stringify(mockOutput)));
});

it('handles streaming requests successfully', async () => {
handler = new NodeHttpHandler({
streaming: true,
actionsClient: mockActionsClient,
connectorId,
});

const request = new HttpRequest({
body: JSON.stringify({ messages: [] }),
});

const readable = new Readable();
readable.push('streaming data');
readable.push(null);

mockActionsClient.execute.mockResolvedValue({
data: readable,
status: 'ok',
actionId: 'mock-action-id',
});

const result = await handler.handle(request);

expect(result.response.statusCode).toBe(200);
expect(result.response.body).toBe(readable);
});

it('throws an error for non-streaming requests with error status', async () => {
const request = new HttpRequest({
body: JSON.stringify({ messages: [] }),
});

mockActionsClient.execute.mockResolvedValue({
status: 'error',
message: 'error message',
serviceMessage: 'service error message',
actionId: 'mock-action-id',
});

await expect(handler.handle(request)).rejects.toThrow(
'ActionsClientBedrockChat: action result status is error: error message - service error message'
);
});

it('throws an error for streaming requests with error status', async () => {
handler = new NodeHttpHandler({
streaming: true,
actionsClient: mockActionsClient,
connectorId,
});

const request = new HttpRequest({
body: JSON.stringify({ messages: [] }),
});

mockActionsClient.execute.mockResolvedValue({
status: 'error',
message: 'error message',
serviceMessage: 'service error message',
actionId: 'mock-action-id',
});

await expect(handler.handle(request)).rejects.toThrow(
'ActionsClientBedrockChat: action result status is error: error message - service error message'
);
});
});
Loading

0 comments on commit 755ef31

Please sign in to comment.