Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Assistant] Vertex chat model #193032

Merged
merged 32 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5ba961e
working for non-streaming
stephmilovic Sep 13, 2024
3721107
Merge branch 'main' into vertex_chat
stephmilovic Sep 16, 2024
b879c00
prompts
stephmilovic Sep 16, 2024
af46a8e
rm from integrations assistant
stephmilovic Sep 16, 2024
4f4c673
add streaming
stephmilovic Sep 16, 2024
4d56a8b
Merge branch 'main' into vertex_chat
stephmilovic Sep 16, 2024
2f9e725
import from LC
stephmilovic Sep 16, 2024
b87fed3
revert translation
stephmilovic Sep 16, 2024
857d68f
fix types
stephmilovic Sep 16, 2024
1fe3e7e
add test
stephmilovic Sep 17, 2024
c29aba4
[CI] Auto-commit changed files from 'node scripts/notice'
kibanamachine Sep 17, 2024
7488ea1
Merge branch 'main' into vertex_chat
stephmilovic Sep 17, 2024
a631591
fix
stephmilovic Sep 17, 2024
c79057d
rm elastic-assistant reference
stephmilovic Sep 17, 2024
28b44b0
[CI] Auto-commit changed files from 'node scripts/notice'
kibanamachine Sep 17, 2024
62d0fba
fix circular dep
stephmilovic Sep 17, 2024
5618557
Merge branch 'main' into vertex_chat
stephmilovic Sep 18, 2024
5c3e678
better prompting
stephmilovic Sep 18, 2024
5afa6d9
Merge branch 'main' into vertex_chat
stephmilovic Sep 30, 2024
877c187
Merge branch 'main' into vertex_chat
stephmilovic Sep 30, 2024
a8c5582
prompt wip
stephmilovic Sep 30, 2024
0a05507
Merge branch 'main' into vertex_chat
stephmilovic Oct 1, 2024
df96c81
working pretty well
stephmilovic Oct 1, 2024
aa42f6f
fix finish reason
stephmilovic Oct 2, 2024
04da653
prompt fixes
stephmilovic Oct 2, 2024
359cb4e
wip
stephmilovic Oct 3, 2024
1fc00ac
Merge branch 'main' into vertex_chat
stephmilovic Oct 3, 2024
d7d8094
Revert "wip"
stephmilovic Oct 3, 2024
b35ba41
rm logs
stephmilovic Oct 3, 2024
5c3459f
revert topK
stephmilovic Oct 3, 2024
08f4fb2
no role in system instruction
stephmilovic Oct 3, 2024
c5a684f
rename const
stephmilovic Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"**/@bazel/typescript/protobufjs": "6.11.4",
"**/@hello-pangea/dnd": "16.6.0",
"**/@langchain/core": "^0.2.18",
"**/@langchain/google-common": "^0.1.1",
"**/@types/node": "20.10.5",
"**/@typescript-eslint/utils": "5.62.0",
"**/chokidar": "^3.5.3",
Expand Down Expand Up @@ -999,7 +1000,9 @@
"@kbn/zod-helpers": "link:packages/kbn-zod-helpers",
"@langchain/community": "0.2.18",
"@langchain/core": "^0.2.18",
"@langchain/google-genai": "^0.0.23",
"@langchain/google-common": "^0.1.1",
"@langchain/google-genai": "^0.1.0",
"@langchain/google-vertexai": "^0.1.0",
"@langchain/langgraph": "0.0.34",
"@langchain/openai": "^0.1.3",
"@langtrase/trace-attributes": "^3.0.8",
Expand Down Expand Up @@ -1148,7 +1151,7 @@
"jsts": "^1.6.2",
"kea": "^2.6.0",
"langchain": "^0.2.11",
"langsmith": "^0.1.39",
"langsmith": "^0.1.55",
"launchdarkly-js-client-sdk": "^3.4.0",
"load-json-file": "^6.2.0",
"lodash": "^4.17.21",
Expand Down
2 changes: 2 additions & 0 deletions x-pack/packages/kbn-langchain/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { ActionsClientChatOpenAI } from './language_models/chat_openai';
import { ActionsClientLlm } from './language_models/llm';
import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model';
import { ActionsClientGeminiChatModel } from './language_models/gemini_chat';
import { ActionsClientChatVertexAI } from './language_models/chat_vertex';
import { parseBedrockStream } from './utils/bedrock';
import { parseGeminiResponse } from './utils/gemini';
import { getDefaultArguments } from './language_models/constants';
Expand All @@ -20,6 +21,7 @@ export {
getDefaultArguments,
ActionsClientBedrockChatModel,
ActionsClientChatOpenAI,
ActionsClientChatVertexAI,
ActionsClientGeminiChatModel,
ActionsClientLlm,
ActionsClientSimpleChatModel,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { PassThrough } from 'stream';
import { loggerMock } from '@kbn/logging-mocks';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';

import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { ActionsClientChatVertexAI } from './chat_vertex';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';

const connectorId = 'mock-connector-id';

const mockExecute = jest.fn();
const actionsClient = actionsClientMock.create();

const mockLogger = loggerMock.create();

const mockStreamExecute = jest.fn().mockImplementation(() => {
const passThrough = new PassThrough();

// Write the data chunks to the stream
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
// End the stream
passThrough.end();
});

return {
data: passThrough, // PassThrough stream will act as the async iterator
status: 'ok',
};
});

