From 2c17577f2fb9dfd84c17da7801b0ed9d47c9939c Mon Sep 17 00:00:00 2001 From: dafriz Date: Wed, 9 Oct 2024 22:15:15 +1100 Subject: [PATCH] Add support for prompt token details in OpenAI usage stats - Add PromptTokensDetails record to track cached tokens in prompt - Update Usage record to include promptTokensDetails field - Add getCachedTokens() method to OpenAiUsage - Add test cases for cached tokens handling Resolves #1506 --- .../ai/openai/api/OpenAiApi.java | 17 +++++++++++-- .../ai/openai/metadata/OpenAiUsage.java | 7 ++++++ .../ai/openai/metadata/OpenAiUsageTests.java | 25 ++++++++++++++++--- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 3141826826..657f37a136 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -56,6 +56,7 @@ * @author Michael Lavelle * @author Mariusz Bernacki * @author Thomas Vitale + * @author David Frizelle */ public class OpenAiApi { @@ -938,17 +939,29 @@ public record TopLogProbs(// @formatter:off * @param promptTokens Number of tokens in the prompt. * @param totalTokens Total number of tokens used in the request (prompt + * completion). - * @param completionTokenDetails Breakdown of tokens used in a completion + * @param promptTokensDetails Breakdown of tokens used in the prompt. + * @param completionTokenDetails Breakdown of tokens used in a completion. */ @JsonInclude(Include.NON_NULL) public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails, @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { - this(completionTokens, promptTokens, totalTokens, null); + this(completionTokens, promptTokens, totalTokens, null, null); + } + + /** + * Breakdown of tokens used in the prompt + * + * @param cachedTokens Cached tokens present in the prompt. + */ + @JsonInclude(Include.NON_NULL) + public record PromptTokensDetails(// @formatter:off + @JsonProperty("cached_tokens") Integer cachedTokens) {// @formatter:on } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index add5d896b5..46ec6ffb78 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -24,6 +24,7 @@ * * @author John Blum * @author Thomas Vitale + * @author David Frizelle * @since 0.7.0 * @see Completion @@ -58,6 +59,12 @@ public Long getGenerationTokens() { return generationTokens != null ? generationTokens.longValue() : 0; } + public Long getCachedTokens() { + OpenAiApi.Usage.PromptTokensDetails promptTokenDetails = getUsage().promptTokensDetails(); + Integer cachedTokens = promptTokenDetails != null ? promptTokenDetails.cachedTokens() : null; + return cachedTokens != null ? cachedTokens.longValue() : 0; + } + public Long getReasoningTokens() { OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index b9215b4c3d..1c7c53e0d2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -55,16 +55,17 @@ void whenTotalTokensIsNull() { } @Test - void whenCompletionTokenDetailsIsNull() { - OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null); + void whenPromptAndCompletionTokensDetailsIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null); OpenAiUsage usage = OpenAiUsage.from(openAiUsage); assertThat(usage.getTotalTokens()).isEqualTo(300); + assertThat(usage.getCachedTokens()).isEqualTo(0); assertThat(usage.getReasoningTokens()).isEqualTo(0); } @Test void whenReasoningTokensIsNull() { - OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(null)); OpenAiUsage usage = OpenAiUsage.from(openAiUsage); assertThat(usage.getReasoningTokens()).isEqualTo(0); @@ -72,10 +73,26 @@ void whenReasoningTokensIsNull() { @Test void whenCompletionTokenDetailsIsPresent() { - OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(50)); OpenAiUsage usage = OpenAiUsage.from(openAiUsage); assertThat(usage.getReasoningTokens()).isEqualTo(50); } + @Test + void whenCacheTokensIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(null), + null); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getCachedTokens()).isEqualTo(0); + } + + @Test + void whenCacheTokensIsPresent() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(15), + null); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getCachedTokens()).isEqualTo(15); + } + }