Skip to content

Commit

Permalink
[Security Assistant] Vertex chat model (#193032)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Oct 4, 2024
1 parent ef3bc96 commit aae8c50
Show file tree
Hide file tree
Showing 19 changed files with 941 additions and 320 deletions.
7 changes: 5 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"**/@bazel/typescript/protobufjs": "6.11.4",
"**/@hello-pangea/dnd": "16.6.0",
"**/@langchain/core": "^0.2.18",
"**/@langchain/google-common": "^0.1.1",
"**/@types/node": "20.10.5",
"**/@typescript-eslint/utils": "5.62.0",
"**/chokidar": "^3.5.3",
Expand Down Expand Up @@ -999,7 +1000,9 @@
"@kbn/zod-helpers": "link:packages/kbn-zod-helpers",
"@langchain/community": "0.2.18",
"@langchain/core": "^0.2.18",
"@langchain/google-genai": "^0.0.23",
"@langchain/google-common": "^0.1.1",
"@langchain/google-genai": "^0.1.0",
"@langchain/google-vertexai": "^0.1.0",
"@langchain/langgraph": "0.0.34",
"@langchain/openai": "^0.1.3",
"@langtrase/trace-attributes": "^3.0.8",
Expand Down Expand Up @@ -1148,7 +1151,7 @@
"jsts": "^1.6.2",
"kea": "^2.6.0",
"langchain": "^0.2.11",
"langsmith": "^0.1.39",
"langsmith": "^0.1.55",
"launchdarkly-js-client-sdk": "^3.4.0",
"load-json-file": "^6.2.0",
"lodash": "^4.17.21",
Expand Down
2 changes: 2 additions & 0 deletions x-pack/packages/kbn-langchain/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { ActionsClientChatOpenAI } from './language_models/chat_openai';
import { ActionsClientLlm } from './language_models/llm';
import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model';
import { ActionsClientGeminiChatModel } from './language_models/gemini_chat';
import { ActionsClientChatVertexAI } from './language_models/chat_vertex';
import { parseBedrockStream } from './utils/bedrock';
import { parseGeminiResponse } from './utils/gemini';
import { getDefaultArguments } from './language_models/constants';
Expand All @@ -20,6 +21,7 @@ export {
getDefaultArguments,
ActionsClientBedrockChatModel,
ActionsClientChatOpenAI,
ActionsClientChatVertexAI,
ActionsClientGeminiChatModel,
ActionsClientLlm,
ActionsClientSimpleChatModel,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { PassThrough } from 'stream';
import { loggerMock } from '@kbn/logging-mocks';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';

import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { ActionsClientChatVertexAI } from './chat_vertex';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';

const connectorId = 'mock-connector-id';

const mockExecute = jest.fn();
const actionsClient = actionsClientMock.create();

const mockLogger = loggerMock.create();

const mockStreamExecute = jest.fn().mockImplementation(() => {
const passThrough = new PassThrough();

// Write the data chunks to the stream
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token1"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token2"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
});
setTimeout(() => {
passThrough.write(
Buffer.from(
`data: {"candidates": [{"content": {"role": "model","parts": [{"text": "token3"}]}}],"modelVersion": "gemini-1.5-pro-001"}`
)
);
// End the stream
passThrough.end();
});

return {
data: passThrough, // PassThrough stream will act as the async iterator
status: 'ok',
};
});

const callMessages = [
new SystemMessage('Answer the following questions truthfully and as best you can.'),
new HumanMessage('Question: Do you know my name?\n\n'),
] as unknown as BaseMessage[];

const callOptions = {
stop: ['\n'],
recursionLimit: 0,
/** Maximum number of parallel calls to make. */
maxConcurrency: 0,
};
const handleLLMNewToken = jest.fn();
const callRunManager = {
handleLLMNewToken,
} as unknown as CallbackManagerForLLMRun;
const onFailedAttempt = jest.fn();
const defaultArgs = {
actionsClient,
connectorId,
logger: mockLogger,
streaming: false,
maxRetries: 0,
onFailedAttempt,
};

const testMessage = 'Yes, your name is Andrew. How can I assist you further, Andrew?';

export const mockActionResponse = {
candidates: [
{
content: {
role: 'model',
parts: [
{
text: testMessage,
},
],
},
finishReason: 'STOP',
},
],
usageMetadata: { input_tokens: 4, output_tokens: 10, total_tokens: 14 },
};

describe('ActionsClientChatVertexAI', () => {
beforeEach(() => {
jest.clearAllMocks();
actionsClient.execute.mockImplementation(
jest.fn().mockImplementation(() => ({
data: mockActionResponse,
status: 'ok',
}))
);
mockExecute.mockImplementation(() => ({
data: mockActionResponse,
status: 'ok',
}));
});

describe('_generate streaming: false', () => {
it('returns the expected content when _generate is invoked', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

const result = await actionsClientChatVertexAI._generate(
callMessages,
callOptions,
callRunManager
);
const subAction = actionsClient.execute.mock.calls[0][0].params.subAction;
expect(subAction).toEqual('invokeAIRaw');

expect(result.generations[0].text).toEqual(testMessage);
});

it('rejects with the expected error when the action result status is error', async () => {
const hasErrorStatus = jest.fn().mockImplementation(() => {
throw new Error(
'ActionsClientChatVertexAI: action result status is error: action-result-message - action-result-service-message'
);
});

actionsClient.execute.mockRejectedValueOnce(hasErrorStatus);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI({
...defaultArgs,
actionsClient,
});

await expect(
actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager)
).rejects.toThrowError();
expect(onFailedAttempt).toHaveBeenCalled();
});

it('rejects with the expected error the message has invalid content', async () => {
actionsClient.execute.mockImplementation(
jest.fn().mockResolvedValue({
data: {
Bad: true,
finishReason: 'badness',
},
status: 'ok',
})
);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await expect(
actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager)
).rejects.toThrowError("Cannot read properties of undefined (reading 'text')");
});
});

describe('*_streamResponseChunks', () => {
it('iterates over gemini chunks', async () => {
actionsClient.execute.mockImplementationOnce(mockStreamExecute);

const actionsClientChatVertexAI = new ActionsClientChatVertexAI({
...defaultArgs,
actionsClient,
streaming: true,
});

const gen = actionsClientChatVertexAI._streamResponseChunks(
callMessages,
callOptions,
callRunManager
);

const chunks = [];

for await (const chunk of gen) {
chunks.push(chunk);
}

expect(chunks.map((c) => c.text)).toEqual(['token1', 'token2', 'token3']);
expect(handleLLMNewToken).toHaveBeenCalledTimes(3);
expect(handleLLMNewToken).toHaveBeenCalledWith('token1');
expect(handleLLMNewToken).toHaveBeenCalledWith('token2');
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});
});
Loading

0 comments on commit aae8c50

Please sign in to comment.