-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Security solution]
ChatBedrockConverse
(#200042)
- Loading branch information
1 parent
8e7799a
commit 755ef31
Showing
19 changed files
with
2,482 additions
and
156 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
...ages/kbn-langchain/server/language_models/chat_bedrock_converse/bedrock_runtime_client.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { | ||
BedrockRuntimeClient as _BedrockRuntimeClient, | ||
BedrockRuntimeClientConfig, | ||
} from '@aws-sdk/client-bedrock-runtime'; | ||
import { constructStack } from '@smithy/middleware-stack'; | ||
import { PublicMethodsOf } from '@kbn/utility-types'; | ||
import type { ActionsClient } from '@kbn/actions-plugin/server'; | ||
|
||
import { NodeHttpHandler } from './node_http_handler'; | ||
|
||
export interface CustomChatModelInput extends BedrockRuntimeClientConfig { | ||
actionsClient: PublicMethodsOf<ActionsClient>; | ||
connectorId: string; | ||
streaming?: boolean; | ||
} | ||
|
||
export class BedrockRuntimeClient extends _BedrockRuntimeClient { | ||
middlewareStack: _BedrockRuntimeClient['middlewareStack']; | ||
|
||
constructor({ actionsClient, connectorId, ...fields }: CustomChatModelInput) { | ||
super(fields ?? {}); | ||
this.config.requestHandler = new NodeHttpHandler({ | ||
streaming: fields.streaming ?? true, | ||
actionsClient, | ||
connectorId, | ||
}); | ||
// eliminate middleware steps that handle auth as Kibana connector handles auth | ||
this.middlewareStack = constructStack() as _BedrockRuntimeClient['middlewareStack']; | ||
} | ||
} |
50 changes: 50 additions & 0 deletions
50
...kages/kbn-langchain/server/language_models/chat_bedrock_converse/chat_bedrock_converse.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
import { ChatBedrockConverse } from '@langchain/aws'; | ||
import type { ActionsClient } from '@kbn/actions-plugin/server'; | ||
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models'; | ||
import { Logger } from '@kbn/logging'; | ||
import { PublicMethodsOf } from '@kbn/utility-types'; | ||
import { BedrockRuntimeClient } from './bedrock_runtime_client'; | ||
import { DEFAULT_BEDROCK_MODEL, DEFAULT_BEDROCK_REGION } from '../../utils/bedrock'; | ||
|
||
export interface CustomChatModelInput extends BaseChatModelParams { | ||
actionsClient: PublicMethodsOf<ActionsClient>; | ||
connectorId: string; | ||
logger: Logger; | ||
signal?: AbortSignal; | ||
model?: string; | ||
} | ||
|
||
/** | ||
* Custom chat model class for Bedrock Converse API. | ||
* The ActionsClientChatBedrockConverse chat model supports streaming and | ||
* non-streaming via the Bedrock Converse and ConverseStream APIs. | ||
* | ||
* @param {Object} params - The parameters for the chat model. | ||
* @param {ActionsClient} params.actionsClient - The actions client. | ||
* @param {string} params.connectorId - The connector ID. | ||
* @param {Logger} params.logger - The logger instance. | ||
* @param {AbortSignal} [params.signal] - Optional abort signal. | ||
* @param {string} [params.model] - Optional model name. | ||
*/ | ||
export class ActionsClientChatBedrockConverse extends ChatBedrockConverse { | ||
constructor({ actionsClient, connectorId, logger, ...fields }: CustomChatModelInput) { | ||
super({ | ||
...(fields ?? {}), | ||
credentials: { accessKeyId: '', secretAccessKey: '' }, | ||
model: fields?.model ?? DEFAULT_BEDROCK_MODEL, | ||
region: DEFAULT_BEDROCK_REGION, | ||
}); | ||
this.client = new BedrockRuntimeClient({ | ||
actionsClient, | ||
connectorId, | ||
streaming: this.streaming, | ||
region: DEFAULT_BEDROCK_REGION, | ||
}); | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
x-pack/packages/kbn-langchain/server/language_models/chat_bedrock_converse/index.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { ActionsClientChatBedrockConverse } from './chat_bedrock_converse'; | ||
|
||
export { ActionsClientChatBedrockConverse }; |
125 changes: 125 additions & 0 deletions
125
...ages/kbn-langchain/server/language_models/chat_bedrock_converse/node_http_handler.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { NodeHttpHandler } from './node_http_handler'; | ||
import { HttpRequest } from '@smithy/protocol-http'; | ||
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; | ||
import { Readable } from 'stream'; | ||
import { fromUtf8 } from '@smithy/util-utf8'; | ||
|
||
const mockActionsClient = actionsClientMock.create(); | ||
const connectorId = 'mock-connector-id'; | ||
const mockOutput = { | ||
output: { | ||
message: { | ||
role: 'assistant', | ||
content: [{ text: 'This is a response from the assistant.' }], | ||
}, | ||
}, | ||
stopReason: 'end_turn', | ||
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, | ||
metrics: { latencyMs: 123 }, | ||
additionalModelResponseFields: {}, | ||
trace: { guardrail: { modelOutput: ['Output text'] } }, | ||
}; | ||
describe('NodeHttpHandler', () => { | ||
let handler: NodeHttpHandler; | ||
|
||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
handler = new NodeHttpHandler({ | ||
streaming: false, | ||
actionsClient: mockActionsClient, | ||
connectorId, | ||
}); | ||
|
||
mockActionsClient.execute.mockResolvedValue({ | ||
data: mockOutput, | ||
actionId: 'mock-action-id', | ||
status: 'ok', | ||
}); | ||
}); | ||
|
||
it('handles non-streaming requests successfully', async () => { | ||
const request = new HttpRequest({ | ||
body: JSON.stringify({ messages: [] }), | ||
}); | ||
|
||
const result = await handler.handle(request); | ||
|
||
expect(result.response.statusCode).toBe(200); | ||
expect(result.response.headers['content-type']).toBe('application/json'); | ||
expect(result.response.body).toStrictEqual(fromUtf8(JSON.stringify(mockOutput))); | ||
}); | ||
|
||
it('handles streaming requests successfully', async () => { | ||
handler = new NodeHttpHandler({ | ||
streaming: true, | ||
actionsClient: mockActionsClient, | ||
connectorId, | ||
}); | ||
|
||
const request = new HttpRequest({ | ||
body: JSON.stringify({ messages: [] }), | ||
}); | ||
|
||
const readable = new Readable(); | ||
readable.push('streaming data'); | ||
readable.push(null); | ||
|
||
mockActionsClient.execute.mockResolvedValue({ | ||
data: readable, | ||
status: 'ok', | ||
actionId: 'mock-action-id', | ||
}); | ||
|
||
const result = await handler.handle(request); | ||
|
||
expect(result.response.statusCode).toBe(200); | ||
expect(result.response.body).toBe(readable); | ||
}); | ||
|
||
it('throws an error for non-streaming requests with error status', async () => { | ||
const request = new HttpRequest({ | ||
body: JSON.stringify({ messages: [] }), | ||
}); | ||
|
||
mockActionsClient.execute.mockResolvedValue({ | ||
status: 'error', | ||
message: 'error message', | ||
serviceMessage: 'service error message', | ||
actionId: 'mock-action-id', | ||
}); | ||
|
||
await expect(handler.handle(request)).rejects.toThrow( | ||
'ActionsClientBedrockChat: action result status is error: error message - service error message' | ||
); | ||
}); | ||
|
||
it('throws an error for streaming requests with error status', async () => { | ||
handler = new NodeHttpHandler({ | ||
streaming: true, | ||
actionsClient: mockActionsClient, | ||
connectorId, | ||
}); | ||
|
||
const request = new HttpRequest({ | ||
body: JSON.stringify({ messages: [] }), | ||
}); | ||
|
||
mockActionsClient.execute.mockResolvedValue({ | ||
status: 'error', | ||
message: 'error message', | ||
serviceMessage: 'service error message', | ||
actionId: 'mock-action-id', | ||
}); | ||
|
||
await expect(handler.handle(request)).rejects.toThrow( | ||
'ActionsClientBedrockChat: action result status is error: error message - service error message' | ||
); | ||
}); | ||
}); |
Oops, something went wrong.