From 0ec05018e084f83efffe894298cd6c531ef4efd3 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 22 Oct 2024 10:32:10 +0200 Subject: [PATCH 1/2] feat: add chat endpoint to sync and async client while this commit supports streaming for the chat endpoint, it does only offer a simplified version and does not forward all parameters --- aleph_alpha_client/aleph_alpha_client.py | 109 ++++++++++++++++++++++- aleph_alpha_client/chat.py | 106 ++++++++++++++++++++++ tests/common.py | 5 ++ tests/test_chat.py | 54 +++++++++++ 4 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 aleph_alpha_client/chat.py create mode 100644 tests/test_chat.py diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 1aaace9..6c6bd9b 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -36,9 +36,9 @@ CompletionRequest, CompletionResponse, CompletionResponseStreamItem, - StreamChunk, stream_item_from_json, ) +from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse from aleph_alpha_client.detokenization import ( @@ -99,6 +99,7 @@ def _check_api_version(version_str: str): AnyRequest = Union[ CompletionRequest, + ChatRequest, EmbeddingRequest, EvaluationRequest, TokenizationRequest, @@ -302,6 +303,34 @@ def complete( response = self._post_request("complete", request, model) return CompletionResponse.from_json(response) + def chat( + self, + request: ChatRequest, + model: str, + ) -> ChatResponse: + """Chat with a model. + + Parameters: + request (ChatRequest, required): + Parameters for the requested chat. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a chat request + >>> request = ChatRequest( + messages=[Message(role="user", content="Hello, how are you?")], + model=model, + ) + >>> + >>> # chat with the model + >>> result = client.chat(request, model=model_name) + """ + response = self._post_request("chat/completions", request, model) + return ChatResponse.from_json(response) + def tokenize( self, request: TokenizationRequest, @@ -797,7 +826,11 @@ async def _post_request_with_streaming( f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}" ) - yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :]) + payload = stream_item_as_str[len(self.SSE_DATA_PREFIX) :] + if payload == "[DONE]": + continue + + yield json.loads(payload) def _build_query_parameters(self) -> Mapping[str, str]: return { @@ -864,6 +897,38 @@ async def complete( ) return CompletionResponse.from_json(response) + async def chat( + self, + request: ChatRequest, + model: str, + ) -> ChatResponse: + """Chat with a model. + + Parameters: + request (ChatRequest, required): + Parameters for the requested chat. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a chat request + >>> request = ChatRequest( + messages=[Message(role="user", content="Hello, how are you?")], + model=model, + ) + >>> + >>> # chat with the model + >>> result = await client.chat(request, model=model_name) + """ + response = await self._post_request( + "chat/completions", + request, + model, + ) + return ChatResponse.from_json(response) + async def complete_with_streaming( self, request: CompletionRequest, @@ -905,6 +970,46 @@ async def complete_with_streaming( ): yield stream_item_from_json(stream_item_json) + async def chat_with_streaming( + self, + request: ChatRequest, + model: str, + ) -> AsyncGenerator[ChatStreamChunk, None]: + """Generates streamed chat completions. + + The first yielded chunk contains the role, while subsequent chunks only contain the content delta. + + Parameters: + request (ChatRequest, required): + Parameters for the requested chat. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a chat request + >>> request = ChatRequest( + messages=[Message(role="user", content="Hello, how are you?")], + model=model, + ) + >>> + >>> # chat with the model + >>> result = await client.chat_with_streaming(request, model=model_name) + >>> + >>> # consume the chat stream + >>> async for stream_item in result: + >>> do_something_with(stream_item) + """ + async for stream_item_json in self._post_request_with_streaming( + "chat/completions", + request, + model, + ): + chunk = ChatStreamChunk.from_json(stream_item_json) + if chunk is not None: + yield chunk + async def tokenize( self, request: TokenizationRequest, diff --git a/aleph_alpha_client/chat.py b/aleph_alpha_client/chat.py new file mode 100644 index 0000000..8105140 --- /dev/null +++ b/aleph_alpha_client/chat.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, asdict +from typing import List, Optional, Mapping, Any, Dict +from enum import Enum + + +class Role(str, Enum): + """A role used for a message in a chat.""" + User = "user" + Assistant = "assistant" + System = "system" + + +@dataclass(frozen=True) +class Message: + """ + Describes a message in a chat. + + Parameters: + role (Role, required): + The role of the message. + + content (str, required): + The content of the message. + """ + role: Role + content: str + + def to_json(self) -> Mapping[str, Any]: + return asdict(self) + + @staticmethod + def from_json(json: Dict[str, Any]) -> "Message": + return Message( + role=Role(json["role"]), + content=json["content"], + ) + + +@dataclass(frozen=True) +class ChatRequest: + """ + Describes a chat request. + + Only supports a subset of the parameters of `CompletionRequest` for simplicity. + See `CompletionRequest` for documentation on the parameters. + """ + model: str + messages: List[Message] + maximum_tokens: Optional[int] = None + temperature: float = 0.0 + top_k: int = 0 + top_p: float = 0.0 + + def to_json(self) -> Mapping[str, Any]: + payload = {k: v for k, v in asdict(self).items() if v is not None} + payload["messages"] = [message.to_json() for message in self.messages] + return payload + + +@dataclass(frozen=True) +class ChatResponse: + """ + A simplified version of the chat response. + + As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values), + the `ChatResponse` assumes there to be only one choice. + """ + finish_reason: str + message: Message + + @staticmethod + def from_json(json: Dict[str, Any]) -> "ChatResponse": + first_choice = json["choices"][0] + return ChatResponse( + finish_reason=first_choice["finish_reason"], + message=Message.from_json(first_choice["message"]), + ) + + +@dataclass(frozen=True) +class ChatStreamChunk: + """ + A streamed chat completion chunk. + + Parameters: + content (str, required): + The content of the current chat completion. Will be empty for the first chunk of every completion stream and non-empty for the remaining chunks. + + role (Role, optional): + The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks. + """ + content: str + role: Optional[Role] + + @staticmethod + def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]: + """ + Returns a ChatStreamChunk if the chunk contains a message, otherwise None. + """ + if not (delta := json["choices"][0]["delta"]): + return None + + return ChatStreamChunk( + content=delta["content"], + role=Role(delta.get("role")) if delta.get("role") else None, + ) \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index 6a30287..9390638 100644 --- a/tests/common.py +++ b/tests/common.py @@ -32,6 +32,11 @@ def model_name() -> str: return "luminous-base" +@pytest.fixture(scope="session") +def chat_model_name() -> str: + return "llama-3.1-70b-instruct" + + @pytest.fixture(scope="session") def prompt_image() -> Image: image_source_path = Path(__file__).parent / "dog-and-cat-cover.jpg" diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..a216b55 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,54 @@ +import pytest + +from aleph_alpha_client import AsyncClient, Client +from aleph_alpha_client.chat import ChatRequest, Message, Role +from tests.common import async_client, sync_client, model_name, chat_model_name + + +@pytest.mark.system_test +async def test_can_not_chat_with_all_models(async_client: AsyncClient, model_name: str): + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=model_name, + ) + + with pytest.raises(ValueError): + await async_client.chat(request, model=model_name) + + +def test_can_chat_with_chat_model(sync_client: Client, chat_model_name: str): + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=chat_model_name, + ) + + response = sync_client.chat(request, model=chat_model_name) + assert response.message.role == Role.Assistant + assert response.message.content is not None + + +async def test_can_chat_with_async_client(async_client: AsyncClient, chat_model_name: str): + system_msg = Message(role=Role.System, content="You are a helpful assistant.") + user_msg = Message(role=Role.User, content="Hello, how are you?") + request = ChatRequest( + messages=[system_msg, user_msg], + model=chat_model_name, + ) + + response = await async_client.chat(request, model=chat_model_name) + assert response.message.role == Role.Assistant + assert response.message.content is not None + + +async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_model_name: str): + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=chat_model_name, + ) + + stream_items = [ + stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name) + ] + + assert stream_items[0].role is not None + assert all(item.content is not None for item in stream_items[1:]) From 8dafe2246c2f0a5dbfdf1610feeb6fdf4d324dfc Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 28 Oct 2024 17:56:29 +0100 Subject: [PATCH 2/2] feat: enable include_usage flag for chat stream endpoint --- aleph_alpha_client/aleph_alpha_client.py | 6 +- aleph_alpha_client/chat.py | 51 ++++++++++++++- tests/test_chat.py | 82 +++++++++++++++++++++++- 3 files changed, 131 insertions(+), 8 deletions(-) diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 6c6bd9b..2a433cf 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -38,7 +38,7 @@ CompletionResponseStreamItem, stream_item_from_json, ) -from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk +from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk, Usage, stream_chat_item_from_json from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse from aleph_alpha_client.detokenization import ( @@ -974,7 +974,7 @@ async def chat_with_streaming( self, request: ChatRequest, model: str, - ) -> AsyncGenerator[ChatStreamChunk, None]: + ) -> AsyncGenerator[Union[ChatStreamChunk, Usage], None]: """Generates streamed chat completions. The first yielded chunk contains the role, while subsequent chunks only contain the content delta. @@ -1006,7 +1006,7 @@ async def chat_with_streaming( request, model, ): - chunk = ChatStreamChunk.from_json(stream_item_json) + chunk = stream_chat_item_from_json(stream_item_json) if chunk is not None: yield chunk diff --git a/aleph_alpha_client/chat.py b/aleph_alpha_client/chat.py index 8105140..1e22a12 100644 --- a/aleph_alpha_client/chat.py +++ b/aleph_alpha_client/chat.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, asdict -from typing import List, Optional, Mapping, Any, Dict +from typing import List, Optional, Mapping, Any, Dict, Union from enum import Enum @@ -36,6 +36,17 @@ def from_json(json: Dict[str, Any]) -> "Message": ) +@dataclass(frozen=True) +class StreamOptions: + """ + Additional options to affect the streaming behavior. + """ + # If set, an additional chunk will be streamed before the data: [DONE] message. + # The usage field on this chunk shows the token usage statistics for the entire + # request, and the choices field will always be an empty array. + include_usage: bool + + @dataclass(frozen=True) class ChatRequest: """ @@ -50,6 +61,7 @@ class ChatRequest: temperature: float = 0.0 top_k: int = 0 top_p: float = 0.0 + stream_options: Optional[StreamOptions] = None def to_json(self) -> Mapping[str, Any]: payload = {k: v for k, v in asdict(self).items() if v is not None} @@ -77,6 +89,34 @@ def from_json(json: Dict[str, Any]) -> "ChatResponse": ) + +@dataclass(frozen=True) +class Usage: + """ + Usage statistics for the completion request. + + When streaming is enabled, this field will be null by default. + To include an additional usage-only message in the response stream, set stream_options.include_usage to true. + """ + # Number of tokens in the generated completion. + completion_tokens: int + + # Number of tokens in the prompt. + prompt_tokens: int + + # Total number of tokens used in the request (prompt + completion). + total_tokens: int + + @staticmethod + def from_json(json: Dict[str, Any]) -> "Usage": + return Usage( + completion_tokens=json["completion_tokens"], + prompt_tokens=json["prompt_tokens"], + total_tokens=json["total_tokens"] + ) + + + @dataclass(frozen=True) class ChatStreamChunk: """ @@ -103,4 +143,11 @@ def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]: return ChatStreamChunk( content=delta["content"], role=Role(delta.get("role")) if delta.get("role") else None, - ) \ No newline at end of file + ) + + +def stream_chat_item_from_json(json: Dict[str, Any]) -> Union[Usage, ChatStreamChunk, None]: + if (usage := json.get("usage")) is not None: + return Usage.from_json(usage) + + return ChatStreamChunk.from_json(json) \ No newline at end of file diff --git a/tests/test_chat.py b/tests/test_chat.py index a216b55..13c7598 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,7 +1,7 @@ import pytest from aleph_alpha_client import AsyncClient, Client -from aleph_alpha_client.chat import ChatRequest, Message, Role +from aleph_alpha_client.chat import ChatRequest, Message, Role, StreamOptions, stream_chat_item_from_json, Usage, ChatStreamChunk from tests.common import async_client, sync_client, model_name, chat_model_name @@ -50,5 +50,81 @@ async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_m stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name) ] - assert stream_items[0].role is not None - assert all(item.content is not None for item in stream_items[1:]) + first = stream_items[0] + assert isinstance(first, ChatStreamChunk) and first.role is not None + assert all(isinstance(item, ChatStreamChunk) and item.content is not None for item in stream_items[1:]) + + +async def test_usage_response_is_parsed(): + # Given an API response with usage data and no choice + data = { + "choices": [], + "created": 1730133402, + "model": "llama-3.1-70b-instruct", + "system_fingerprint": ".unknown.", + "object": "chat.completion.chunk", + "usage": { + "prompt_tokens": 31, + "completion_tokens": 88, + "total_tokens": 119 + } + } + + # When parsing it + result = stream_chat_item_from_json(data) + + # Then a usage instance is returned + assert isinstance(result, Usage) + assert result.prompt_tokens == 31 + + +def test_chunk_response_is_parsed(): + # Given an API response without usage data + data = { + "choices": [ + { + "finish_reason": None, + "index": 0, + "delta": { + "content": " way, those clothes you're wearing" + }, + "logprobs": None + } + ], + "created": 1730133401, + "model": "llama-3.1-70b-instruct", + "system_fingerprint": None, + "object": "chat.completion.chunk", + "usage": None, + } + + # When parsing it + result = stream_chat_item_from_json(data) + + # Then a ChatStreamChunk instance is returned + assert isinstance(result, ChatStreamChunk) + assert result.content == " way, those clothes you're wearing" + + + +async def test_stream_options(async_client: AsyncClient, chat_model_name: str): + # Given a request with include usage options set + stream_options = StreamOptions(include_usage=True) + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=chat_model_name, + stream_options=stream_options + + ) + + # When receiving the chunks + stream_items = [ + stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name) + ] + + # Then the last chunks has information about usage + assert all(isinstance(item, ChatStreamChunk) for item in stream_items[:-1]) + assert isinstance(stream_items[-1], Usage) + + + \ No newline at end of file