diff --git a/autogen/agentchat/realtime_agent/oai_realtime_client.py b/autogen/agentchat/realtime_agent/oai_realtime_client.py index d074c20eb8..6a104a5670 100644 --- a/autogen/agentchat/realtime_agent/oai_realtime_client.py +++ b/autogen/agentchat/realtime_agent/oai_realtime_client.py @@ -2,17 +2,21 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio +import json from contextlib import asynccontextmanager from logging import Logger, getLogger from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional -from asyncer import TaskGroup, create_task_group +import httpx from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection from .realtime_client import Role if TYPE_CHECKING: + from fastapi.websockets import WebSocket + from .realtime_client import RealtimeClientProtocol __all__ = ["OpenAIRealtimeClient", "Role"] @@ -168,8 +172,184 @@ async def read_events(self) -> AsyncGenerator[dict[str, Any], None]: self._connection = None -# needed for mypy to check if OpenAIRealtimeClient implements RealtimeClientProtocol +class OpenAIRealtimeWebRTCClient: + """(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol.""" + + def __init__( + self, + *, + llm_config: dict[str, Any], + voice: str, + system_message: str, + websocket: "WebSocket", + logger: Optional[Logger] = None, + ) -> None: + """(Experimental) Client for OpenAI Realtime API. + + Args: + llm_config (dict[str, Any]): The config for the client. + """ + self._llm_config = llm_config + self._voice = voice + self._system_message = system_message + self._logger = logger + self._websocket = websocket + + config = llm_config["config_list"][0] + self._model: str = config["model"] + self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr] + self._config = config + + @property + def logger(self) -> Logger: + """Get the logger for the OpenAI Realtime API.""" + return self._logger or global_logger + + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to the OpenAI Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + await self._websocket.send_json( + { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result, + }, + } + ) + await self._websocket.send_json({"type": "response.create"}) + + async def send_text(self, *, role: Role, text: str) -> None: + """Send a text message to the OpenAI Realtime API. + + Args: + role (str): The role of the message. + text (str): The text of the message. + """ + # await self.connection.response.cancel() #why is this here? + await self._websocket.send_json( + { + "type": "connection.conversation.item.create", + "item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}, + } + ) + # await self.connection.response.create() + + async def send_audio(self, audio: str) -> None: + """Send audio to the OpenAI Realtime API. + + Args: + audio (str): The audio to send. + """ + await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio}) + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + """Truncate audio in the OpenAI Realtime API. + + Args: + audio_end_ms (int): The end of the audio to truncate. + content_index (int): The index of the content to truncate. + item_id (str): The ID of the item to truncate. + """ + await self._websocket.send_json( + { + "type": "conversation.item.truncate", + "content_index": content_index, + "item_id": item_id, + "audio_end_ms": audio_end_ms, + } + ) + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Send a session update to the OpenAI Realtime API. + + In the case of WebRTC we can not send it directly, but we can send it + to the javascript over the websocket, and rely on it to send session + update to OpenAI + + Args: + session_options (dict[str, Any]): The session options to update. + """ + logger = self.logger + logger.info(f"Sending session update: {session_options}") + # await self.connection.session.update(session=session_options) # type: ignore[arg-type] + await self._websocket.send_json({"type": "session.update", "session": session_options}) + logger.info("Sending session update finished") + + async def _initialize_session(self) -> None: + """Control initial session with OpenAI.""" + session_update = { + "turn_detection": {"type": "server_vad"}, + "voice": self._voice, + "instructions": self._system_message, + "modalities": ["audio", "text"], + "temperature": self._temperature, + } + await self.session_update(session_options=session_update) + + @asynccontextmanager + async def connect(self) -> AsyncGenerator[None, None]: + """Connect to the OpenAI Realtime API. + + In the case of WebRTC, we pass connection information over the + websocket, so that javascript on the other end of websocket open + actual connection to OpenAI + """ + try: + url = "https://api.openai.com/v1/realtime/sessions" + api_key = self._config.get("api_key", None) + headers = { + "Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment + "Content-Type": "application/json", + } + data = { + # "model": "gpt-4o-realtime-preview-2024-12-17", + "model": self._model, + "voice": self._voice, + } + async with httpx.AsyncClient() as client: + response = await client.post(url, headers=headers, json=data) + response.raise_for_status() + json_data = response.json() + json_data["model"] = self._model + if self._websocket is not None: + await self._websocket.send_json({"type": "ag2.init", "config": json_data}) + await asyncio.sleep(10) + await self._initialize_session() + yield + finally: + pass + + async def read_events(self) -> AsyncGenerator[dict[str, Any], None]: + """Read messages from the OpenAI Realtime API. + Again, in case of WebRTC, we do not read OpenAI messages directly since we + do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript + client on the other side of the websocket that is connected to OpenAI is relaying events to us. + """ + logger = self.logger + while True: + try: + messageJSON = await self._websocket.receive_text() + message = json.loads(messageJSON) + if "function" in message["type"]: + logger.info("Received function message", message) + yield message + except Exception: + break + + +# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol if TYPE_CHECKING: _client: RealtimeClientProtocol = OpenAIRealtimeClient( llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant." ) + + def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol: + return OpenAIRealtimeWebRTCClient( + llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant.", websocket=websocket + ) diff --git a/autogen/agentchat/realtime_agent/realtime_agent.py b/autogen/agentchat/realtime_agent/realtime_agent.py index 04bee015d1..6f45bdea38 100644 --- a/autogen/agentchat/realtime_agent/realtime_agent.py +++ b/autogen/agentchat/realtime_agent/realtime_agent.py @@ -7,13 +7,16 @@ import anyio from asyncer import create_task_group, syncify +from fastapi import WebSocket + +from autogen.agentchat.realtime_agent.realtime_client import RealtimeClientProtocol from ... import SwarmAgent from ...tools import Tool, get_function_schema from ..agent import Agent from ..conversable_agent import ConversableAgent from .function_observer import FunctionObserver -from .oai_realtime_client import OpenAIRealtimeClient, Role +from .oai_realtime_client import OpenAIRealtimeClient, OpenAIRealtimeWebRTCClient, Role from .realtime_observer import RealtimeObserver F = TypeVar("F", bound=Callable[..., Any]) @@ -42,20 +45,22 @@ def __init__( self, *, name: str, - audio_adapter: RealtimeObserver, + audio_adapter: Optional[RealtimeObserver] = None, system_message: str = "You are a helpful AI Assistant.", llm_config: dict[str, Any], voice: str = "alloy", logger: Optional[Logger] = None, + websocket: Optional[WebSocket] = None, ): """(Experimental) Agent for interacting with the Realtime Clients. Args: name (str): The name of the agent. - audio_adapter (RealtimeObserver): The audio adapter for the agent. + audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent. system_message (str): The system message for the agent. llm_config (dict[str, Any], bool): The config for the agent. voice (str): The voice for the agent. + websocket (Optional[WebSocket] = None): WebSocket from WebRTC javascript client """ super().__init__( name=name, @@ -75,12 +80,20 @@ def __init__( self._logger = logger self._function_observer = FunctionObserver(logger=logger) self._audio_adapter = audio_adapter - self._realtime_client = OpenAIRealtimeClient( + self._realtime_client: RealtimeClientProtocol = OpenAIRealtimeClient( llm_config=llm_config, voice=voice, system_message=system_message, logger=logger ) + if websocket is not None: + self._realtime_client = OpenAIRealtimeWebRTCClient( + llm_config=llm_config, voice=voice, system_message=system_message, websocket=websocket, logger=logger + ) + self._voice = voice - self._observers: list[RealtimeObserver] = [self._function_observer, self._audio_adapter] + self._observers: list[RealtimeObserver] = [self._function_observer] + if self._audio_adapter: + # audio adapter is not needed for WebRTC + self._observers.append(self._audio_adapter) self._registred_realtime_tools: dict[str, Tool] = {} @@ -102,7 +115,7 @@ def logger(self) -> Logger: return self._logger or global_logger @property - def realtime_client(self) -> OpenAIRealtimeClient: + def realtime_client(self) -> RealtimeClientProtocol: """Get the OpenAI Realtime Client.""" return self._realtime_client diff --git a/notebook/agentchat_realtime_webrtc.ipynb b/notebook/agentchat_realtime_webrtc.ipynb new file mode 100644 index 0000000000..0163198332 --- /dev/null +++ b/notebook/agentchat_realtime_webrtc.ipynb @@ -0,0 +1,339 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# RealtimeAgent with WebRTC connection\n", + "\n", + "\n", + "AG2 supports **RealtimeAgent**, a powerful agent type that connects seamlessly to OpenAI's [Realtime API](https://openai.com/index/introducing-the-realtime-api). In this example we will start a local RealtimeAgent and register a mock get_weather function that the agent will be able to call.\n", + "\n", + "**Note**: This notebook cannot be run in Google Colab because it depends on local JavaScript files and HTML templates. To execute the notebook successfully, run it locally within the cloned project so that the `notebooks/agentchat_realtime_websocket/static` and `notebooks/agentchat_realtime_websocket/templates` folders are available in the correct relative paths.\n", + "\n", + "````{=mdx}\n", + ":::info Requirements\n", + "Install `ag2`:\n", + "```bash\n", + "git clone https://github.com/ag2ai/ag2.git\n", + "cd ag2\n", + "```\n", + ":::\n", + "````\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Install AG2 and dependencies\n", + "\n", + "To use the realtime agent we will connect it to a local websocket trough the browser.\n", + "\n", + "To be able to run this notebook, you will need to install ag2, fastapi and uvicorn.\n", + "````{=mdx}\n", + ":::info Requirements\n", + "Install `ag2`:\n", + "```bash\n", + "pip install \"ag2\", \"fastapi>=0.115.0,<1\", \"uvicorn>=0.30.6,<1\" \"flaml[automl]\"\n", + "```\n", + "For more information, please refer to the [installation guide](/docs/installation/Installation).\n", + ":::\n", + "````" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"ag2\" \"fastapi>=0.115.0,<1\" \"uvicorn>=0.30.6,<1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import the dependencies\n", + "\n", + "After installing the necessary requirements, we can import the necessary dependencies for the example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from logging import getLogger\n", + "from pathlib import Path\n", + "from typing import Annotated\n", + "\n", + "import uvicorn\n", + "from fastapi import FastAPI, Request, WebSocket\n", + "from fastapi.responses import HTMLResponse, JSONResponse\n", + "from fastapi.staticfiles import StaticFiles\n", + "from fastapi.templating import Jinja2Templates\n", + "\n", + "import autogen\n", + "from autogen.agentchat.realtime_agent import RealtimeAgent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare your `llm_config` and `realtime_llm_config`\n", + "\n", + "The [`config_list_from_json`](https://docs.ag2.ai/docs/reference/oai/openai_utils#config-list-from-json) function loads a list of configurations from an environment variable or a json file.\n", + "\n", + "## Important note\n", + "\n", + "Currenlty WebRTC can be used only by API keys the begin with:\n", + "\n", + "```\n", + "sk-proj\n", + "```\n", + "\n", + "and other keys may result internal server error (500) on OpenAI server. For more details see:\n", + "https://community.openai.com/t/realtime-api-create-sessions-results-in-500-internal-server-error/1060964/5\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "realtime_config_list = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\n", + " \"tags\": [\"gpt-4o-mini-realtime\"],\n", + " },\n", + ")\n", + "\n", + "realtime_llm_config = {\n", + " \"timeout\": 600,\n", + " \"config_list\": realtime_config_list,\n", + " \"temperature\": 0.8,\n", + "}\n", + "\n", + "assert realtime_config_list, (\n", + " \"No LLM found for the given model, please add the following lines to the OAI_CONFIG_LIST file:\"\n", + " \"\"\"\n", + " {\n", + " \"model\": \"gpt-4o-mini-realtime-preview\",\n", + " \"api_key\": \"sk-prod*********************...*\",\n", + " \"tags\": [\"gpt-4o-mini-realtime\", \"realtime\"]\n", + " }\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Before you start the server\n", + "\n", + "To run uvicorn server inside the notebook, you will need to use nest_asyncio. This is because Jupyter uses the asyncio event loop, and uvicorn uses its own event loop. nest_asyncio will allow uvicorn to run in Jupyter.\n", + "\n", + "Please install nest_asyncio by running the following cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install nest_asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Implementing and Running a Basic App\n", + "\n", + "Let us set up and execute a FastAPI application that integrates real-time agent interactions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define basic FastAPI app\n", + "\n", + "1. **Define Port**: Sets the `PORT` variable to `5050`, which will be used for the server.\n", + "2. **Initialize FastAPI App**: Creates a `FastAPI` instance named `app`, which serves as the main application.\n", + "3. **Define Root Endpoint**: Adds a `GET` endpoint at the root URL (`/`). When accessed, it returns a JSON response with the message `\"WebRTC AG2 Server is running!\"`.\n", + "\n", + "This sets up a basic FastAPI server and provides a simple health-check endpoint to confirm that the server is operational." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PORT = 5050\n", + "\n", + "app = FastAPI()\n", + "\n", + "\n", + "@app.get(\"/\", response_class=JSONResponse)\n", + "async def index_page():\n", + " return {\"message\": \"WebRTC AG2 Server is running!\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare `start-chat` endpoint\n", + "\n", + "1. **Set the Working Directory**: Define `notebook_path` as the current working directory using `os.getcwd()`.\n", + "2. **Mount Static Files**: Mount the `static` directory (inside `agentchat_realtime_webrtc`) to serve JavaScript, CSS, and other static assets under the `/static` path.\n", + "3. **Set Up Templates**: Configure Jinja2 to render HTML templates from the `templates` directory within `agentchat_realtime_webrtc`.\n", + "4. **Create the `/start-chat/` Endpoint**: Define a `GET` route that serves the `chat.html` template. Pass the client's `request` and the `port` variable to the template for rendering a dynamic page for the audio chat interface.\n", + "\n", + "This code sets up static file handling, template rendering, and a dedicated endpoint to deliver the chat interface.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "notebook_path = os.getcwd()\n", + "\n", + "app.mount(\"/static\", StaticFiles(directory=Path(notebook_path) / \"agentchat_realtime_webrtc\" / \"static\"), name=\"static\")\n", + "\n", + "# Templates for HTML responses\n", + "\n", + "templates = Jinja2Templates(directory=Path(notebook_path) / \"agentchat_realtime_webrtc\" / \"templates\")\n", + "\n", + "\n", + "@app.get(\"/start-chat/\", response_class=HTMLResponse)\n", + "async def start_chat(request: Request):\n", + " \"\"\"Endpoint to return the HTML page for audio chat.\"\"\"\n", + " port = PORT # Extract the client's port\n", + " return templates.TemplateResponse(\"chat.html\", {\"request\": request, \"port\": port})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare endpoint for AG2 backend websocket\n", + "\n", + "1. **Set Up the WebSocket Endpoint**: Define the `/session` WebSocket route to handle audio streaming.\n", + "2. **Accept WebSocket Connections**: Accept incoming WebSocket connections from clients.\n", + "3. **Initialize Logger**: Retrieve a logger instance for logging purposes.\n", + "4. **Set Up Realtime Agent**: Create a `RealtimeAgent` with the following:\n", + " - **Name**: `Weather Bot`.\n", + " - **System Message**: Introduces the AI assistant and its capabilities.\n", + " - **LLM Configuration**: Uses `realtime_llm_config` for language model settings.\n", + " - **Websocket**: Used by the RealtimeAgent backend to receive messages form WebRTC application.\n", + " - **Logger**: Logs activities for debugging and monitoring.\n", + "6. **Register a Realtime Function**: Add a function `get_weather` to the agent, allowing it to respond with basic weather information based on the provided `location`.\n", + "7. **Run the Agent**: Start the `realtime_agent` to handle interactions in real time.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@app.websocket(\"/session\")\n", + "async def handle_media_stream(websocket: WebSocket):\n", + " \"\"\"Handle WebSocket connections providing audio stream and OpenAI.\"\"\"\n", + " await websocket.accept()\n", + "\n", + " logger = getLogger(\"uvicorn.error\")\n", + "\n", + " realtime_agent = RealtimeAgent(\n", + " name=\"Weather Bot\",\n", + " system_message=\"Hello there! I am an AI voice assistant powered by Autogen and the OpenAI Realtime API. You can ask me about weather, jokes, or anything you can imagine. Start by saying 'How can I help you'?\",\n", + " llm_config=realtime_llm_config,\n", + " websocket=websocket,\n", + " logger=logger,\n", + " )\n", + "\n", + " @realtime_agent.register_realtime_function(name=\"get_weather\", description=\"Get the current weather\")\n", + " def get_weather(location: Annotated[str, \"city\"]) -> str:\n", + " logger.info(f\"Checking the weather: {location}\")\n", + " return \"The weather is cloudy.\" if location == \"Rome\" else \"The weather is sunny.\"\n", + "\n", + " await realtime_agent.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the app using uvicorn" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "uvicorn.run(app, host=\"0.0.0.0\", port=PORT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "front_matter": { + "description": "RealtimeAgent using websockets", + "tags": [ + "realtime", + "websockets" + ] + }, + "kernelspec": { + "display_name": ".venv-3.9", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebook/agentchat_realtime_webrtc/static/WebRTC.js b/notebook/agentchat_realtime_webrtc/static/WebRTC.js new file mode 100644 index 0000000000..419bd461dd --- /dev/null +++ b/notebook/agentchat_realtime_webrtc/static/WebRTC.js @@ -0,0 +1,76 @@ +export async function init(webSocketUrl) { + + let ws + const pc = new RTCPeerConnection(); + let dc = null; // data connection + + async function openRTC(data) { + const EPHEMERAL_KEY = data.client_secret.value; + + // Set up to play remote audio from the model + const audioEl = document.createElement("audio"); + audioEl.autoplay = true; + pc.ontrack = e => audioEl.srcObject = e.streams[0]; + + // Add local audio track for microphone input in the browser + const ms = await navigator.mediaDevices.getUserMedia({ + audio: true + }); + pc.addTrack(ms.getTracks()[0]); + + // Set up data channel for sending and receiving events + dc = pc.createDataChannel("oai-events"); + dc.addEventListener("message", (e) => { + // Realtime server events appear here! + const message = JSON.parse(e.data) + if (message.type.includes("function")) { + console.log("WebRTC function message", message) + ws.send(e.data) + } + }); + + // Start the session using the Session Description Protocol (SDP) + const offer = await pc.createOffer(); + await pc.setLocalDescription(offer); + + const baseUrl = "https://api.openai.com/v1/realtime"; + const model = data.model; + const sdpResponse = await fetch(`${baseUrl}?model=${model}`, { + method: "POST", + body: offer.sdp, + headers: { + Authorization: `Bearer ${EPHEMERAL_KEY}`, + "Content-Type": "application/sdp" + }, + }); + + const answer = { + type: "answer", + sdp: await sdpResponse.text(), + }; + await pc.setRemoteDescription(answer); + console.log("Connected to OpenAI WebRTC") + } + + ws = new WebSocket(webSocketUrl); + + ws.onopen = event => { + console.log("web socket opened") + } + + ws.onmessage = async event => { + const message = JSON.parse(event.data) + console.info("Received Message from AG2 backend", message) + const type = message.type + if (type == "ag2.init") { + await openRTC(message.config) + return + } + const messageJSON = JSON.stringify(message) + if (dc) { + dc.send(messageJSON) + } else { + console.log("DC not ready yet", message) + } + } +} diff --git a/notebook/agentchat_realtime_webrtc/static/main.js b/notebook/agentchat_realtime_webrtc/static/main.js new file mode 100644 index 0000000000..0b44401d98 --- /dev/null +++ b/notebook/agentchat_realtime_webrtc/static/main.js @@ -0,0 +1,3 @@ +import { init } from './WebRTC.js'; + +init(socketUrl) diff --git a/notebook/agentchat_realtime_webrtc/templates/chat.html b/notebook/agentchat_realtime_webrtc/templates/chat.html new file mode 100644 index 0000000000..aee1ee6abc --- /dev/null +++ b/notebook/agentchat_realtime_webrtc/templates/chat.html @@ -0,0 +1,20 @@ + + + + + + Ag2 WebRTC Chat + + + + +

Ag2 WebRTC Chat

+

Ensure microphone and speaker access is enabled.

+ + You may try asking about weather in some cities. + +