From 41f81d0092a5e1e71b5857bc0002582a24dc2d12 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Fri, 8 Nov 2024 17:25:22 +0400 Subject: [PATCH] chore: reintroducing `writerai.types.chat_chat_params.Message` --- src/writer/ai.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index bb260a9ad..6ed7d86fe 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -38,6 +38,7 @@ from writerai.types.application_generate_content_params import Input from writerai.types.chat import ChoiceMessage, ChoiceMessageGraphData, ChoiceMessageToolCall from writerai.types.chat_chat_params import Message as WriterAIMessage +from writerai.types.chat_chat_params import MessageGraphData from writerai.types.chat_chat_params import ToolFunctionTool as SDKFunctionTool from writerai.types.chat_chat_params import ToolGraphTool as SDKGraphTool @@ -94,22 +95,6 @@ class FunctionTool(Tool): parameters: Dict[str, Dict[str, str]] -class PreparedAPIMessage(TypedDict, total=False): - role: Literal["user", "assistant", "system", "tool"] - - content: Union[str, None] - - name: Optional[str] - - tool_call_id: Optional[str] - - tool_calls: Optional[List[ChoiceMessageToolCall]] - - graph_data: Optional[ChoiceMessageGraphData] - - refusal: Optional[str] - - def create_function_tool( callable: Callable, name: str, @@ -1044,7 +1029,7 @@ def _clear_chunk_flag(chunk): updated_last_message |= clear_chunk @staticmethod - def _prepare_message(message: 'Conversation.Message') -> PreparedAPIMessage: + def _prepare_message(message: 'Conversation.Message') -> WriterAIMessage: """ Converts a message object stored in Conversation to a Writer AI SDK `Message` model, suitable for calls to API. @@ -1055,7 +1040,7 @@ def _prepare_message(message: 'Conversation.Message') -> PreparedAPIMessage: """ if not ("role" in message and "content" in message): raise ValueError("Improper message format") - sdk_message = PreparedAPIMessage( + sdk_message = WriterAIMessage( content=message["content"] or None, role=message["role"] ) @@ -1067,7 +1052,7 @@ def _prepare_message(message: 'Conversation.Message') -> PreparedAPIMessage: sdk_message["tool_calls"] = cast(list, msg_tool_calls) if msg_graph_data := message.get("graph_data"): sdk_message["graph_data"] = cast( - ChoiceMessageGraphData, + MessageGraphData, msg_graph_data ) if msg_refusal := message.get("refusal"): @@ -1350,13 +1335,10 @@ def _send_chat_request( a Stream or a Chat object. """ client = WriterAIManager.acquire_client() - prepared_messages = cast( - Iterable[WriterAIMessage], - [ + prepared_messages = [ self._prepare_message(message) for message in self.messages ] - ) logging.debug( "Attempting to request a message from LLM: " + f"prepared messages – {prepared_messages}, " +