Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] Bedrock streaming and token tracking #170815

Merged
merged 38 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a1929d0
bedrock magic
stephmilovic Nov 7, 2023
f49cf23
cleanup
stephmilovic Nov 7, 2023
8738b15
cleanup and openai wip
stephmilovic Nov 8, 2023
427f6d2
better
stephmilovic Nov 8, 2023
e32ffdb
wip
stephmilovic Nov 8, 2023
237f6fc
token tracking for bedrock
stephmilovic Nov 9, 2023
4604df4
cleanup
stephmilovic Nov 9, 2023
9e69031
rm
stephmilovic Nov 9, 2023
1a396f5
cleanup
stephmilovic Nov 9, 2023
8696210
fix api tests
stephmilovic Nov 9, 2023
2eabb1b
token tests
stephmilovic Nov 10, 2023
b1154db
security solution tests
stephmilovic Nov 10, 2023
a375387
stack connector tests
stephmilovic Nov 10, 2023
886ad47
Merge branch 'main' into bedrock_streaming
stephmilovic Nov 14, 2023
ab85ac4
WIP
stephmilovic Nov 14, 2023
f5c8a85
update package.json whitespace?
stephmilovic Nov 14, 2023
a57fae7
Merge branch 'bedrock_streaming' into bedrock_streaming_integration_t…
stephmilovic Nov 14, 2023
fe858d7
cleanup
stephmilovic Nov 14, 2023
b1e70a7
fix
stephmilovic Nov 14, 2023
ef6fc8e
make streamApi private
stephmilovic Nov 14, 2023
4f95837
Merge branch 'bedrock_streaming' into bedrock_streaming_integration_t…
stephmilovic Nov 14, 2023
b16913f
comment the code better
stephmilovic Nov 14, 2023
6875c63
fix comment
stephmilovic Nov 14, 2023
9ab90b4
Merge branch 'bedrock_streaming' into bedrock_streaming_integration_t…
stephmilovic Nov 15, 2023
b291fe3
Sergi PR changes
stephmilovic Nov 15, 2023
c8957c5
Sergi was right
stephmilovic Nov 15, 2023
ae4e85d
one more!
stephmilovic Nov 15, 2023
40bd1c9
Merge branch 'bedrock_streaming' into bedrock_streaming_integration_t…
stephmilovic Nov 15, 2023
3f70494
fix whoops
stephmilovic Nov 15, 2023
902e7ba
add tests for shouldTrackGenAiToken
stephmilovic Nov 15, 2023
10c188b
Merge branch 'bedrock_streaming' into bedrock_streaming_integration_t…
stephmilovic Nov 15, 2023
325b0ff
commit
stephmilovic Nov 15, 2023
b5e8e83
done?
stephmilovic Nov 15, 2023
ce6668f
really
stephmilovic Nov 15, 2023
33d1f92
better error handling
stephmilovic Nov 15, 2023
5b7cf70
tests for line buffer
stephmilovic Nov 16, 2023
cd6cb1e
fixed
stephmilovic Nov 16, 2023
25ffb7e
Merge branch 'main' into bedrock_streaming
stephmilovic Nov 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,8 @@
"@opentelemetry/semantic-conventions": "^1.4.0",
"@reduxjs/toolkit": "1.7.2",
"@slack/webhook": "^5.0.4",
"@smithy/eventstream-codec": "^2.0.12",
"@smithy/util-utf8": "^2.0.0",
"@tanstack/react-query": "^4.29.12",
"@tanstack/react-query-devtools": "^4.29.12",
"@turf/along": "6.0.1",
Expand Down
51 changes: 46 additions & 5 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,20 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
headers: { 'Content-Type': 'application/json' },
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeStream"},"assistantLangChain":false}',
method: 'POST',
asResponse: true,
rawResponse: true,
signal: undefined,
}
);
});

