diff --git a/autogen/agentchat/realtime_agent/oai_realtime_client.py b/autogen/agentchat/realtime_agent/oai_realtime_client.py index 902b5d543..de4945011 100644 --- a/autogen/agentchat/realtime_agent/oai_realtime_client.py +++ b/autogen/agentchat/realtime_agent/oai_realtime_client.py @@ -9,14 +9,14 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional import httpx -from asyncer import TaskGroup, create_task_group -from fastapi import WebSocket from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection from .realtime_client import Role if TYPE_CHECKING: + from fastapi.websockets import WebSocket + from .realtime_client import RealtimeClientProtocol __all__ = ["OpenAIRealtimeClient", "Role"] @@ -172,13 +172,6 @@ async def read_events(self) -> AsyncGenerator[dict[str, Any], None]: self._connection = None -# needed for mypy to check if OpenAIRealtimeClient implements RealtimeClientProtocol -if TYPE_CHECKING: - _client: RealtimeClientProtocol = OpenAIRealtimeClient( - llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant." - ) - - class OpenAIRealtimeWebRTCClient: """(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol.""" @@ -188,7 +181,7 @@ def __init__( llm_config: dict[str, Any], voice: str, system_message: str, - websocket: Optional[WebSocket] = None, + websocket: "WebSocket", logger: Optional[Logger] = None, ) -> None: """(Experimental) Client for OpenAI Realtime API. @@ -200,7 +193,7 @@ def __init__( self._voice = voice self._system_message = system_message self._logger = logger - self._websocket: Optional[WebSocket] = websocket + self._websocket = websocket config = llm_config["config_list"][0] self._model: str = config["model"] @@ -219,18 +212,17 @@ async def send_function_result(self, call_id: str, result: str) -> None: call_id (str): The ID of the function call. result (str): The result of the function call. """ - if self._websocket is not None: - await self._websocket.send_json( - { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": call_id, - "output": result, - }, - } - ) - await self._websocket.send_json({"type": "response.create"}) + await self._websocket.send_json( + { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result, + }, + } + ) + await self._websocket.send_json({"type": "response.create"}) async def send_text(self, *, role: Role, text: str) -> None: """Send a text message to the OpenAI Realtime API. @@ -240,14 +232,13 @@ async def send_text(self, *, role: Role, text: str) -> None: text (str): The text of the message. """ # await self.connection.response.cancel() #why is this here? - if self._websocket is not None: - await self._websocket.send_json( - { - "type": "connection.conversation.item.create", - "item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}, - } - ) - # await self.connection.response.create() + await self._websocket.send_json( + { + "type": "connection.conversation.item.create", + "item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}, + } + ) + # await self.connection.response.create() async def send_audio(self, audio: str) -> None: """Send audio to the OpenAI Realtime API. @@ -255,9 +246,7 @@ async def send_audio(self, audio: str) -> None: Args: audio (str): The audio to send. """ - ... - if self._websocket is not None: - await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio}) + await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio}) async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: """Truncate audio in the OpenAI Realtime API. @@ -267,15 +256,14 @@ async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: s content_index (int): The index of the content to truncate. item_id (str): The ID of the item to truncate. """ - if self._websocket is not None: - await self._websocket.send_json( - { - "type": "conversation.item.truncate", - "content_index": content_index, - "item_id": item_id, - "audio_end_ms": audio_end_ms, - } - ) + await self._websocket.send_json( + { + "type": "conversation.item.truncate", + "content_index": content_index, + "item_id": item_id, + "audio_end_ms": audio_end_ms, + } + ) async def session_update(self, session_options: dict[str, Any]) -> None: """Send a session update to the OpenAI Realtime API. @@ -290,9 +278,8 @@ async def session_update(self, session_options: dict[str, Any]) -> None: logger = self.logger logger.info(f"Sending session update: {session_options}") # await self.connection.session.update(session=session_options) # type: ignore[arg-type] - if self._websocket is not None: - await self._websocket.send_json({"type": "session.update", "session": session_options}) - logger.info("Sending session update finished") + await self._websocket.send_json({"type": "session.update", "session": session_options}) + logger.info("Sending session update finished") async def _initialize_session(self) -> None: """Control initial session with OpenAI.""" @@ -344,24 +331,25 @@ async def read_events(self) -> AsyncGenerator[dict[str, Any], None]: do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript client on the other side of the websocket that is connected to OpenAI is relaying events to us. """ - try: - logger = self.logger - if self._websocket is not None: - while True: - try: - messageJSON = await self._websocket.receive_text() - message = json.loads(messageJSON) - if "function" in message["type"]: - logger.info("Received function message", message) - yield message - except Exception: - break - finally: - self._websocket = None + logger = self.logger + while True: + try: + messageJSON = await self._websocket.receive_text() + message = json.loads(messageJSON) + if "function" in message["type"]: + logger.info("Received function message", message) + yield message + except Exception: + break # needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol if TYPE_CHECKING: - _rtc_client: RealtimeClientProtocol = OpenAIRealtimeWebRTCClient( - llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant.", websocket=None + _client: RealtimeClientProtocol = OpenAIRealtimeClient( + llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant." ) + + def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol: + return OpenAIRealtimeWebRTCClient( + llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant.", websocket=websocket + )