Skip to content

Commit

Permalink
[inference] add support for openAI native stream token count (#200745)
Browse files Browse the repository at this point in the history
## Summary

Fix #192962

Add support for native openAI token count for streaming APIs.

This is done by adding the `stream_options: {"include_usage": true}`
parameter when `stream: true` is being used
([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),
and then using the `usage` entry for the last emitted chunk.

**Note**: this was done only for the `OpenAI` and `AzureAI`
[providers](https://github.com/elastic/kibana/blob/83a701e837a7a84a86dcc8d359154f900f69676a/x-pack/plugins/stack_connectors/common/openai/constants.ts#L27-L31),
and **not** for the `Other` provider. The reasoning is that not all
openAI """compatible""" providers fully support all options, so I didn't
want to risk adding a parameter that could cause some models using an
openAI adapter to reject the requests. This is also the reason why I did
not change the way
[getTokenCountFromOpenAIStream](https://github.com/elastic/kibana/blob/8bffd618059aacc30d6190a0d143d8b0c7217faf/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts#L15)
function, as we want that to work for all providers.

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
pgayvallet and elasticmachine authored Nov 20, 2024
1 parent 7c02b32 commit 67171e1
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ describe('getTokenCountFromOpenAIStream', () => {
],
};

const usageChunk = {
object: 'chat.completion.chunk',
choices: [],
usage: {
prompt_tokens: 50,
completion_tokens: 100,
total_tokens: 150,
},
};

const PROMPT_TOKEN_COUNT = 36;
const COMPLETION_TOKEN_COUNT = 5;

Expand All @@ -70,55 +80,79 @@ describe('getTokenCountFromOpenAIStream', () => {
});

describe('when a stream completes', () => {
beforeEach(async () => {
stream.write('data: [DONE]');
stream.complete();
});
describe('with usage chunk', () => {
it('returns the counts from the usage chunk', async () => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(chunk)}`);
stream.write(`data: ${JSON.stringify(usageChunk)}`);
stream.write('data: [DONE]');
stream.complete();

describe('without function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify(body),
});
});

it('counts the prompt tokens', () => {
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
expect(tokens).toEqual({
prompt: usageChunk.usage.prompt_tokens,
completion: usageChunk.usage.completion_tokens,
total: usageChunk.usage.total_tokens,
});
});
});

describe('with function tokens', () => {
describe('without usage chunk', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify({
...body,
functions: [
{
name: 'my_function',
description: 'My function description',
parameters: {
type: 'object',
properties: {
my_property: {
type: 'boolean',
description: 'My function property',
stream.write('data: [DONE]');
stream.complete();
});

describe('without function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify(body),
});
});

it('counts the prompt tokens', () => {
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
});

describe('with function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
logger,
body: JSON.stringify({
...body,
functions: [
{
name: 'my_function',
description: 'My function description',
parameters: {
type: 'object',
properties: {
my_property: {
type: 'boolean',
description: 'My function property',
},
},
},
},
},
],
}),
],
}),
});
});
});

it('counts the function tokens', () => {
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
it('counts the function tokens', () => {
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
});
});
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,7 @@ export async function getTokenCountFromOpenAIStream({
prompt: number;
completion: number;
}> {
const chatCompletionRequest = JSON.parse(
body
) as OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming;

// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const tokensFromMessages = encode(
chatCompletionRequest.messages
.map(
(msg) =>
`<|start|>${msg.role}\n${msg.content}\n${
'name' in msg
? msg.name
: 'function_call' in msg && msg.function_call
? msg.function_call.name + '\n' + msg.function_call.arguments
: ''
}<|end|>`
)
.join('\n')
).length;

// this is an approximation. OpenAI cuts off a function schema
// at a certain level of nesting, so their token count might
// be lower than what we are calculating here.

const tokensFromFunctions = chatCompletionRequest.functions
? encode(
chatCompletionRequest.functions
?.map(
(fn) =>
`<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>`
)
.join('\n')
).length
: 0;

const promptTokens = tokensFromMessages + tokensFromFunctions;

let responseBody: string = '';
let responseBody = '';

responseStream.on('data', (chunk: string) => {
responseBody += chunk.toString();
Expand All @@ -74,39 +37,64 @@ export async function getTokenCountFromOpenAIStream({
logger.error('An error occurred while calculating streaming response tokens');
}

const response = responseBody
let completionUsage: OpenAI.CompletionUsage | undefined;

const response: ParsedResponse = responseBody
.split('\n')
.filter((line) => {
return line.startsWith('data: ') && !line.endsWith('[DONE]');
})
.map((line) => {
return JSON.parse(line.replace('data: ', ''));
})
.filter(
(
line
): line is {
choices: Array<{
delta: { content?: string; function_call?: { name?: string; arguments: string } };
}>;
} => {
return (
'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0
);
}
)
.filter((line): line is OpenAI.ChatCompletionChunk => {
return 'object' in line && line.object === 'chat.completion.chunk';
})
.reduce(
(prev, line) => {
const msg = line.choices[0].delta!;
prev.content += msg.content || '';
prev.function_call.name += msg.function_call?.name || '';
prev.function_call.arguments += msg.function_call?.arguments || '';
if (line.usage) {
completionUsage = line.usage;
}
if (line.choices?.length) {
const msg = line.choices[0].delta!;
prev.content += msg.content || '';
prev.function_call.name += msg.function_call?.name || '';
prev.function_call.arguments += msg.function_call?.arguments || '';
}
return prev;
},
{ content: '', function_call: { name: '', arguments: '' } }
);

const completionTokens = encode(
// not all openAI compatible providers emit completion chunk, so we still have to support
// manually counting the tokens
if (completionUsage) {
return {
prompt: completionUsage.prompt_tokens,
completion: completionUsage.completion_tokens,
total: completionUsage.total_tokens,
};
} else {
const promptTokens = manuallyCountPromptTokens(body);
const completionTokens = manuallyCountCompletionTokens(response);
return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}
}

interface ParsedResponse {
content: string;
function_call: {
name: string;
arguments: string;
};
}

const manuallyCountCompletionTokens = (response: ParsedResponse) => {
return encode(
JSON.stringify(
omitBy(
{
Expand All @@ -117,10 +105,42 @@ export async function getTokenCountFromOpenAIStream({
)
)
).length;
};

return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}
const manuallyCountPromptTokens = (requestBody: string) => {
const chatCompletionRequest: OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming =
JSON.parse(requestBody);

// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const tokensFromMessages = encode(
chatCompletionRequest.messages
.map(
(msg) =>
`<|start|>${msg.role}\n${msg.content}\n${
'name' in msg
? msg.name
: 'function_call' in msg && msg.function_call
? msg.function_call.name + '\n' + msg.function_call.arguments
: ''
}<|end|>`
)
.join('\n')
).length;

// this is an approximation. OpenAI cuts off a function schema
// at a certain level of nesting, so their token count might
// be lower than what we are calculating here.

const tokensFromFunctions = chatCompletionRequest.functions
? encode(
chatCompletionRequest.functions
?.map(
(fn) =>
`<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>`
)
.join('\n')
).length
: 0;

return tokensFromMessages + tokensFromFunctions;
};
Loading

0 comments on commit 67171e1

Please sign in to comment.