Skip to content

Commit

Permalink
realtimeapi code
Browse files Browse the repository at this point in the history
  • Loading branch information
saudsami committed Oct 3, 2024
1 parent bdc3163 commit 70d8e37
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 150 deletions.
267 changes: 134 additions & 133 deletions shared/open-ai-integration/complete-code.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -591,129 +591,6 @@ Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. Act
</CodeBlock>
</details>

<details>
<summary>`realtimeapi/connection.py`</summary>
<CodeBlock showLineNumbers language="python">
{`import asyncio
import base64
import json
import logging
import os
import aiohttp
from typing import Any, AsyncGenerator
from .struct import InputAudioBufferAppend, ClientToServerMessage, ServerToClientMessage, parse_server_message, to_json
from ..logger import setup_logger
# Set up the logger with color and timestamp support
logger = setup_logger(name=__name__, log_level=logging.INFO)
DEFAULT_VIRTUAL_MODEL = "gpt-4o-realtime-preview"
def smart_str(s: str, max_field_len: int = 128) -> str:
"""parse string as json, truncate data field to 128 characters, reserialize"""
try:
data = json.loads(s)
if "delta" in data:
key = "delta"
elif "audio" in data:
key = "audio"
else:
return s
if len(data[key]) > max_field_len:
data[key] = data[key][:max_field_len] + "..."
return json.dumps(data)
except json.JSONDecodeError:
return s
class RealtimeApiConnection:
def __init__(
self,
base_uri: str,
api_key: str | None = None,
path: str = "/v1/realtime",
verbose: bool = False,
model: str = DEFAULT_VIRTUAL_MODEL,
):
self.url = f"{base_uri}{path}"
if "model=" not in self.url:
self.url += f"?model={model}"
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.websocket: aiohttp.ClientWebSocketResponse | None = None
self.verbose = verbose
self.session = aiohttp.ClientSession()
async def __aenter__(self) -> "RealtimeApiConnection":
await self.connect()
return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
await self.close()
return False
async def connect(self):
auth = aiohttp.BasicAuth("", self.api_key) if self.api_key else None
headers = {"OpenAI-Beta": "realtime=v1"}
self.websocket = await self.session.ws_connect(
url=self.url,
auth=auth,
headers=headers,
)
async def send_audio_data(self, audio_data: bytes):
"""audio_data is assumed to be pcm16 24kHz mono little-endian"""
base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
message = InputAudioBufferAppend(audio=base64_audio_data)
await self.send_request(message)
async def send_request(self, message: ClientToServerMessage):
assert self.websocket is not None
message_str = to_json(message)
if self.verbose:
logger.info(f"-> {smart_str(message_str)}")
await self.websocket.send_str(message_str)
async def listen(self) -> AsyncGenerator[ServerToClientMessage, None]:
assert self.websocket is not None
if self.verbose:
logger.info("Listening for realtimeapi messages")
try:
async for msg in self.websocket:
if msg.type == aiohttp.WSMsgType.TEXT:
if self.verbose:
logger.info(f"<- {smart_str(msg.data)}")
yield self.handle_server_message(msg.data)
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error("Error during receive: %s", self.websocket.exception())
break
except asyncio.CancelledError:
logger.info("Receive messages task cancelled")
def handle_server_message(self, message: str) -> ServerToClientMessage:
try:
return parse_server_message(message)
except Exception as e:
logger.error("Error handling message: " + str(e))
raise e
async def close(self):
# Close the websocket connection if it exists
if self.websocket:
await self.websocket.close()
self.websocket = None
`}
</CodeBlock>
</details>

<details>
<summary>`tools.py`</summary>
<CodeBlock showLineNumbers language="python">
Expand Down Expand Up @@ -1019,6 +896,130 @@ def parse_args_realtimekit() -> RealtimeKitOptions:
</CodeBlock>
</details>


<details>
<summary>`realtimeapi/connection.py`</summary>
<CodeBlock showLineNumbers language="python">
{`import asyncio
import base64
import json
import logging
import os
import aiohttp
from typing import Any, AsyncGenerator
from .struct import InputAudioBufferAppend, ClientToServerMessage, ServerToClientMessage, parse_server_message, to_json
from ..logger import setup_logger
# Set up the logger with color and timestamp support
logger = setup_logger(name=__name__, log_level=logging.INFO)
DEFAULT_VIRTUAL_MODEL = "gpt-4o-realtime-preview"
def smart_str(s: str, max_field_len: int = 128) -> str:
"""parse string as json, truncate data field to 128 characters, reserialize"""
try:
data = json.loads(s)
if "delta" in data:
key = "delta"
elif "audio" in data:
key = "audio"
else:
return s
if len(data[key]) > max_field_len:
data[key] = data[key][:max_field_len] + "..."
return json.dumps(data)
except json.JSONDecodeError:
return s
class RealtimeApiConnection:
def __init__(
self,
base_uri: str,
api_key: str | None = None,
path: str = "/v1/realtime",
verbose: bool = False,
model: str = DEFAULT_VIRTUAL_MODEL,
):
self.url = f"{base_uri}{path}"
if "model=" not in self.url:
self.url += f"?model={model}"
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.websocket: aiohttp.ClientWebSocketResponse | None = None
self.verbose = verbose
self.session = aiohttp.ClientSession()
async def __aenter__(self) -> "RealtimeApiConnection":
await self.connect()
return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
await self.close()
return False
async def connect(self):
auth = aiohttp.BasicAuth("", self.api_key) if self.api_key else None
headers = {"OpenAI-Beta": "realtime=v1"}
self.websocket = await self.session.ws_connect(
url=self.url,
auth=auth,
headers=headers,
)
async def send_audio_data(self, audio_data: bytes):
"""audio_data is assumed to be pcm16 24kHz mono little-endian"""
base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
message = InputAudioBufferAppend(audio=base64_audio_data)
await self.send_request(message)
async def send_request(self, message: ClientToServerMessage):
assert self.websocket is not None
message_str = to_json(message)
if self.verbose:
logger.info(f"-> {smart_str(message_str)}")
await self.websocket.send_str(message_str)
async def listen(self) -> AsyncGenerator[ServerToClientMessage, None]:
assert self.websocket is not None
if self.verbose:
logger.info("Listening for realtimeapi messages")
try:
async for msg in self.websocket:
if msg.type == aiohttp.WSMsgType.TEXT:
if self.verbose:
logger.info(f"<- {smart_str(msg.data)}")
yield self.handle_server_message(msg.data)
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error("Error during receive: %s", self.websocket.exception())
break
except asyncio.CancelledError:
logger.info("Receive messages task cancelled")
def handle_server_message(self, message: str) -> ServerToClientMessage:
try:
return parse_server_message(message)
except Exception as e:
logger.error("Error handling message: " + str(e))
raise e
async def close(self):
# Close the websocket connection if it exists
if self.websocket:
await self.websocket.close()
self.websocket = None
`}
</CodeBlock>
</details>

