Skip to content

Commit

Permalink
fix: conversation item create and delete return a Future
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw committed Nov 14, 2024
1 parent 49c0663 commit 3c35c37
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 53 deletions.
5 changes: 5 additions & 0 deletions .changeset/itchy-lions-protect.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-openai": patch
---

make ConversationItem.create and delete return a Future in Realtime model
3 changes: 3 additions & 0 deletions examples/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def _on_agent_speech_created(msg: llm.ChatMessage):
chat_ctx = agent.chat_ctx_copy()
if len(chat_ctx.messages) > max_ctx_len:
chat_ctx.messages = chat_ctx.messages[-max_ctx_len:]
# NOTE: The `set_chat_ctx` function will attempt to synchronize changes made
# to the local chat context with the server instead of completely replacing it,
# provided that the message IDs are consistent.
asyncio.create_task(agent.set_chat_ctx(chat_ctx))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import AsyncIterable, Callable, Literal, Union, cast, overload
from typing import AsyncIterable, Literal, Union, cast, overload
from urllib.parse import urlencode

import aiohttp
Expand Down Expand Up @@ -485,13 +485,15 @@ def __init__(self, sess: RealtimeSession) -> None:

def create(
self, message: llm.ChatMessage, previous_item_id: str | None = None
) -> None:
) -> asyncio.Future[bool]:
fut = asyncio.Future[bool]()

message_content = message.content
tool_call_id = message.tool_call_id
if not tool_call_id and message_content is None:
# not a function call while the message content is None
return

