Skip to content

Commit

Permalink
polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Jan 8, 2025
1 parent 739ac00 commit 06da9de
Showing 1 changed file with 50 additions and 62 deletions.
112 changes: 50 additions & 62 deletions autogen/agentchat/realtime_agent/oai_realtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

import httpx
from asyncer import TaskGroup, create_task_group
from fastapi import WebSocket
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection

from .realtime_client import Role

if TYPE_CHECKING:
from fastapi.websockets import WebSocket

from .realtime_client import RealtimeClientProtocol

__all__ = ["OpenAIRealtimeClient", "Role"]
Expand Down Expand Up @@ -172,13 +172,6 @@ async def read_events(self) -> AsyncGenerator[dict[str, Any], None]:
self._connection = None


# needed for mypy to check if OpenAIRealtimeClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = OpenAIRealtimeClient(
llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant."
)


class OpenAIRealtimeWebRTCClient:
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""

Expand All @@ -188,7 +181,7 @@ def __init__(
llm_config: dict[str, Any],
voice: str,
system_message: str,
websocket: Optional[WebSocket] = None,
websocket: "WebSocket",
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Expand All @@ -200,7 +193,7 @@ def __init__(
self._voice = voice
self._system_message = system_message
self._logger = logger
self._websocket: Optional[WebSocket] = websocket
self._websocket = websocket

config = llm_config["config_list"][0]
self._model: str = config["model"]
Expand All @@ -219,18 +212,17 @@ async def send_function_result(self, call_id: str, result: str) -> None:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
if self._websocket is not None:
await self._websocket.send_json(
{
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
)
await self._websocket.send_json({"type": "response.create"})
await self._websocket.send_json(
{
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
)
await self._websocket.send_json({"type": "response.create"})

async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Expand All @@ -240,24 +232,21 @@ async def send_text(self, *, role: Role, text: str) -> None:
text (str): The text of the message.
"""
# await self.connection.response.cancel() #why is this here?
if self._websocket is not None:
await self._websocket.send_json(
{
"type": "connection.conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
}
)
# await self.connection.response.create()
await self._websocket.send_json(
{
"type": "connection.conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
}
)
# await self.connection.response.create()

async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
Args:
audio (str): The audio to send.
"""
...
if self._websocket is not None:
await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio})
await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio})

async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Expand All @@ -267,15 +256,14 @@ async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: s
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
if self._websocket is not None:
await self._websocket.send_json(
{
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
}
)
await self._websocket.send_json(
{
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
}
)

async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Expand All @@ -290,9 +278,8 @@ async def session_update(self, session_options: dict[str, Any]) -> None:
logger = self.logger
logger.info(f"Sending session update: {session_options}")
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
if self._websocket is not None:
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")

async def _initialize_session(self) -> None:
"""Control initial session with OpenAI."""
Expand Down Expand Up @@ -344,24 +331,25 @@ async def read_events(self) -> AsyncGenerator[dict[str, Any], None]:
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
"""
try:
logger = self.logger
if self._websocket is not None:
while True:
try:
messageJSON = await self._websocket.receive_text()
message = json.loads(messageJSON)
if "function" in message["type"]:
logger.info("Received function message", message)
yield message
except Exception:
break
finally:
self._websocket = None
logger = self.logger
while True:
try:
messageJSON = await self._websocket.receive_text()
message = json.loads(messageJSON)
if "function" in message["type"]:
logger.info("Received function message", message)
yield message
except Exception:
break


# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_rtc_client: RealtimeClientProtocol = OpenAIRealtimeWebRTCClient(
llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant.", websocket=None
_client: RealtimeClientProtocol = OpenAIRealtimeClient(
llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant."
)

def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
return OpenAIRealtimeWebRTCClient(
llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant.", websocket=websocket
)

0 comments on commit 06da9de

Please sign in to comment.