Skip to content

Commit

Permalink
Merge pull request #187 from Aleph-Alpha/chat
Browse files Browse the repository at this point in the history
feat: add chat endpoint to sync and async client
  • Loading branch information
ahartel authored Oct 30, 2024
2 parents 9b84b47 + 8dafe22 commit b628e63
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 2 deletions.
109 changes: 107 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
CompletionRequest,
CompletionResponse,
CompletionResponseStreamItem,
StreamChunk,
stream_item_from_json,
)
from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk, Usage, stream_chat_item_from_json
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand Down Expand Up @@ -99,6 +99,7 @@ def _check_api_version(version_str: str):

AnyRequest = Union[
CompletionRequest,
ChatRequest,
EmbeddingRequest,
EvaluationRequest,
TokenizationRequest,
Expand Down Expand Up @@ -302,6 +303,34 @@ def complete(
response = self._post_request("complete", request, model)
return CompletionResponse.from_json(response)

def chat(
self,
request: ChatRequest,
model: str,
) -> ChatResponse:
"""Chat with a model.
Parameters:
request (ChatRequest, required):
Parameters for the requested chat.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = client.chat(request, model=model_name)
"""
response = self._post_request("chat/completions", request, model)
return ChatResponse.from_json(response)

def tokenize(
self,
request: TokenizationRequest,
Expand Down Expand Up @@ -797,7 +826,11 @@ async def _post_request_with_streaming(
f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}"
)

yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :])
payload = stream_item_as_str[len(self.SSE_DATA_PREFIX) :]
if payload == "[DONE]":
continue

yield json.loads(payload)

def _build_query_parameters(self) -> Mapping[str, str]:
return {
Expand Down Expand Up @@ -864,6 +897,38 @@ async def complete(
)
return CompletionResponse.from_json(response)

async def chat(
self,
request: ChatRequest,
model: str,
) -> ChatResponse:
"""Chat with a model.
Parameters:
request (ChatRequest, required):
Parameters for the requested chat.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = await client.chat(request, model=model_name)
"""
response = await self._post_request(
"chat/completions",
request,
model,
)
return ChatResponse.from_json(response)

async def complete_with_streaming(
self,
request: CompletionRequest,
Expand Down Expand Up @@ -905,6 +970,46 @@ async def complete_with_streaming(
):
yield stream_item_from_json(stream_item_json)

async def chat_with_streaming(
self,
request: ChatRequest,
model: str,
) -> AsyncGenerator[Union[ChatStreamChunk, Usage], None]:
"""Generates streamed chat completions.
The first yielded chunk contains the role, while subsequent chunks only contain the content delta.
Parameters:
request (ChatRequest, required):
Parameters for the requested chat.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = await client.chat_with_streaming(request, model=model_name)
>>>
>>> # consume the chat stream
>>> async for stream_item in result:
>>> do_something_with(stream_item)
"""
async for stream_item_json in self._post_request_with_streaming(
"chat/completions",
request,
model,
):
chunk = stream_chat_item_from_json(stream_item_json)
if chunk is not None:
yield chunk

async def tokenize(
self,
request: TokenizationRequest,
Expand Down
153 changes: 153 additions & 0 deletions aleph_alpha_client/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Mapping, Any, Dict, Union
from enum import Enum


class Role(str, Enum):
"""A role used for a message in a chat."""
User = "user"
Assistant = "assistant"
System = "system"


@dataclass(frozen=True)
class Message:
"""
Describes a message in a chat.
Parameters:
role (Role, required):
The role of the message.
content (str, required):
The content of the message.
"""
role: Role
content: str

def to_json(self) -> Mapping[str, Any]:
return asdict(self)

@staticmethod
def from_json(json: Dict[str, Any]) -> "Message":
return Message(
role=Role(json["role"]),
content=json["content"],
)


@dataclass(frozen=True)
class StreamOptions:
"""
Additional options to affect the streaming behavior.
"""
# If set, an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire
# request, and the choices field will always be an empty array.
include_usage: bool


@dataclass(frozen=True)
class ChatRequest:
"""
Describes a chat request.
Only supports a subset of the parameters of `CompletionRequest` for simplicity.
See `CompletionRequest` for documentation on the parameters.
"""
model: str
messages: List[Message]
maximum_tokens: Optional[int] = None
temperature: float = 0.0
top_k: int = 0
top_p: float = 0.0
stream_options: Optional[StreamOptions] = None

def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in asdict(self).items() if v is not None}
payload["messages"] = [message.to_json() for message in self.messages]
return payload


@dataclass(frozen=True)
class ChatResponse:
"""
A simplified version of the chat response.
As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values),
the `ChatResponse` assumes there to be only one choice.
"""
finish_reason: str
message: Message

@staticmethod
def from_json(json: Dict[str, Any]) -> "ChatResponse":
first_choice = json["choices"][0]
return ChatResponse(
finish_reason=first_choice["finish_reason"],
message=Message.from_json(first_choice["message"]),
)



@dataclass(frozen=True)
class Usage:
"""
Usage statistics for the completion request.
When streaming is enabled, this field will be null by default.
To include an additional usage-only message in the response stream, set stream_options.include_usage to true.
"""
# Number of tokens in the generated completion.
completion_tokens: int

# Number of tokens in the prompt.
prompt_tokens: int

# Total number of tokens used in the request (prompt + completion).
total_tokens: int

@staticmethod
def from_json(json: Dict[str, Any]) -> "Usage":
return Usage(
completion_tokens=json["completion_tokens"],
prompt_tokens=json["prompt_tokens"],
total_tokens=json["total_tokens"]
)



@dataclass(frozen=True)
class ChatStreamChunk:
"""
A streamed chat completion chunk.
Parameters:
content (str, required):
The content of the current chat completion. Will be empty for the first chunk of every completion stream and non-empty for the remaining chunks.
role (Role, optional):
The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks.
"""
content: str
role: Optional[Role]

@staticmethod
def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]:
"""
Returns a ChatStreamChunk if the chunk contains a message, otherwise None.
"""
if not (delta := json["choices"][0]["delta"]):
return None

return ChatStreamChunk(
content=delta["content"],
role=Role(delta.get("role")) if delta.get("role") else None,
)


def stream_chat_item_from_json(json: Dict[str, Any]) -> Union[Usage, ChatStreamChunk, None]:
if (usage := json.get("usage")) is not None:
return Usage.from_json(usage)

return ChatStreamChunk.from_json(json)
5 changes: 5 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def model_name() -> str:
return "luminous-base"


@pytest.fixture(scope="session")
def chat_model_name() -> str:
return "llama-3.1-70b-instruct"


@pytest.fixture(scope="session")
def prompt_image() -> Image:
image_source_path = Path(__file__).parent / "dog-and-cat-cover.jpg"
Expand Down
Loading

0 comments on commit b628e63

Please sign in to comment.