const callMessages = [
new SystemMessage('Answer the following questions truthfully and as best you can.'),
new HumanMessage('Question: Do you know my name?\n\n'),
] as unknown as BaseMessage[];

const callOptions = {
stop: ['\n'],
recursionLimit: 0,
/** Maximum number of parallel calls to make. */
maxConcurrency: 0,
};
const handleLLMNewToken = jest.fn();
const callRunManager = {
handleLLMNewToken,
} as unknown as CallbackManagerForLLMRun;
const onFailedAttempt = jest.fn();
const defaultArgs = {
actionsClient,
connectorId,
logger: mockLogger,
streaming: false,
maxRetries: 0,
onFailedAttempt,
};

const testMessage = 'Yes, your name is Andrew. How can I assist you further, Andrew?';

export const mockActionResponse = {
candidates: [
{
content: {
role: 'model',
parts: [
{
text: testMessage,
},
],
},
finishReason: 'STOP',
},
],
usageMetadata: { input_tokens: 4, output_tokens: 10, total_tokens: 14 },
};

describe('ActionsClientChatVertexAI', () => {
beforeEach(() => {
jest.clearAllMocks();
actionsClient.execute.mockImplementation(
jest.fn().mockImplementation(() => ({
data: mockActionResponse,
status: 'ok',
}))
);
mockExecute.mockImplementation(() => ({
data: mockActionResponse,
status: 'ok',
}));
});

describe('_generate streaming: false', () => {
it('returns the expected content when _generate is invoked', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

const result = await actionsClientChatVertexAI._generate(
callMessages,
callOptions,
callRunManager
);
const subAction = actionsClient.execute.mock.calls[0][0].params.subAction;
expect(subAction).toEqual('invokeAIRaw');

expect(result.generations[0].text).toEqual(testMessage);
});

it('rejects with the expected error when the action result status is error', async () => {
const hasErrorStatus = jest.fn().mockImplementation(() => {
throw new Error(
'ActionsClientChatVertexAI: action result status is error: action-result-message - action-result-service-message'
);
});

actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI({
...defaultArgs,
actionsClient,
});

await expect(
actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager)
).rejects.toThrowError();
expect(onFailedAttempt).toHaveBeenCalled();
});

it('rejects with the expected error the message has invalid content', async () => {
actionsClient.execute.mockImplementation(
jest.fn().mockResolvedValue({
data: {
Bad: true,
finishReason: 'badness',
},
status: 'ok',
})
);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await expect(
actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager)
).rejects.toThrowError("Cannot read properties of undefined (reading 'text')");
});
});

describe('*_streamResponseChunks', () => {
it('iterates over gemini chunks', async () => {
actionsClient.execute.mockImplementationOnce(mockStreamExecute);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI({
...defaultArgs,
actionsClient,
streaming: true,
});

const gen = actionsClientChatVertexAI._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);

const chunks = [];

for await (const chunk of gen) {
chunks.push(chunk);
}

expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});
});
Loading