fut.set_result(False)
return fut
event: api_proto.ClientEvent.ConversationItemCreate | None = None
if tool_call_id:
if message.role == "tool":
Expand All @@ -515,7 +517,8 @@ def create(
message,
extra=self._sess.logging_extra(),
)
return
fut.set_result(False)
return fut
if len(message.tool_calls) > 1:
logger.warning(
"function call message has multiple tool calls, "
Expand All @@ -541,7 +544,8 @@ def create(
message,
extra=self._sess.logging_extra(),
)
return
fut.set_result(False)
return fut
if not isinstance(message_content, list):
message_content = [message_content]

Expand Down Expand Up @@ -630,13 +634,18 @@ def create(
message,
extra=self._sess.logging_extra(),
)
return
fut.set_result(False)
return fut

self._sess._item_created_futs[message.id] = fut
self._sess._queue_msg(event)
return fut

def truncate(
self, *, item_id: str, content_index: int, audio_end_ms: int
) -> None:
) -> asyncio.Future[bool]:
fut = asyncio.Future[bool]()
self._sess._item_truncated_futs[item_id] = fut
self._sess._queue_msg(
{
"type": "conversation.item.truncate",
Expand All @@ -645,35 +654,18 @@ def truncate(
"audio_end_ms": audio_end_ms,
}
)
return fut

def delete(self, *, item_id: str) -> None:
def delete(self, *, item_id: str) -> asyncio.Future[bool]:
fut = asyncio.Future[bool]()
self._sess._item_deleted_futs[item_id] = fut
self._sess._queue_msg(
{
"type": "conversation.item.delete",
"item_id": item_id,
}
)

async def acreate(
self,
message: llm.ChatMessage,
previous_item_id: str | None = None,
_on_create_callback: Callable[[], None] | None = None,
) -> None:
fut = asyncio.Future[None]()
self._sess._item_created_futs[message.id] = fut
self.create(message, previous_item_id)
if _on_create_callback:
_on_create_callback()
await fut
del self._sess._item_created_futs[message.id]

async def adelete(self, *, item_id: str) -> None:
fut = asyncio.Future[None]()
self._sess._item_deleted_futs[item_id] = fut
self.delete(item_id=item_id)
await fut
del self._sess._item_deleted_futs[item_id]
return fut

class Conversation:
def __init__(self, sess: RealtimeSession) -> None:
Expand Down Expand Up @@ -710,8 +702,9 @@ def __init__(
self._remote_converstation_items = remote_items._RemoteConversationItems()

# wait for the item to be created or deleted
self._item_created_futs: dict[str, asyncio.Future[None]] = {}
self._item_deleted_futs: dict[str, asyncio.Future[None]] = {}
self._item_created_futs: dict[str, asyncio.Future[bool]] = {}
self._item_deleted_futs: dict[str, asyncio.Future[bool]] = {}
self._item_truncated_futs: dict[str, asyncio.Future[bool]] = {}

self._fnc_ctx = fnc_ctx
self._loop = loop
Expand Down Expand Up @@ -891,28 +884,16 @@ async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None:
)
logger.debug("added empty audio message to the chat context")

_futs = []
for msg in changes.to_delete:
fut = asyncio.Future[None]()
self._item_deleted_futs[msg.id] = fut
self.conversation.item.delete(item_id=msg.id)
_futs.append(fut)

for prev, msg in changes.to_add:
fut = asyncio.Future[None]()
self._item_created_futs[msg.id] = fut
_futs = [
self.conversation.item.delete(item_id=msg.id) for msg in changes.to_delete
] + [
self.conversation.item.create(msg, prev.id if prev else None)
_futs.append(fut)
for prev, msg in changes.to_add
]

# wait for all the futures to complete
await asyncio.gather(*_futs)

# clean up the futures
for msg in changes.to_delete:
del self._item_deleted_futs[msg.id]
for _, msg in changes.to_add:
del self._item_created_futs[msg.id]

def _update_converstation_item_content(
self, item_id: str, content: llm.ChatContent | list[llm.ChatContent] | None
) -> None:
Expand Down Expand Up @@ -1028,6 +1009,8 @@ async def _recv_task():
self._handle_conversation_item_created(data)
elif event == "conversation.item.deleted":
self._handle_conversation_item_deleted(data)
elif event == "conversation.item.truncated":
self._handle_conversation_item_truncated(data)
elif event == "response.created":
self._handle_response_created(data)
elif event == "response.output_item.added":
Expand Down Expand Up @@ -1184,7 +1167,8 @@ def _handle_conversation_item_created(
# Insert into conversation items
self._remote_converstation_items.insert_after(previous_item_id, message)
if item_id in self._item_created_futs:
self._item_created_futs[item_id].set_result(None)
self._item_created_futs[item_id].set_result(True)
del self._item_created_futs[item_id]
logger.debug("conversation item created", extra=item_created)

def _handle_conversation_item_deleted(
Expand All @@ -1194,9 +1178,18 @@ def _handle_conversation_item_deleted(
item_id = item_deleted["item_id"]
self._remote_converstation_items.delete(item_id)
if item_id in self._item_deleted_futs:
self._item_deleted_futs[item_id].set_result(None)
self._item_deleted_futs[item_id].set_result(True)
del self._item_deleted_futs[item_id]
logger.debug("conversation item deleted", extra=item_deleted)

def _handle_conversation_item_truncated(
self, item_truncated: api_proto.ServerEvent.ConversationItemTruncated
):
item_id = item_truncated["item_id"]
if item_id in self._item_truncated_futs:
self._item_truncated_futs[item_id].set_result(True)
del self._item_truncated_futs[item_id]

def _handle_response_created(
self, response_created: api_proto.ServerEvent.ResponseCreated
):
Expand Down Expand Up @@ -1414,11 +1407,12 @@ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str)
tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)

if called_fnc.result is not None:
await self.conversation.item.acreate(
create_fut = self.conversation.item.create(
tool_call,
previous_item_id=item_id,
_on_create_callback=self.response.create,
)
self.response.create()
await create_fut

# update the message with the tool call result
msg = self._remote_converstation_items.get(tool_call.id)
Expand Down

0 comments on commit 3c35c37

Please sign in to comment.