Skip to content

Commit

Permalink
[Security solution] Bedrock streaming and token tracking (elastic#170815
Browse files Browse the repository at this point in the history
)
  • Loading branch information
stephmilovic authored Nov 16, 2023
1 parent 18d65c4 commit d201610
Show file tree
Hide file tree
Showing 32 changed files with 1,403 additions and 224 deletions.
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,8 @@
"@opentelemetry/semantic-conventions": "^1.4.0",
"@reduxjs/toolkit": "1.7.2",
"@slack/webhook": "^5.0.4",
"@smithy/eventstream-codec": "^2.0.12",
"@smithy/util-utf8": "^2.0.0",
"@tanstack/react-query": "^4.29.12",
"@tanstack/react-query-devtools": "^4.29.12",
"@turf/along": "6.0.1",
Expand Down
51 changes: 46 additions & 5 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,20 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
headers: { 'Content-Type': 'application/json' },
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeStream"},"assistantLangChain":false}',
method: 'POST',
asResponse: true,
rawResponse: true,
signal: undefined,
}
);
});

it('returns API_ERROR when the response status is not ok', async () => {
it('returns API_ERROR when the response status is error and langchain is on', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
Expand All @@ -98,10 +99,50 @@ describe('API tests', () => {
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
});

it('returns API_ERROR when the response status is error, langchain is off, and response is not a reader', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });

const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toEqual({
response: `${API_ERROR}\n\nCould not get reader from response`,
isStream: false,
isError: true,
});
});

it('returns API_ERROR when the response is error, langchain is off, and response is a reader', async () => {
const mockReader = jest.fn();
(mockHttp.fetch as jest.Mock).mockRejectedValue({
response: { body: { getReader: jest.fn().mockImplementation(() => mockReader) } },
});
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};

const result = await fetchConnectorExecuteAction(testProps);

expect(result).toEqual({
response: mockReader,
isStream: true,
isError: true,
});
});

it('returns API_ERROR when there are no choices', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'ok', data: '' });
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: false,
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
Expand Down
26 changes: 17 additions & 9 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,16 @@ export const fetchConnectorExecuteAction = async ({
messages: outboundMessages,
};

// TODO: Remove in part 2 of streaming work for security solution
// TODO: Remove in part 3 of streaming work for security solution
// tracked here: https://github.com/elastic/security-team/issues/7363
// My "Feature Flag", turn to false before merging
// In part 2 I will make enhancements to invokeAI to make it work with both openA, but to keep it to a Security Soltuion only review on this PR,
// I'm calling the stream action directly
const isStream = !assistantLangChain && false;
// In part 3 I will make enhancements to langchain to introduce streaming
// Once implemented, invokeAI can be removed
const isStream = !assistantLangChain;
const requestBody = isStream
? {
params: {
subActionParams: body,
subAction: 'stream',
subAction: 'invokeStream',
},
assistantLangChain,
}
Expand Down Expand Up @@ -105,7 +104,7 @@ export const fetchConnectorExecuteAction = async ({
};
}

// TODO: Remove in part 2 of streaming work for security solution
// TODO: Remove in part 3 of streaming work for security solution
// tracked here: https://github.com/elastic/security-team/issues/7363
// This is a temporary code to support the non-streaming API
const response = await http.fetch<{
Expand Down Expand Up @@ -140,10 +139,19 @@ export const fetchConnectorExecuteAction = async ({
isStream: false,
};
} catch (error) {
const reader = error?.response?.body?.getReader();

if (!reader) {
return {
response: `${API_ERROR}\n\n${error?.body?.message ?? error?.message}`,
isError: true,
isStream: false,
};
}
return {
response: `${API_ERROR}\n\n${error?.body?.message ?? error?.message}`,
response: reader,
isStream: true,
isError: true,
isStream: false,
};
}
};
Expand Down
76 changes: 27 additions & 49 deletions x-pack/plugins/actions/server/lib/action_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import type { PublicMethodsOf } from '@kbn/utility-types';
import { Logger, KibanaRequest } from '@kbn/core/server';
import { cloneDeep } from 'lodash';
import { set } from '@kbn/safer-lodash-set';
import { withSpan } from '@kbn/apm-utils';
import { EncryptedSavedObjectsClient } from '@kbn/encrypted-saved-objects-plugin/server';
import { SpacesServiceStart } from '@kbn/spaces-plugin/server';
import { IEventLogger, SAVED_OBJECT_REL_PRIMARY } from '@kbn/event-log-plugin/server';
import { SecurityPluginStart } from '@kbn/security-plugin/server';
import { PassThrough, Readable } from 'stream';
import { getGenAiTokenTracking, shouldTrackGenAiToken } from './gen_ai_token_tracking';
import {
validateParams,
validateConfig,
Expand All @@ -38,7 +39,6 @@ import { RelatedSavedObjects } from './related_saved_objects';
import { createActionEventLogRecordObject } from './create_action_event_log_record_object';
import { ActionExecutionError, ActionExecutionErrorReason } from './errors/action_execution_error';
import type { ActionsAuthorization } from '../authorization/actions_authorization';
import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream';

// 1,000,000 nanoseconds in 1 millisecond
const Millis2Nanos = 1000 * 1000;
Expand Down Expand Up @@ -328,55 +328,33 @@ export class ActionExecutor {
eventLogger.logEvent(event);
}

// start openai extension
// add event.kibana.action.execution.openai to event log when OpenAI Connector is executed
if (result.status === 'ok' && actionTypeId === '.gen-ai') {
const data = result.data as unknown as {
usage: { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number };
};
event.kibana = event.kibana || {};
event.kibana.action = event.kibana.action || {};
event.kibana = {
...event.kibana,
action: {
...event.kibana.action,
execution: {
...event.kibana.action.execution,
gen_ai: {
usage: {
total_tokens: data.usage?.total_tokens,
prompt_tokens: data.usage?.prompt_tokens,
completion_tokens: data.usage?.completion_tokens,
},
},
},
},
};

if (result.data instanceof Readable) {
getTokenCountFromOpenAIStream({
responseStream: result.data.pipe(new PassThrough()),
body: (validatedParams as { subActionParams: { body: string } }).subActionParams.body,
// start genai extension
if (result.status === 'ok' && shouldTrackGenAiToken(actionTypeId)) {
getGenAiTokenTracking({
actionTypeId,
logger,
result,
validatedParams,
})
.then((tokenTracking) => {
if (tokenTracking != null) {
set(event, 'kibana.action.execution.gen_ai.usage', {
total_tokens: tokenTracking.total_tokens,
prompt_tokens: tokenTracking.prompt_tokens,
completion_tokens: tokenTracking.completion_tokens,
});
}
})
.then(({ total, prompt, completion }) => {
event.kibana!.action!.execution!.gen_ai!.usage = {
total_tokens: total,
prompt_tokens: prompt,
completion_tokens: completion,
};
})
.catch((err) => {
logger.error('Failed to calculate tokens from streaming response');
logger.error(err);
})
.finally(() => {
completeEventLogging();
});

return resultWithoutError;
}
.catch((err) => {
logger.error('Failed to calculate tokens from streaming response');
logger.error(err);
})
.finally(() => {
completeEventLogging();
});
return resultWithoutError;
}
// end openai extension
// end genai extension

completeEventLogging();

Expand Down
Loading

0 comments on commit d201610

Please sign in to comment.