Skip to content

Commit

Permalink
Add support for prompt token details in OpenAI usage stats
Browse files Browse the repository at this point in the history
- Add PromptTokensDetails record to track cached tokens in prompt
- Update Usage record to include promptTokensDetails field
- Add getCachedTokens() method to OpenAiUsage
- Add test cases for cached tokens handling

Resolves #1506
  • Loading branch information
dafriz authored and tzolov committed Oct 22, 2024
1 parent 8b1882b commit 2c17577
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
* @author Michael Lavelle
* @author Mariusz Bernacki
* @author Thomas Vitale
* @author David Frizelle
*/
public class OpenAiApi {

Expand Down Expand Up @@ -938,17 +939,29 @@ public record TopLogProbs(// @formatter:off
* @param promptTokens Number of tokens in the prompt.
* @param totalTokens Total number of tokens used in the request (prompt +
* completion).
* @param completionTokenDetails Breakdown of tokens used in a completion
* @param promptTokensDetails Breakdown of tokens used in the prompt.
* @param completionTokenDetails Breakdown of tokens used in a completion.
*/
@JsonInclude(Include.NON_NULL)
public record Usage(// @formatter:off
@JsonProperty("completion_tokens") Integer completionTokens,
@JsonProperty("prompt_tokens") Integer promptTokens,
@JsonProperty("total_tokens") Integer totalTokens,
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails,
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on

public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
this(completionTokens, promptTokens, totalTokens, null);
this(completionTokens, promptTokens, totalTokens, null, null);
}

/**
* Breakdown of tokens used in the prompt
*
* @param cachedTokens Cached tokens present in the prompt.
*/
@JsonInclude(Include.NON_NULL)
public record PromptTokensDetails(// @formatter:off
@JsonProperty("cached_tokens") Integer cachedTokens) {// @formatter:on
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*
* @author John Blum
* @author Thomas Vitale
* @author David Frizelle
* @since 0.7.0
* @see <a href=
* "https://platform.openai.com/docs/api-reference/completions/object">Completion
Expand Down Expand Up @@ -58,6 +59,12 @@ public Long getGenerationTokens() {
return generationTokens != null ? generationTokens.longValue() : 0;
}

public Long getCachedTokens() {
OpenAiApi.Usage.PromptTokensDetails promptTokenDetails = getUsage().promptTokensDetails();
Integer cachedTokens = promptTokenDetails != null ? promptTokenDetails.cachedTokens() : null;
return cachedTokens != null ? cachedTokens.longValue() : 0;
}

public Long getReasoningTokens() {
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,44 @@ void whenTotalTokensIsNull() {
}

@Test
void whenCompletionTokenDetailsIsNull() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null);
void whenPromptAndCompletionTokensDetailsIsNull() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null);
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getTotalTokens()).isEqualTo(300);
assertThat(usage.getCachedTokens()).isEqualTo(0);
assertThat(usage.getReasoningTokens()).isEqualTo(0);
}

@Test
void whenReasoningTokensIsNull() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
new OpenAiApi.Usage.CompletionTokenDetails(null));
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getReasoningTokens()).isEqualTo(0);
}

@Test
void whenCompletionTokenDetailsIsPresent() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
new OpenAiApi.Usage.CompletionTokenDetails(50));
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getReasoningTokens()).isEqualTo(50);
}

@Test
void whenCacheTokensIsNull() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(null),
null);
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getCachedTokens()).isEqualTo(0);
}

@Test
void whenCacheTokensIsPresent() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(15),
null);
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getCachedTokens()).isEqualTo(15);
}

}

0 comments on commit 2c17577

Please sign in to comment.