From bcc6c271e8eca16fff5cd56dc953741eecb992f0 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 28 Oct 2024 17:56:29 +0100 Subject: [PATCH] feat: enable include_usage flag for chat stream endpoint --- aleph_alpha_client/aleph_alpha_client.py | 6 +- aleph_alpha_client/chat.py | 50 ++++++++++++++- tests/test_chat.py | 77 +++++++++++++++++++++++- 3 files changed, 127 insertions(+), 6 deletions(-) diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 6c6bd9b..2a433cf 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -38,7 +38,7 @@ CompletionResponseStreamItem, stream_item_from_json, ) -from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk +from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk, Usage, stream_chat_item_from_json from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse from aleph_alpha_client.detokenization import ( @@ -974,7 +974,7 @@ async def chat_with_streaming( self, request: ChatRequest, model: str, - ) -> AsyncGenerator[ChatStreamChunk, None]: + ) -> AsyncGenerator[Union[ChatStreamChunk, Usage], None]: """Generates streamed chat completions. The first yielded chunk contains the role, while subsequent chunks only contain the content delta. @@ -1006,7 +1006,7 @@ async def chat_with_streaming( request, model, ): - chunk = ChatStreamChunk.from_json(stream_item_json) + chunk = stream_chat_item_from_json(stream_item_json) if chunk is not None: yield chunk diff --git a/aleph_alpha_client/chat.py b/aleph_alpha_client/chat.py index 8105140..edaeb22 100644 --- a/aleph_alpha_client/chat.py +++ b/aleph_alpha_client/chat.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, asdict -from typing import List, Optional, Mapping, Any, Dict +from typing import List, Optional, Mapping, Any, Dict, Union from enum import Enum @@ -36,6 +36,17 @@ def from_json(json: Dict[str, Any]) -> "Message": ) +@dataclass(frozen=True) +class StreamOptions: + """ + Additional options to affect the streaming behavior. + """ + # If set, an additional chunk will be streamed before the data: [DONE] message. + # The usage field on this chunk shows the token usage statistics for the entire + # request, and the choices field will always be an empty array. + include_usage: bool + + @dataclass(frozen=True) class ChatRequest: """ @@ -50,6 +61,7 @@ class ChatRequest: temperature: float = 0.0 top_k: int = 0 top_p: float = 0.0 + stream_options: Optional[StreamOptions] = None def to_json(self) -> Mapping[str, Any]: payload = {k: v for k, v in asdict(self).items() if v is not None} @@ -77,6 +89,33 @@ def from_json(json: Dict[str, Any]) -> "ChatResponse": ) + +@dataclass(frozen=True) +class Usage: + """ + Usage statistics for the completion request. + + When streaming is enabled, this field will be null by default. + To include an additional usage-only message in the response stream, set stream_options.include_usage to true. + """ + # Number of tokens in the generated completion. + completion_tokens: int + + # Number of tokens in the prompt. + prompt_tokens: int + + # Total number of tokens used in the request (prompt + completion). + total_tokens: int + + def from_json(json: Dict[str, Any]) -> "Usage": + return Usage( + completion_tokens=json["completion_tokens"], + prompt_tokens=json["prompt_tokens"], + total_tokens=json["total_tokens"] + ) + + + @dataclass(frozen=True) class ChatStreamChunk: """ @@ -103,4 +142,11 @@ def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]: return ChatStreamChunk( content=delta["content"], role=Role(delta.get("role")) if delta.get("role") else None, - ) \ No newline at end of file + ) + + +def stream_chat_item_from_json(json: Dict[str, Any]) -> Union[Usage, ChatStreamChunk, None]: + if (usage := json.get("usage")) is not None: + return Usage.from_json(usage) + + return ChatStreamChunk.from_json(json) \ No newline at end of file diff --git a/tests/test_chat.py b/tests/test_chat.py index a216b55..5015c42 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,7 +1,7 @@ import pytest from aleph_alpha_client import AsyncClient, Client -from aleph_alpha_client.chat import ChatRequest, Message, Role +from aleph_alpha_client.chat import ChatRequest, Message, Role, StreamOptions, stream_chat_item_from_json, Usage, ChatStreamChunk from tests.common import async_client, sync_client, model_name, chat_model_name @@ -52,3 +52,78 @@ async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_m assert stream_items[0].role is not None assert all(item.content is not None for item in stream_items[1:]) + + +async def test_usage_response_is_parsed(): + # Given an API response with usage data and no choice + data = { + "choices": [], + "created": 1730133402, + "model": "llama-3.1-70b-instruct", + "system_fingerprint": ".unknown.", + "object": "chat.completion.chunk", + "usage": { + "prompt_tokens": 31, + "completion_tokens": 88, + "total_tokens": 119 + } + } + + # When parsing it + result = stream_chat_item_from_json(data) + + # Then a usage instance is returned + assert isinstance(result, Usage) + assert result.prompt_tokens == 31 + + +def test_chunk_response_is_parsed(): + # Given an API response without usage data + data = { + "choices": [ + { + "finish_reason": None, + "index": 0, + "delta": { + "content": " way, those clothes you're wearing" + }, + "logprobs": None + } + ], + "created": 1730133401, + "model": "llama-3.1-70b-instruct", + "system_fingerprint": None, + "object": "chat.completion.chunk", + "usage": None, + } + + # When parsing it + result = stream_chat_item_from_json(data) + + # Then a ChatStreamChunk instance is returned + assert isinstance(result, ChatStreamChunk) + assert result.content == " way, those clothes you're wearing" + + + +async def test_stream_options(async_client: AsyncClient, chat_model_name: str): + # Given a request with include usage options set + stream_options = StreamOptions(include_usage=True) + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=chat_model_name, + stream_options=stream_options + + ) + + # When receiving the chunks + stream_items = [ + stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name) + ] + + # Then the last chunks has information about usage + assert all(item.usage is None for item in stream_items[:-1]) + assert stream_items[-1].usage is not None + + + \ No newline at end of file