it('returns API_ERROR when the response status is not ok', async () => {
it('returns API_ERROR when the response status is error and langchain is on', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
Expand All @@ -98,10 +99,50 @@ describe('API tests', () => {
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
});

it('returns API_ERROR when the response status is error, langchain is off, and response is not a reader', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toEqual({
response: `${API_ERROR}\n\nCould not get reader from response`,
isStream: false,
isError: true,
});
});

it('returns API_ERROR when the response is error, langchain is off, and response is a reader', async () => {
const mockReader = jest.fn();
(mockHttp.fetch as jest.Mock).mockRejectedValue({
response: { body: { getReader: jest.fn().mockImplementation(() => mockReader) } },
});
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toEqual({
response: mockReader,
isStream: true,
isError: true,
});
});

it('returns API_ERROR when there are no choices', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'ok', data: '' });
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
Expand Down
26 changes: 17 additions & 9 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,16 @@ export const fetchConnectorExecuteAction = async ({
messages: outboundMessages,
};

// TODO: Remove in part 2 of streaming work for security solution
// TODO: Remove in part 3 of streaming work for security solution
// tracked here: https://github.com/elastic/security-team/issues/7363
// My "Feature Flag", turn to false before merging
// In part 2 I will make enhancements to invokeAI to make it work with both openA, but to keep it to a Security Soltuion only review on this PR,
// I'm calling the stream action directly
const isStream = !assistantLangChain && false;
// In part 3 I will make enhancements to langchain to introduce streaming
// Once implemented, invokeAI can be removed
const isStream = !assistantLangChain;
const requestBody = isStream
? {
params: {
subActionParams: body,
subAction: 'stream',
subAction: 'invokeStream',
},
assistantLangChain,
}
Expand Down Expand Up @@ -105,7 +104,7 @@ export const fetchConnectorExecuteAction = async ({
};
}

// TODO: Remove in part 2 of streaming work for security solution
// TODO: Remove in part 3 of streaming work for security solution
// tracked here: https://github.com/elastic/security-team/issues/7363
// This is a temporary code to support the non-streaming API
const response = await http.fetch<{
Expand Down Expand Up @@ -140,10 +139,19 @@ export const fetchConnectorExecuteAction = async ({
isStream: false,
};
} catch (error) {
const reader = error?.response?.body?.getReader();

if (!reader) {
return {
response: `${API_ERROR}\n\n${error?.body?.message ?? error?.message}`,
isError: true,
isStream: false,
};
}
return {
response: `${API_ERROR}\n\n${error?.body?.message ?? error?.message}`,
response: reader,
isStream: true,
isError: true,
isStream: false,
};
}
};
Expand Down
76 changes: 27 additions & 49 deletions x-pack/plugins/actions/server/lib/action_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import type { PublicMethodsOf } from '@kbn/utility-types';
import { Logger, KibanaRequest } from '@kbn/core/server';
import { cloneDeep } from 'lodash';
import { set } from '@kbn/safer-lodash-set';
import { withSpan } from '@kbn/apm-utils';
import { EncryptedSavedObjectsClient } from '@kbn/encrypted-saved-objects-plugin/server';
import { SpacesServiceStart } from '@kbn/spaces-plugin/server';
import { IEventLogger, SAVED_OBJECT_REL_PRIMARY } from '@kbn/event-log-plugin/server';
import { SecurityPluginStart } from '@kbn/security-plugin/server';
import { PassThrough, Readable } from 'stream';
import { getGenAiTokenTracking, shouldTrackGenAiToken } from './gen_ai_token_tracking';
import {
validateParams,
validateConfig,
Expand All @@ -38,7 +39,6 @@ import { RelatedSavedObjects } from './related_saved_objects';
import { createActionEventLogRecordObject } from './create_action_event_log_record_object';
import { ActionExecutionError, ActionExecutionErrorReason } from './errors/action_execution_error';
import type { ActionsAuthorization } from '../authorization/actions_authorization';
import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream';

// 1,000,000 nanoseconds in 1 millisecond
const Millis2Nanos = 1000 * 1000;
Expand Down Expand Up @@ -328,55 +328,33 @@ export class ActionExecutor {
eventLogger.logEvent(event);
}

// start openai extension
// add event.kibana.action.execution.openai to event log when OpenAI Connector is executed
if (result.status === 'ok' && actionTypeId === '.gen-ai') {
const data = result.data as unknown as {
usage: { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number };
};
event.kibana = event.kibana || {};
event.kibana.action = event.kibana.action || {};
event.kibana = {
...event.kibana,
action: {
...event.kibana.action,
execution: {
...event.kibana.action.execution,
gen_ai: {
usage: {
total_tokens: data.usage?.total_tokens,
prompt_tokens: data.usage?.prompt_tokens,
completion_tokens: data.usage?.completion_tokens,
},
},
},
},
};

if (result.data instanceof Readable) {
getTokenCountFromOpenAIStream({
responseStream: result.data.pipe(new PassThrough()),
body: (validatedParams as { subActionParams: { body: string } }).subActionParams.body,
// start genai extension
if (result.status === 'ok' && shouldTrackGenAiToken(actionTypeId)) {
getGenAiTokenTracking({
actionTypeId,
logger,
result,
validatedParams,
})
.then((tokenTracking) => {
if (tokenTracking != null) {
set(event, 'kibana.action.execution.get_ai.usage', {
total_tokens: tokenTracking.total_tokens,
prompt_tokens: tokenTracking.prompt_tokens,
completion_tokens: tokenTracking.completion_tokens,
});
}
})
.then(({ total, prompt, completion }) => {
event.kibana!.action!.execution!.gen_ai!.usage = {
total_tokens: total,
prompt_tokens: prompt,
completion_tokens: completion,
};
})
.catch((err) => {
logger.error('Failed to calculate tokens from streaming response');
logger.error(err);
})
.finally(() => {
completeEventLogging();
});

return resultWithoutError;
}
.catch((err) => {
logger.error('Failed to calculate tokens from streaming response');
logger.error(err);
})
.finally(() => {
completeEventLogging();
});
return resultWithoutError;
}
// end openai extension
// end genai extension

completeEventLogging();

Expand Down
Loading