<details>
<summary>`realtimeapi/struct.py`</summary>
<CodeBlock showLineNumbers language="python">
Expand Down Expand Up @@ -1119,13 +1120,13 @@ class SessionUpdateParams:
model: Optional[str] = None # Optional string to specify the model
modalities: Optional[Set[str]] = None # Set of allowed modalities (e.g., "text", "audio")
instructions: Optional[str] = None # Optional instructions string
voice: Optional[Voices] = None # Voice selection, can be `None` or from `Voices` Enum
voice: Optional[Voices] = None # Voice selection, can be \`None\` or from \`Voices\` Enum
turn_detection: Optional[ServerVADUpdateParams] = None # Server VAD update params
input_audio_format: Optional[AudioFormats] = None # Input audio format from `AudioFormats` Enum
output_audio_format: Optional[AudioFormats] = None # Output audio format from `AudioFormats` Enum
input_audio_format: Optional[AudioFormats] = None # Input audio format from \`AudioFormats\` Enum
output_audio_format: Optional[AudioFormats] = None # Output audio format from \`AudioFormats\` Enum
input_audio_transcription: Optional[InputAudioTranscription] = None # Optional transcription model
tools: Optional[List[Dict[str, Union[str, any]]]] = None # List of tools (e.g., dictionaries)
tool_choice: Optional[ToolChoice] = None # ToolChoice, either string or `FunctionToolChoice`
tool_choice: Optional[ToolChoice] = None # ToolChoice, either string or \`FunctionToolChoice\`
temperature: Optional[float] = None # Optional temperature for response generation
max_response_output_tokens: Optional[Union[int, str]] = None # Max response tokens, "inf" for infinite
Expand Down Expand Up @@ -1568,7 +1569,7 @@ class InputAudioBufferClear(ClientToServerMessage):
@dataclass
class ItemCreate(ClientToServerMessage):
item: Optional[ItemParam] = field(default=None) # Assuming `ItemParam` is already defined
item: Optional[ItemParam] = field(default=None) # Assuming \`ItemParam\` is already defined
type: str = EventType.ITEM_CREATE
previous_item_id: Optional[str] = None
Expand Down Expand Up @@ -1605,7 +1606,7 @@ class ResponseCreateParams:
@dataclass
class ResponseCreate(ClientToServerMessage):
type: str = EventType.RESPONSE_CREATE
response: Optional[ResponseCreateParams] = None # Assuming `ResponseCreateParams` is defined
response: Optional[ResponseCreateParams] = None # Assuming \`ResponseCreateParams\` is defined
@dataclass
Expand All @@ -1631,7 +1632,7 @@ class UpdateConversationConfig(ClientToServerMessage):
@dataclass
class SessionUpdate(ClientToServerMessage):
session: Optional[SessionUpdateParams] = field(default=None) # Assuming `SessionUpdateParams` is defined
session: Optional[SessionUpdateParams] = field(default=None) # Assuming \`SessionUpdateParams\` is defined
type: str = EventType.SESSION_UPDATE
Expand Down Expand Up @@ -1662,7 +1663,7 @@ def from_dict(data_class, data):
def parse_client_message(unparsed_string: str) -> ClientToServerMessage:
data = json.loads(unparsed_string)
# Dynamically select the correct message class based on the `type` field, using from_dict
# Dynamically select the correct message class based on the \`type\` field, using from_dict
if data["type"] == EventType.INPUT_AUDIO_BUFFER_APPEND:
return from_dict(InputAudioBufferAppend, data)
elif data["type"] == EventType.INPUT_AUDIO_BUFFER_COMMIT:
Expand All @@ -1688,12 +1689,12 @@ def parse_client_message(unparsed_string: str) -> ClientToServerMessage:
# Assuming all necessary classes and enums (EventType, ServerToClientMessages, etc.) are imported
# Here’s how you can dynamically parse a server-to-client message based on the `type` field:
# Here’s how you can dynamically parse a server-to-client message based on the \`type\` field:
def parse_server_message(unparsed_string: str) -> ServerToClientMessage:
data = json.loads(unparsed_string)
# Dynamically select the correct message class based on the `type` field, using from_dict
# Dynamically select the correct message class based on the \`type\` field, using from_dict
if data["type"] == EventType.ERROR:
return from_dict(ErrorMessage, data)
elif data["type"] == EventType.SESSION_CREATED:
Expand Down
Loading

0 comments on commit 70d8e37

Please sign in to comment.