From 67171e15c2bd9063059701c4974f76f480ccd538 Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Wed, 20 Nov 2024 17:53:44 +0100 Subject: [PATCH] [inference] add support for openAI native stream token count (#200745) ## Summary Fix https://github.com/elastic/kibana/issues/192962 Add support for native openAI token count for streaming APIs. This is done by adding the `stream_options: {"include_usage": true}` parameter when `stream: true` is being used ([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)), and then using the `usage` entry for the last emitted chunk. **Note**: this was done only for the `OpenAI` and `AzureAI` [providers](https://github.com/elastic/kibana/blob/83a701e837a7a84a86dcc8d359154f900f69676a/x-pack/plugins/stack_connectors/common/openai/constants.ts#L27-L31), and **not** for the `Other` provider. The reasoning is that not all openAI """compatible""" providers fully support all options, so I didn't want to risk adding a parameter that could cause some models using an openAI adapter to reject the requests. This is also the reason why I did not change the way [getTokenCountFromOpenAIStream](https://github.com/elastic/kibana/blob/8bffd618059aacc30d6190a0d143d8b0c7217faf/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts#L15) function, as we want that to work for all providers. --------- Co-authored-by: Elastic Machine --- ...get_token_count_from_openai_stream.test.ts | 100 ++++++++---- .../lib/get_token_count_from_openai_stream.ts | 146 ++++++++++-------- .../adapters/openai/openai_adapter.test.ts | 70 +++++++-- .../adapters/openai/openai_adapter.ts | 88 +++++++---- .../openai/lib/azure_openai_utils.test.ts | 47 +++++- .../openai/lib/azure_openai_utils.ts | 5 + .../openai/lib/openai_utils.test.ts | 60 ++++++- .../openai/lib/openai_utils.ts | 5 + .../connector_types/openai/openai.test.ts | 12 +- 9 files changed, 395 insertions(+), 138 deletions(-) diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts index cc81706fc257c..a1bc118066b9d 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.test.ts @@ -61,6 +61,16 @@ describe('getTokenCountFromOpenAIStream', () => { ], }; + const usageChunk = { + object: 'chat.completion.chunk', + choices: [], + usage: { + prompt_tokens: 50, + completion_tokens: 100, + total_tokens: 150, + }, + }; + const PROMPT_TOKEN_COUNT = 36; const COMPLETION_TOKEN_COUNT = 5; @@ -70,55 +80,79 @@ describe('getTokenCountFromOpenAIStream', () => { }); describe('when a stream completes', () => { - beforeEach(async () => { - stream.write('data: [DONE]'); - stream.complete(); - }); + describe('with usage chunk', () => { + it('returns the counts from the usage chunk', async () => { + stream = createStreamMock(); + stream.write(`data: ${JSON.stringify(chunk)}`); + stream.write(`data: ${JSON.stringify(usageChunk)}`); + stream.write('data: [DONE]'); + stream.complete(); - describe('without function tokens', () => { - beforeEach(async () => { tokens = await getTokenCountFromOpenAIStream({ responseStream: stream.transform, logger, body: JSON.stringify(body), }); - }); - it('counts the prompt tokens', () => { - expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); - expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); - expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); + expect(tokens).toEqual({ + prompt: usageChunk.usage.prompt_tokens, + completion: usageChunk.usage.completion_tokens, + total: usageChunk.usage.total_tokens, + }); }); }); - describe('with function tokens', () => { + describe('without usage chunk', () => { beforeEach(async () => { - tokens = await getTokenCountFromOpenAIStream({ - responseStream: stream.transform, - logger, - body: JSON.stringify({ - ...body, - functions: [ - { - name: 'my_function', - description: 'My function description', - parameters: { - type: 'object', - properties: { - my_property: { - type: 'boolean', - description: 'My function property', + stream.write('data: [DONE]'); + stream.complete(); + }); + + describe('without function tokens', () => { + beforeEach(async () => { + tokens = await getTokenCountFromOpenAIStream({ + responseStream: stream.transform, + logger, + body: JSON.stringify(body), + }); + }); + + it('counts the prompt tokens', () => { + expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT); + expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT); + expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT); + }); + }); + + describe('with function tokens', () => { + beforeEach(async () => { + tokens = await getTokenCountFromOpenAIStream({ + responseStream: stream.transform, + logger, + body: JSON.stringify({ + ...body, + functions: [ + { + name: 'my_function', + description: 'My function description', + parameters: { + type: 'object', + properties: { + my_property: { + type: 'boolean', + description: 'My function property', + }, }, }, }, - }, - ], - }), + ], + }), + }); }); - }); - it('counts the function tokens', () => { - expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT); + it('counts the function tokens', () => { + expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT); + }); }); }); }); diff --git a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts index 790a59fe6097a..5c19a23e6d230 100644 --- a/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts +++ b/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts @@ -25,44 +25,7 @@ export async function getTokenCountFromOpenAIStream({ prompt: number; completion: number; }> { - const chatCompletionRequest = JSON.parse( - body - ) as OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming; - - // per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - const tokensFromMessages = encode( - chatCompletionRequest.messages - .map( - (msg) => - `<|start|>${msg.role}\n${msg.content}\n${ - 'name' in msg - ? msg.name - : 'function_call' in msg && msg.function_call - ? msg.function_call.name + '\n' + msg.function_call.arguments - : '' - }<|end|>` - ) - .join('\n') - ).length; - - // this is an approximation. OpenAI cuts off a function schema - // at a certain level of nesting, so their token count might - // be lower than what we are calculating here. - - const tokensFromFunctions = chatCompletionRequest.functions - ? encode( - chatCompletionRequest.functions - ?.map( - (fn) => - `<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>` - ) - .join('\n') - ).length - : 0; - - const promptTokens = tokensFromMessages + tokensFromFunctions; - - let responseBody: string = ''; + let responseBody = ''; responseStream.on('data', (chunk: string) => { responseBody += chunk.toString(); @@ -74,7 +37,9 @@ export async function getTokenCountFromOpenAIStream({ logger.error('An error occurred while calculating streaming response tokens'); } - const response = responseBody + let completionUsage: OpenAI.CompletionUsage | undefined; + + const response: ParsedResponse = responseBody .split('\n') .filter((line) => { return line.startsWith('data: ') && !line.endsWith('[DONE]'); @@ -82,31 +47,54 @@ export async function getTokenCountFromOpenAIStream({ .map((line) => { return JSON.parse(line.replace('data: ', '')); }) - .filter( - ( - line - ): line is { - choices: Array<{ - delta: { content?: string; function_call?: { name?: string; arguments: string } }; - }>; - } => { - return ( - 'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0 - ); - } - ) + .filter((line): line is OpenAI.ChatCompletionChunk => { + return 'object' in line && line.object === 'chat.completion.chunk'; + }) .reduce( (prev, line) => { - const msg = line.choices[0].delta!; - prev.content += msg.content || ''; - prev.function_call.name += msg.function_call?.name || ''; - prev.function_call.arguments += msg.function_call?.arguments || ''; + if (line.usage) { + completionUsage = line.usage; + } + if (line.choices?.length) { + const msg = line.choices[0].delta!; + prev.content += msg.content || ''; + prev.function_call.name += msg.function_call?.name || ''; + prev.function_call.arguments += msg.function_call?.arguments || ''; + } return prev; }, { content: '', function_call: { name: '', arguments: '' } } ); - const completionTokens = encode( + // not all openAI compatible providers emit completion chunk, so we still have to support + // manually counting the tokens + if (completionUsage) { + return { + prompt: completionUsage.prompt_tokens, + completion: completionUsage.completion_tokens, + total: completionUsage.total_tokens, + }; + } else { + const promptTokens = manuallyCountPromptTokens(body); + const completionTokens = manuallyCountCompletionTokens(response); + return { + prompt: promptTokens, + completion: completionTokens, + total: promptTokens + completionTokens, + }; + } +} + +interface ParsedResponse { + content: string; + function_call: { + name: string; + arguments: string; + }; +} + +const manuallyCountCompletionTokens = (response: ParsedResponse) => { + return encode( JSON.stringify( omitBy( { @@ -117,10 +105,42 @@ export async function getTokenCountFromOpenAIStream({ ) ) ).length; +}; - return { - prompt: promptTokens, - completion: completionTokens, - total: promptTokens + completionTokens, - }; -} +const manuallyCountPromptTokens = (requestBody: string) => { + const chatCompletionRequest: OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming = + JSON.parse(requestBody); + + // per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + const tokensFromMessages = encode( + chatCompletionRequest.messages + .map( + (msg) => + `<|start|>${msg.role}\n${msg.content}\n${ + 'name' in msg + ? msg.name + : 'function_call' in msg && msg.function_call + ? msg.function_call.name + '\n' + msg.function_call.arguments + : '' + }<|end|>` + ) + .join('\n') + ).length; + + // this is an approximation. OpenAI cuts off a function schema + // at a certain level of nesting, so their token count might + // be lower than what we are calculating here. + + const tokensFromFunctions = chatCompletionRequest.functions + ? encode( + chatCompletionRequest.functions + ?.map( + (fn) => + `<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>` + ) + .join('\n') + ).length + : 0; + + return tokensFromMessages + tokensFromFunctions; +}; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts index ff1bbc71a876d..2d0154313b632 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts @@ -21,17 +21,19 @@ function createOpenAIChunk({ delta, usage, }: { - delta: OpenAI.ChatCompletionChunk['choices'][number]['delta']; + delta?: OpenAI.ChatCompletionChunk['choices'][number]['delta']; usage?: OpenAI.ChatCompletionChunk['usage']; }): OpenAI.ChatCompletionChunk { return { - choices: [ - { - finish_reason: null, - index: 0, - delta, - }, - ], + choices: delta + ? [ + { + finish_reason: null, + index: 0, + delta, + }, + ] + : [], created: new Date().getTime(), id: v4(), model: 'gpt-4o', @@ -313,7 +315,7 @@ describe('openAIAdapter', () => { ]); }); - it('emits token events', async () => { + it('emits chunk events with tool calls', async () => { const response$ = openAIAdapter.chatComplete({ ...defaultArgs, messages: [ @@ -375,5 +377,55 @@ describe('openAIAdapter', () => { }, ]); }); + + it('emits token count events', async () => { + const response$ = openAIAdapter.chatComplete({ + ...defaultArgs, + messages: [ + { + role: MessageRole.User, + content: 'Hello', + }, + ], + }); + + source$.next( + createOpenAIChunk({ + delta: { + content: 'chunk', + }, + }) + ); + + source$.next( + createOpenAIChunk({ + usage: { + prompt_tokens: 50, + completion_tokens: 100, + total_tokens: 150, + }, + }) + ); + + source$.complete(); + + const allChunks = await lastValueFrom(response$.pipe(toArray())); + + expect(allChunks).toEqual([ + { + type: ChatCompletionEventType.ChatCompletionChunk, + content: 'chunk', + tool_calls: [], + }, + { + type: ChatCompletionEventType.ChatCompletionTokenCount, + tokens: { + prompt: 50, + completion: 100, + total: 150, + }, + }, + ]); + }); }); }); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts b/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts index 121ba96ab115a..fa412f335800d 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts @@ -5,7 +5,7 @@ * 2.0. */ -import OpenAI from 'openai'; +import type OpenAI from 'openai'; import type { ChatCompletionAssistantMessageParam, ChatCompletionMessageParam, @@ -13,22 +13,33 @@ import type { ChatCompletionToolMessageParam, ChatCompletionUserMessageParam, } from 'openai/resources'; -import { filter, from, map, switchMap, tap, throwError, identity } from 'rxjs'; -import { Readable, isReadable } from 'stream'; +import { + filter, + from, + identity, + map, + mergeMap, + Observable, + switchMap, + tap, + throwError, +} from 'rxjs'; +import { isReadable, Readable } from 'stream'; import { ChatCompletionChunkEvent, ChatCompletionEventType, + ChatCompletionTokenCountEvent, + createInferenceInternalError, Message, MessageRole, ToolOptions, - createInferenceInternalError, } from '@kbn/inference-common'; import { createTokenLimitReachedError } from '../../errors'; import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable'; import type { InferenceConnectorAdapter } from '../../types'; import { - wrapWithSimulatedFunctionCalling, parseInlineFunctionCalls, + wrapWithSimulatedFunctionCalling, } from '../../simulated_function_calling'; export const openAIAdapter: InferenceConnectorAdapter = { @@ -92,34 +103,57 @@ export const openAIAdapter: InferenceConnectorAdapter = { throw createTokenLimitReachedError(); } }), - filter( - (line): line is OpenAI.ChatCompletionChunk => - 'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0 - ), - map((chunk): ChatCompletionChunkEvent => { - const delta = chunk.choices[0].delta; - - return { - type: ChatCompletionEventType.ChatCompletionChunk, - content: delta.content ?? '', - tool_calls: - delta.tool_calls?.map((toolCall) => { - return { - function: { - name: toolCall.function?.name ?? '', - arguments: toolCall.function?.arguments ?? '', - }, - toolCallId: toolCall.id ?? '', - index: toolCall.index, - }; - }) ?? [], - }; + filter((line): line is OpenAI.ChatCompletionChunk => { + return 'object' in line && line.object === 'chat.completion.chunk'; + }), + mergeMap((chunk): Observable => { + const events: Array = []; + if (chunk.usage) { + events.push(tokenCountFromOpenAI(chunk.usage)); + } + if (chunk.choices?.length) { + events.push(chunkFromOpenAI(chunk)); + } + return from(events); }), simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity ); }, }; +function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompletionChunkEvent { + const delta = chunk.choices[0].delta; + + return { + type: ChatCompletionEventType.ChatCompletionChunk, + content: delta.content ?? '', + tool_calls: + delta.tool_calls?.map((toolCall) => { + return { + function: { + name: toolCall.function?.name ?? '', + arguments: toolCall.function?.arguments ?? '', + }, + toolCallId: toolCall.id ?? '', + index: toolCall.index, + }; + }) ?? [], + }; +} + +function tokenCountFromOpenAI( + completionUsage: OpenAI.CompletionUsage +): ChatCompletionTokenCountEvent { + return { + type: ChatCompletionEventType.ChatCompletionTokenCount, + tokens: { + completion: completionUsage.completion_tokens, + prompt: completionUsage.prompt_tokens, + total: completionUsage.total_tokens, + }, + }; +} + function toolsToOpenAI(tools: ToolOptions['tools']): OpenAI.ChatCompletionCreateParams['tools'] { return tools ? Object.entries(tools).map(([toolName, { description, schema }]) => { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts index 6023d7715f4ed..628ab7adcd363 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts @@ -101,9 +101,50 @@ describe('Azure Open AI Utils', () => { }; [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), true); - expect(sanitizedBodyString).toEqual( - `{\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}` - ); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + stream: true, + stream_options: { + include_usage: true, + }, + }); + }); + }); + it('sets stream_options when stream is true', () => { + const body = { + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), true); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + stream: true, + stream_options: { + include_usage: true, + }, + }); + }); + }); + it('does not sets stream_options when stream is false', () => { + const body = { + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), false); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + stream: false, + }); }); }); it('overrides stream parameter if defined in body', () => { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts index 02bff6ea2f63a..8825e719f0105 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts @@ -48,6 +48,11 @@ export const getRequestWithStreamOption = (url: string, body: string, stream: bo const jsonBody = JSON.parse(body); if (jsonBody) { jsonBody.stream = stream; + if (stream) { + jsonBody.stream_options = { + include_usage: true, + }; + } } return JSON.stringify(jsonBody); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.test.ts index b480b72859183..cd65084badc92 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.test.ts @@ -118,6 +118,31 @@ describe('Open AI Utils', () => { ], }; + [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption( + url, + JSON.stringify(body), + false, + DEFAULT_OPENAI_MODEL + ); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + model: 'gpt-4', + stream: false, + }); + }); + }); + it('sets stream_options when stream is true', () => { + const body = { + model: 'gpt-4', + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { const sanitizedBodyString = getRequestWithStreamOption( url, @@ -125,9 +150,39 @@ describe('Open AI Utils', () => { true, DEFAULT_OPENAI_MODEL ); - expect(sanitizedBodyString).toEqual( - `{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}` + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + model: 'gpt-4', + stream: true, + stream_options: { + include_usage: true, + }, + }); + }); + }); + it('does not set stream_options when stream is false', () => { + const body = { + model: 'gpt-4', + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + [OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption( + url, + JSON.stringify(body), + false, + DEFAULT_OPENAI_MODEL ); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + model: 'gpt-4', + stream: false, + }); }); }); @@ -182,6 +237,7 @@ describe('Open AI Utils', () => { expect(sanitizedBodyString).toEqual(bodyString); }); }); + describe('removeEndpointFromUrl', () => { test('removes "/chat/completions" from the end of the URL', () => { const originalUrl = 'https://api.openai.com/v1/chat/completions'; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.ts index 7dac5f4692bda..89a29105cd0ca 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/openai_utils.ts @@ -38,6 +38,11 @@ export const getRequestWithStreamOption = ( if (jsonBody) { if (APIS_ALLOWING_STREAMING.has(url)) { jsonBody.stream = stream; + if (stream) { + jsonBody.stream_options = { + include_usage: true, + }; + } } jsonBody.model = jsonBody.model || defaultModel; } diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts index 1362b7610e2cd..33d96451054f4 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts @@ -292,6 +292,7 @@ describe('OpenAIConnector', () => { data: JSON.stringify({ ...sampleOpenAiBody, stream: true, + stream_options: { include_usage: true }, model: DEFAULT_OPENAI_MODEL, }), headers: { @@ -338,6 +339,7 @@ describe('OpenAIConnector', () => { data: JSON.stringify({ ...body, stream: true, + stream_options: { include_usage: true }, }), headers: { Authorization: 'Bearer 123', @@ -397,6 +399,7 @@ describe('OpenAIConnector', () => { data: JSON.stringify({ ...sampleOpenAiBody, stream: true, + stream_options: { include_usage: true }, model: DEFAULT_OPENAI_MODEL, }), headers: { @@ -422,6 +425,7 @@ describe('OpenAIConnector', () => { data: JSON.stringify({ ...sampleOpenAiBody, stream: true, + stream_options: { include_usage: true }, model: DEFAULT_OPENAI_MODEL, }), headers: { @@ -448,6 +452,7 @@ describe('OpenAIConnector', () => { data: JSON.stringify({ ...sampleOpenAiBody, stream: true, + stream_options: { include_usage: true }, model: DEFAULT_OPENAI_MODEL, }), headers: { @@ -1274,7 +1279,11 @@ describe('OpenAIConnector', () => { url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15', method: 'post', responseSchema: StreamingResponseSchema, - data: JSON.stringify({ ...sampleAzureAiBody, stream: true }), + data: JSON.stringify({ + ...sampleAzureAiBody, + stream: true, + stream_options: { include_usage: true }, + }), headers: { 'api-key': '123', 'content-type': 'application/json', @@ -1314,6 +1323,7 @@ describe('OpenAIConnector', () => { data: JSON.stringify({ ...body, stream: true, + stream_options: { include_usage: true }, }), headers: { 'api-key': '123',