diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index 41bfa28605f40..f8be1610d1db2 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -19,6 +19,7 @@ import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_str import { getTokenCountFromInvokeStream, InvokeBody, + parseBedrockConverseStream, parseGeminiStreamForUsageMetadata, } from './get_token_count_from_invoke_stream'; @@ -84,7 +85,7 @@ export const getGenAiTokenTracking = async ({ } } - // this is a streamed Gemini response, using the subAction invokeStream to stream the response as a simple string + // streamed Gemini response, using the subAction invokeStream to stream the response as a simple string if ( validatedParams.subAction === 'invokeStream' && result.data instanceof Readable && @@ -109,7 +110,7 @@ export const getGenAiTokenTracking = async ({ } } - // this is a streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string + // streamed OpenAI or Bedrock response, using the subAction invokeStream to stream the response as a simple string if ( validatedParams.subAction === 'invokeStream' && result.data instanceof Readable && @@ -134,7 +135,7 @@ export const getGenAiTokenTracking = async ({ } } - // this is a streamed OpenAI response, which did not use the subAction invokeStream + // streamed OpenAI response, which did not use the subAction invokeStream if (actionTypeId === '.gen-ai' && result.data instanceof Readable) { try { const { total, prompt, completion } = await getTokenCountFromOpenAIStream({ @@ -154,7 +155,7 @@ export const getGenAiTokenTracking = async ({ } } - // this is a non-streamed OpenAI response, which comes with the usage object + // non-streamed OpenAI response, which comes with the usage object if (actionTypeId === '.gen-ai') { const data = result.data as unknown as { usage: { prompt_tokens?: number; completion_tokens?: number; total_tokens?: number }; @@ -170,7 +171,7 @@ export const getGenAiTokenTracking = async ({ }; } - // this is a non-streamed Bedrock response + // non-streamed Bedrock response if ( actionTypeId === '.bedrock' && (validatedParams.subAction === 'run' || validatedParams.subAction === 'test') @@ -226,7 +227,7 @@ export const getGenAiTokenTracking = async ({ }; } - // this is a non-streamed Bedrock response used by security solution + // non-streamed Bedrock response used by chat model ActionsClientBedrockChatModel if (actionTypeId === '.bedrock' && validatedParams.subAction === 'invokeAI') { try { const rData = result.data as unknown as { @@ -264,6 +265,44 @@ export const getGenAiTokenTracking = async ({ // silently fail and null is returned at bottom of function } } + // non-streamed Bedrock response used by chat model ActionsClientChatBedrockConverse + if (actionTypeId === '.bedrock' && validatedParams.subAction === 'converse') { + console.log('tokenTrack converse'); + const { usage } = result.data as unknown as { + usage: { inputTokens: number; outputTokens: number; totalTokens: number }; + }; + + if (usage) { + return { + total_tokens: usage.totalTokens, + prompt_tokens: usage.inputTokens, + completion_tokens: usage.outputTokens, + }; + } else { + logger.error('Response from Bedrock converse API did not contain usage object'); + return { + total_tokens: 0, + prompt_tokens: 0, + completion_tokens: 0, + }; + } + } + // streamed Bedrock response used by chat model ActionsClientChatBedrockConverse + if ( + actionTypeId === '.bedrock' && + result.data instanceof Readable && + validatedParams.subAction === 'converseStream' + ) { + try { + console.log('converseStream result', result); + const converseTokens = await parseBedrockConverseStream(result.data.pipe(new PassThrough())); + return converseTokens; + } catch (e) { + logger.error('Failed to calculate tokens from converseStream subaction streaming response'); + logger.error(e); + // silently fail and null is returned at bottom of function + } + } return null; }; diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts index 3a2cabbb1b0e4..f484958959c49 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.test.ts @@ -7,6 +7,7 @@ import { Transform } from 'stream'; import { getTokenCountFromInvokeStream, + parseBedrockConverseStream, parseGeminiStreamForUsageMetadata, } from './get_token_count_from_invoke_stream'; import { loggerMock } from '@kbn/logging-mocks'; @@ -149,6 +150,43 @@ describe('getTokenCountFromInvokeStream', () => { stream.write(encodeBedrockResponse('Simple.')); }); + it('parses the Bedrock converse stream and returns token counts', async () => { + const usageData = { + usage: { + totalTokens: 100, + inputTokens: 40, + outputTokens: 60, + }, + }; + const encodedUsageData = new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from(Buffer.from(JSON.stringify(usageData))), + }); + + stream.write(encodedUsageData); + stream.complete(); + + const tokens = await parseBedrockConverseStream(stream.transform); + expect(tokens).toEqual({ + total_tokens: 100, + prompt_tokens: 40, + completion_tokens: 60, + }); + }); + + it('returns null if the converse stream usage object is not present', async () => { + const invalidData = new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: {}, + body: Uint8Array.from(Buffer.from(JSON.stringify({}))), + }); + + stream.write(invalidData); + stream.complete(); + + const tokens = await parseBedrockConverseStream(stream.transform); + expect(tokens).toBeNull(); + }); + it('calculates from the usage object when latest api is used', async () => { stream = createStreamMock(); stream.write( diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts index 909b7e09abda0..8ca0035d4ecb0 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_invoke_stream.ts @@ -105,6 +105,49 @@ const parseBedrockStream: StreamParser = async (responseStream, logger) => { return parseBedrockBuffer(responseBuffer); }; +export const parseBedrockConverseStream = async ( + responseStream: Readable +): Promise<{ + total_tokens: number; + prompt_tokens: number; + completion_tokens: number; +} | null> => { + const responseBuffer: Uint8Array[] = []; + // do not destroy response stream on abort for bedrock + // Amazon charges the same tokens whether the stream is destroyed or not, so let it finish to calculate + responseStream.on('data', (chunk) => { + // special encoding for bedrock, do not attempt to convert to string + responseBuffer.push(chunk); + }); + await finished(responseStream); + const awsDecoder = new EventStreamCodec(toUtf8, fromUtf8); + + const finalChunk = responseBuffer[responseBuffer.length - 1]; + try { + console.log('debugGoHere finalChunk', finalChunk); + const event = awsDecoder.decode(finalChunk); + console.log('debugGoHere event', event); + const decoded = new TextDecoder().decode(event.body); + console.log('debugGoHere decoded', decoded); + const usage = JSON.parse(decoded).usage; + console.log('debugGoHere usage', usage); + return usage + ? { + total_tokens: usage.totalTokens, + prompt_tokens: usage.inputTokens, + completion_tokens: usage.outputTokens, + } + : null; + } catch (e) { + const regularDecoded = new TextDecoder().decode(finalChunk); + console.log('debugGoHere error regularDecoded', regularDecoded); + console.log('debugGoHere error', e); + // Tool response does not contain usage object, thus parsing throws an error + // return null as token from tool usage will be included in the final response of the stream + return null; + } +}; + const parseOpenAIStream: StreamParser = async (responseStream, logger, signal) => { let responseBody: string = ''; const destroyStream = () => {