Skip to content

Commit

Permalink
[8.x] [inference] add support for openAI native stream token count (#…
Browse files Browse the repository at this point in the history
…200745) (#201007)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[inference] add support for openAI native stream token count
(#200745)](#200745)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Pierre
Gayvallet","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-11-20T16:53:44Z","message":"[inference]
add support for openAI native stream token count (#200745)\n\n##
Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/192962\r\n\r\nAdd support for
native openAI token count for streaming APIs.\r\n\r\nThis is done by
adding the `stream_options: {\"include_usage\": true}`\r\nparameter when
`stream: true` is being
used\r\n([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),\r\nand
then using the `usage` entry for the last emitted
chunk.\r\n\r\n**Note**: this was done only for the `OpenAI` and
`AzureAI`\r\n[providers](https://github.com/elastic/kibana/blob/83a701e837a7a84a86dcc8d359154f900f69676a/x-pack/plugins/stack_connectors/common/openai/constants.ts#L27-L31),\r\nand
**not** for the `Other` provider. The reasoning is that not
all\r\nopenAI \"\"\"compatible\"\"\" providers fully support all
options, so I didn't\r\nwant to risk adding a parameter that could cause
some models using an\r\nopenAI adapter to reject the requests. This is
also the reason why I did\r\nnot change the
way\r\n[getTokenCountFromOpenAIStream](https://github.com/elastic/kibana/blob/8bffd618059aacc30d6190a0d143d8b0c7217faf/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts#L15)\r\nfunction,
as we want that to work for all
providers.\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<[email protected]>","sha":"67171e15c2bd9063059701c4974f76f480ccd538","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:prev-minor","Team:AI
Infra"],"title":"[inference] add support for openAI native stream token
count","number":200745,"url":"https://github.com/elastic/kibana/pull/200745","mergeCommit":{"message":"[inference]
add support for openAI native stream token count (#200745)\n\n##
Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/192962\r\n\r\nAdd support for
native openAI token count for streaming APIs.\r\n\r\nThis is done by
adding the `stream_options: {\"include_usage\": true}`\r\nparameter when
`stream: true` is being
used\r\n([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),\r\nand
then using the `usage` entry for the last emitted
chunk.\r\n\r\n**Note**: this was done only for the `OpenAI` and
`AzureAI`\r\n[providers](https://github.com/elastic/kibana/blob/83a701e837a7a84a86dcc8d359154f900f69676a/x-pack/plugins/stack_connectors/common/openai/constants.ts#L27-L31),\r\nand
**not** for the `Other` provider. The reasoning is that not
all\r\nopenAI \"\"\"compatible\"\"\" providers fully support all
options, so I didn't\r\nwant to risk adding a parameter that could cause
some models using an\r\nopenAI adapter to reject the requests. This is
also the reason why I did\r\nnot change the
way\r\n[getTokenCountFromOpenAIStream](https://github.com/elastic/kibana/blob/8bffd618059aacc30d6190a0d143d8b0c7217faf/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts#L15)\r\nfunction,
as we want that to work for all
providers.\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<[email protected]>","sha":"67171e15c2bd9063059701c4974f76f480ccd538"}},"sourceBranch":"main","suggestedTargetBranches":[],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/200745","number":200745,"mergeCommit":{"message":"[inference]
add support for openAI native stream token count (#200745)\n\n##
Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/192962\r\n\r\nAdd support for
native openAI token count for streaming APIs.\r\n\r\nThis is done by
adding the `stream_options: {\"include_usage\": true}`\r\nparameter when
`stream: true` is being
used\r\n([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),\r\nand
then using the `usage` entry for the last emitted
chunk.\r\n\r\n**Note**: this was done only for the `OpenAI` and
`AzureAI`\r\n[providers](https://github.com/elastic/kibana/blob/83a701e837a7a84a86dcc8d359154f900f69676a/x-pack/plugins/stack_connectors/common/openai/constants.ts#L27-L31),\r\nand
**not** for the `Other` provider. The reasoning is that not
all\r\nopenAI \"\"\"compatible\"\"\" providers fully support all
options, so I didn't\r\nwant to risk adding a parameter that could cause
some models using an\r\nopenAI adapter to reject the requests. This is
also the reason why I did\r\nnot change the
way\r\n[getTokenCountFromOpenAIStream](https://github.com/elastic/kibana/blob/8bffd618059aacc30d6190a0d143d8b0c7217faf/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts#L15)\r\nfunction,
as we want that to work for all
providers.\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<[email protected]>","sha":"67171e15c2bd9063059701c4974f76f480ccd538"}}]}]
BACKPORT-->

Co-authored-by: Pierre Gayvallet <[email protected]>
  • Loading branch information
kibanamachine and pgayvallet authored Nov 20, 2024
1 parent 43d4730 commit 63f1de7
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ describe('getTokenCountFromOpenAIStream', () => {
],
};

const usageChunk = {
object: 'chat.completion.chunk',
choices: [],
usage: {
prompt_tokens: 50,
completion_tokens: 100,
total_tokens: 150,
},
};

const PROMPT_TOKEN_COUNT = 36;
const COMPLETION_TOKEN_COUNT = 5;

Expand All @@ -70,55 +80,79 @@ describe('getTokenCountFromOpenAIStream', () => {
});

describe('when a stream completes', () => {
beforeEach(async () => {
stream.write('data: [DONE]');
stream.complete();
});
describe('with usage chunk', () => {
it('returns the counts from the usage chunk', async () => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(chunk)}`);
stream.write(`data: ${JSON.stringify(usageChunk)}`);
stream.write('data: [DONE]');
stream.complete();

describe('without function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify(body),
});
});

it('counts the prompt tokens', () => {
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
expect(tokens).toEqual({
prompt: usageChunk.usage.prompt_tokens,
completion: usageChunk.usage.completion_tokens,
total: usageChunk.usage.total_tokens,
});
});
});

describe('with function tokens', () => {
describe('without usage chunk', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify({
...body,
functions: [
{
name: 'my_function',
description: 'My function description',
parameters: {
type: 'object',
properties: {
my_property: {
type: 'boolean',
description: 'My function property',
stream.write('data: [DONE]');
stream.complete();
});

describe('without function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify(body),
});
});

it('counts the prompt tokens', () => {
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
});

describe('with function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify({
...body,
functions: [
{
name: 'my_function',
description: 'My function description',
parameters: {
type: 'object',
properties: {
my_property: {
type: 'boolean',
description: 'My function property',
},
},
},
},
},
],
}),
],
}),
});
});
});

it('counts the function tokens', () => {
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
it('counts the function tokens', () => {
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
});
});
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,7 @@ export async function getTokenCountFromOpenAIStream({
prompt: number;
completion: number;
}> {
const chatCompletionRequest = JSON.parse(
body
) as OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming;

// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const tokensFromMessages = encode(
chatCompletionRequest.messages
.map(
(msg) =>
`<|start|>${msg.role}\n${msg.content}\n${
'name' in msg
? msg.name
: 'function_call' in msg && msg.function_call
? msg.function_call.name + '\n' + msg.function_call.arguments
: ''
}<|end|>`
)
.join('\n')
).length;

// this is an approximation. OpenAI cuts off a function schema
// at a certain level of nesting, so their token count might
// be lower than what we are calculating here.

const tokensFromFunctions = chatCompletionRequest.functions
? encode(
chatCompletionRequest.functions
?.map(
(fn) =>
`<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>`
)
.join('\n')
).length
: 0;

const promptTokens = tokensFromMessages + tokensFromFunctions;

let responseBody: string = '';
let responseBody = '';

responseStream.on('data', (chunk: string) => {
responseBody += chunk.toString();
Expand All @@ -74,39 +37,64 @@ export async function getTokenCountFromOpenAIStream({
logger.error('An error occurred while calculating streaming response tokens');
}

const response = responseBody
let completionUsage: OpenAI.CompletionUsage | undefined;

const response: ParsedResponse = responseBody
.split('\n')
.filter((line) => {
return line.startsWith('data: ') && !line.endsWith('[DONE]');
})
.map((line) => {
return JSON.parse(line.replace('data: ', ''));
})
.filter(
(
line
): line is {
choices: Array<{
delta: { content?: string; function_call?: { name?: string; arguments: string } };
}>;
} => {
return (
'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0
);
}
)
.filter((line): line is OpenAI.ChatCompletionChunk => {
return 'object' in line && line.object === 'chat.completion.chunk';
})
.reduce(
(prev, line) => {
const msg = line.choices[0].delta!;
prev.content += msg.content || '';
prev.function_call.name += msg.function_call?.name || '';
prev.function_call.arguments += msg.function_call?.arguments || '';
if (line.usage) {
completionUsage = line.usage;
}
if (line.choices?.length) {
const msg = line.choices[0].delta!;
prev.content += msg.content || '';
prev.function_call.name += msg.function_call?.name || '';
prev.function_call.arguments += msg.function_call?.arguments || '';
}
return prev;
},
{ content: '', function_call: { name: '', arguments: '' } }
);

const completionTokens = encode(
// not all openAI compatible providers emit completion chunk, so we still have to support
// manually counting the tokens
if (completionUsage) {
return {
prompt: completionUsage.prompt_tokens,
completion: completionUsage.completion_tokens,
total: completionUsage.total_tokens,
};
} else {
const promptTokens = manuallyCountPromptTokens(body);
const completionTokens = manuallyCountCompletionTokens(response);
return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}
}

interface ParsedResponse {
content: string;
function_call: {
name: string;
arguments: string;
};
}

const manuallyCountCompletionTokens = (response: ParsedResponse) => {
return encode(
JSON.stringify(
omitBy(
{
Expand All @@ -117,10 +105,42 @@ export async function getTokenCountFromOpenAIStream({
)
)
).length;
};

return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}
const manuallyCountPromptTokens = (requestBody: string) => {
const chatCompletionRequest: OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming =
JSON.parse(requestBody);

// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const tokensFromMessages = encode(
chatCompletionRequest.messages
.map(
(msg) =>
`<|start|>${msg.role}\n${msg.content}\n${
'name' in msg
? msg.name
: 'function_call' in msg && msg.function_call
? msg.function_call.name + '\n' + msg.function_call.arguments
: ''
}<|end|>`
)
.join('\n')
).length;

// this is an approximation. OpenAI cuts off a function schema
// at a certain level of nesting, so their token count might
// be lower than what we are calculating here.

const tokensFromFunctions = chatCompletionRequest.functions
? encode(
chatCompletionRequest.functions
?.map(
(fn) =>
`<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>`
)
.join('\n')
).length
: 0;

return tokensFromMessages + tokensFromFunctions;
};
Loading

0 comments on commit 63f1de7

Please sign in to comment.