Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] Token tracking for Converse APIs #200993

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_str
import {
getTokenCountFromInvokeStream,
InvokeBody,
parseBedrockConverseStream,
parseGeminiStreamForUsageMetadata,
} from './get_token_count_from_invoke_stream';

Expand Down Expand Up @@ -84,7 +85,7 @@ export const getGenAiTokenTracking = async ({
}
}

// this is a streamed Gemini response, using the subAction invokeStream to stream the response as a simple string
// streamed Gemini response, using the subAction invokeStream to stream the response as a simple string
if (
validatedParams.subAction === 'invokeStream' &&
result.data instanceof Readable &&
Expand All @@ -109,7 +110,7 @@ export const getGenAiTokenTracking = async ({
}
}

// this is a streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string
// streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string
if (
validatedParams.subAction === 'invokeStream' &&
result.data instanceof Readable &&
Expand All @@ -134,7 +135,7 @@ export const getGenAiTokenTracking = async ({
}
}

// this is a streamed OpenAI response, which did not use the subAction invokeStream
// streamed OpenAI response, which did not use the subAction invokeStream
if (actionTypeId === '.gen-ai' && result.data instanceof Readable) {
try {
const { total, prompt, completion } = await getTokenCountFromOpenAIStream({
Expand All @@ -154,7 +155,7 @@ export const getGenAiTokenTracking = async ({
}
}

// this is a non-streamed OpenAI response, which comes with the usage object
// non-streamed OpenAI response, which comes with the usage object
if (actionTypeId === '.gen-ai') {
const data = result.data as unknown as {
usage: { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number };
Expand All @@ -170,7 +171,7 @@ export const getGenAiTokenTracking = async ({
};
}

// this is a non-streamed Bedrock response
// non-streamed Bedrock response
if (
actionTypeId === '.bedrock' &&
(validatedParams.subAction === 'run' || validatedParams.subAction === 'test')
Expand Down Expand Up @@ -226,7 +227,7 @@ export const getGenAiTokenTracking = async ({
};
}

// this is a non-streamed Bedrock response used by security solution
// non-streamed Bedrock response used by chat model ActionsClientBedrockChatModel
if (actionTypeId === '.bedrock' && validatedParams.subAction === 'invokeAI') {
try {
const rData = result.data as unknown as {
Expand Down Expand Up @@ -264,6 +265,44 @@ export const getGenAiTokenTracking = async ({
// silently fail and null is returned at bottom of function
}
}
// non-streamed Bedrock response used by chat model ActionsClientChatBedrockConverse
if (actionTypeId === '.bedrock' && validatedParams.subAction === 'converse') {
console.log('tokenTrack converse');
const { usage } = result.data as unknown as {
usage: { inputTokens: number; outputTokens: number; totalTokens: number };
};

if (usage) {
return {
total_tokens: usage.totalTokens,
prompt_tokens: usage.inputTokens,
completion_tokens: usage.outputTokens,
};
} else {
logger.error('Response from Bedrock converse API did not contain usage object');
return {
total_tokens: 0,
prompt_tokens: 0,
completion_tokens: 0,
};
}
}
// streamed Bedrock response used by chat model ActionsClientChatBedrockConverse
if (
actionTypeId === '.bedrock' &&
result.data instanceof Readable &&
validatedParams.subAction === 'converseStream'
) {
try {
console.log('converseStream result', result);
const converseTokens = await parseBedrockConverseStream(result.data.pipe(new PassThrough()));
return converseTokens;
} catch (e) {
logger.error('Failed to calculate tokens from converseStream subaction streaming response');
logger.error(e);
// silently fail and null is returned at bottom of function
}
}
return null;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import { Transform } from 'stream';
import {
getTokenCountFromInvokeStream,
parseBedrockConverseStream,
parseGeminiStreamForUsageMetadata,
} from './get_token_count_from_invoke_stream';
import { loggerMock } from '@kbn/logging-mocks';
Expand Down Expand Up @@ -149,6 +150,43 @@ describe('getTokenCountFromInvokeStream', () => {
stream.write(encodeBedrockResponse('Simple.'));
});

it('parses the Bedrock converse stream and returns token counts', async () => {
const usageData = {
usage: {
totalTokens: 100,
inputTokens: 40,
outputTokens: 60,
},
};
const encodedUsageData = new EventStreamCodec(toUtf8, fromUtf8).encode({
headers: {},
body: Uint8Array.from(Buffer.from(JSON.stringify(usageData))),
});

stream.write(encodedUsageData);
stream.complete();

const tokens = await parseBedrockConverseStream(stream.transform);
expect(tokens).toEqual({
total_tokens: 100,
prompt_tokens: 40,
completion_tokens: 60,
});
});

it('returns null if the converse stream usage object is not present', async () => {
const invalidData = new EventStreamCodec(toUtf8, fromUtf8).encode({
headers: {},
body: Uint8Array.from(Buffer.from(JSON.stringify({}))),
});

stream.write(invalidData);
stream.complete();

const tokens = await parseBedrockConverseStream(stream.transform);
expect(tokens).toBeNull();
});

it('calculates from the usage object when latest api is used', async () => {
stream = createStreamMock();
stream.write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,49 @@ const parseBedrockStream: StreamParser = async (responseStream, logger) => {
return parseBedrockBuffer(responseBuffer);
};

export const parseBedrockConverseStream = async (
responseStream: Readable
): Promise<{
total_tokens: number;
prompt_tokens: number;
completion_tokens: number;
} | null> => {
const responseBuffer: Uint8Array[] = [];
// do not destroy response stream on abort for bedrock
// Amazon charges the same tokens whether the stream is destroyed or not, so let it finish to calculate
responseStream.on('data', (chunk) => {
// special encoding for bedrock, do not attempt to convert to string
responseBuffer.push(chunk);
});
await finished(responseStream);
const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8);

const finalChunk = responseBuffer[responseBuffer.length - 1];
try {
console.log('debugGoHere finalChunk', finalChunk);
const event = awsDecoder.decode(finalChunk);
console.log('debugGoHere event', event);
const decoded = new TextDecoder().decode(event.body);
console.log('debugGoHere decoded', decoded);
const usage = JSON.parse(decoded).usage;
console.log('debugGoHere usage', usage);
return usage
? {
total_tokens: usage.totalTokens,
prompt_tokens: usage.inputTokens,
completion_tokens: usage.outputTokens,
}
: null;
} catch (e) {
const regularDecoded = new TextDecoder().decode(finalChunk);
console.log('debugGoHere error regularDecoded', regularDecoded);
console.log('debugGoHere error', e);
// Tool response does not contain usage object, thus parsing throws an error
// return null as token from tool usage will be included in the final response of the stream
return null;
}
};

const parseOpenAIStream: StreamParser = async (responseStream, logger, signal) => {
let responseBody: string = '';
const destroyStream = () => {
Expand Down