Skip to content

Commit

Permalink
feat: add chat endpoint to sync and async client
Browse files Browse the repository at this point in the history
while this commit supports streaming for the chat endpoint, it does only offer a simplified version and does not forward all parameters
  • Loading branch information
moldhouse committed Oct 22, 2024
1 parent 7c07328 commit d46ab1b
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 2 deletions.
109 changes: 107 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
CompletionRequest,
CompletionResponse,
CompletionResponseStreamItem,
StreamChunk,
stream_item_from_json,
)
from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand Down Expand Up @@ -99,6 +99,7 @@ def _check_api_version(version_str: str):

AnyRequest = Union[
CompletionRequest,
ChatRequest,
EmbeddingRequest,
EvaluationRequest,
TokenizationRequest,
Expand Down Expand Up @@ -302,6 +303,34 @@ def complete(
response = self._post_request("complete", request, model)
return CompletionResponse.from_json(response)

def chat(
self,
request: ChatRequest,
model: str,
) -> ChatResponse:
"""Chat with a model.
Parameters:
request (ChatRequest, required):
Parameters for the requested chat.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = client.chat(request, model=model_name)
"""
response = self._post_request("chat/completions", request, model)
return ChatResponse.from_json(response)

def tokenize(
self,
request: TokenizationRequest,
Expand Down Expand Up @@ -797,7 +826,11 @@ async def _post_request_with_streaming(
f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}"
)

yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :])
payload = stream_item_as_str[len(self.SSE_DATA_PREFIX) :]
if payload == "[DONE]":
continue

yield json.loads(payload)

def _build_query_parameters(self) -> Mapping[str, str]:
return {
Expand Down Expand Up @@ -864,6 +897,38 @@ async def complete(
)
return CompletionResponse.from_json(response)

async def chat(
self,
request: ChatRequest,
model: str,
) -> ChatResponse:
"""Chat with a model.
Parameters:
request (ChatRequest, required):
Parameters for the requested chat.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = await client.chat(request, model=model_name)
"""
response = await self._post_request(
"chat/completions",
request,
model,
)
return ChatResponse.from_json(response)

async def complete_with_streaming(
self,
request: CompletionRequest,
Expand Down Expand Up @@ -905,6 +970,46 @@ async def complete_with_streaming(
):
yield stream_item_from_json(stream_item_json)

async def chat_with_streaming(
self,
request: ChatRequest,
model: str,
) -> AsyncGenerator[ChatStreamChunk, None]:
"""Generates streamed chat completions.
The first yie chunk contains the role, while subsequent chunks only contain the content delta.
Parameters:
request (ChatRequest, required):
Parameters for the requested chat.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = await client.chat_with_streaming(request, model=model_name)
>>>
>>> # consume the chat stream
>>> async for stream_item in result:
>>> do_something_with(stream_item)
"""
async for stream_item_json in self._post_request_with_streaming(
"chat/completions",
request,
model,
):
chunk = ChatStreamChunk.from_json(stream_item_json)
if chunk is not None:
yield chunk

async def tokenize(
self,
request: TokenizationRequest,
Expand Down
110 changes: 110 additions & 0 deletions aleph_alpha_client/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Mapping, Any, Dict
from enum import Enum

class Role(str, Enum):
"""A role used for a message in a chat."""
User = "user"
Assistant = "assistant"
System = "system"


@dataclass(frozen=True)
class Message:
"""
Describes a message in a chat.
Parameters:
role (Role, required):
The role of the message.
content (str, required):
The content of the message.
"""
role: Role
content: str

def to_json(self) -> Mapping[str, Any]:
return asdict(self)

@staticmethod
def from_json(json: Dict[str, Any]) -> "Message":
return Message(
role=Role(json["role"]),
content=json["content"],
)


@dataclass(frozen=True)
class ChatRequest:
"""
Describes a chat request.
Only supports a subset of the parameters of `CompletionRequest` for simplicity.
See `CompletionRequest` for documentation on the parameters.
"""
model: str
messages: List[Message]
maximum_tokens: Optional[int] = None
temperature: float = 0.0
top_k: int = 0
top_p: float = 0.0

def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in asdict(self).items() if v is not None}
payload["messages"] = [message.to_json() for message in self.messages]
return payload


@dataclass(frozen=True)
class ChatResponse:
"""
A simplified version of the chat response.
As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values),
the `ChatResponse` assumes there to be only one choice.
"""
finish_reason: str
message: Message

@staticmethod
def from_json(json: Dict[str, Any]) -> "ChatResponse":
first_choice = json["choices"][0]
return ChatResponse(
finish_reason=first_choice["finish_reason"],
message=Message.from_json(first_choice["message"]),
)


@dataclass(frozen=True)
class ChatStreamChunk:
"""
A streamed chat completion chunk.
Parameters:
content (str, required):
The content of the current chat completion. Will be empty for the first chunk of every completion stream and non-empty for the remaining chunks.
role (Role, optional):
The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks.
"""
content: str
role: Optional[Role]

@staticmethod
def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]:
"""
Returns a ChatStreamChunk if the chunk contains a message, otherwise None.
"""
if not (delta := json["choices"][0]["delta"]):
return None

return ChatStreamChunk(
content=delta["content"],
role=Role(delta.get("role")) if delta.get("role") else None,
)


@dataclass(frozen=True)
class StreamSummary:
pass
5 changes: 5 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def model_name() -> str:
return "luminous-base"


@pytest.fixture(scope="session")
def chat_model_name() -> str:
return "llama-3.1-70b-instruct"


@pytest.fixture(scope="session")
def prompt_image() -> Image:
image_source_path = Path(__file__).parent / "dog-and-cat-cover.jpg"
Expand Down
58 changes: 58 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.chat import ChatRequest, Message, Role
from tests.common import async_client, sync_client, model_name, chat_model_name


@pytest.mark.system_test
async def test_can_not_chat_with_all_models(async_client: AsyncClient, model_name: str):
request = ChatRequest(
messages=[Message(role=Role.User, content="Hello, how are you?")],
model=model_name,
)

with pytest.raises(ValueError):
await async_client.chat(request, model=model_name)


@pytest.mark.system_test
def test_can_chat_with_chat_model(sync_client: Client, chat_model_name: str):
request = ChatRequest(
messages=[Message(role=Role.User, content="Hello, how are you?")],
model=chat_model_name,
)

response = sync_client.chat(request, model=chat_model_name)
assert response.message.role == Role.Assistant
assert response.message.content is not None


@pytest.mark.system_test
async def test_can_chat_with_async_client(async_client: AsyncClient, chat_model_name: str):
system_msg = Message(role=Role.System, content="You are a helpful assistant.")
user_msg = Message(role=Role.User, content="Hello, how are you?")
request = ChatRequest(
messages=[system_msg, user_msg],
model=chat_model_name,
)

response = await async_client.chat(request, model=chat_model_name)
assert response.message.role == Role.Assistant
assert response.message.content is not None


@pytest.mark.system_test
async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_model_name: str):
request = ChatRequest(
messages=[Message(role=Role.User, content="Hello, how are you?")],
model=chat_model_name,
)

stream_items = [
stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name)
]

assert len(stream_items) >= 3
assert stream_items[0].role is not None
assert all(item.content is not None for item in stream_items[1:])

0 comments on commit d46ab1b

Please sign in to comment.