Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel main task in aclose #1003

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/young-walls-fry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

Cancel the _main_atask on aclose
60 changes: 26 additions & 34 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
self._speech_q_changed = asyncio.Event()

self._update_state_task: asyncio.Task | None = None
self._main_atask: asyncio.Task | None = None

self._last_final_transcript_time: float | None = None
self._last_speech_time: float | None = None
Expand Down Expand Up @@ -437,6 +438,9 @@ async def aclose(self) -> None:
if not self._started:
return

if self._main_atask is not None:
await utils.aio.gracefully_cancel(self._main_atask)

self._room.off("participant_connected", self._on_participant_connected)
await self._deferred_validation.aclose()

Expand Down Expand Up @@ -654,48 +658,43 @@ async def _play_speech(self, speech_handle: SpeechHandle) -> None:

await self._agent_publication.wait_for_subscription()

synthesis_handle = speech_handle.synthesis_handle
if synthesis_handle.interrupted:
if speech_handle.interrupted:
return

user_question = speech_handle.user_question

play_handle = synthesis_handle.play()
play_handle = speech_handle.synthesis_handle.play()
join_fut = play_handle.join()

def _commit_user_question_if_needed() -> None:
if (
not user_question
or synthesis_handle.interrupted
not speech_handle.user_question
or speech_handle.interrupted
or speech_handle.user_commited
):
return

is_using_tools = isinstance(speech_handle.source, LLMStream) and len(
speech_handle.source.function_calls
)

# make sure at least some speech was played before committing the user message
# since we try to validate as fast as possible it is possible the agent gets interrupted
# really quickly (barely audible), we don't want to mark this question as "answered".
if (
speech_handle.allow_interruptions
and not is_using_tools
and not speech_handle.is_using_tools()
and (
play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT
and not join_fut.done()
)
):
return

user_msg = ChatMessage.create(text=user_question, role="user")
user_msg = ChatMessage.create(text=speech_handle.user_question, role="user")
self._chat_ctx.messages.append(user_msg)
self.emit("user_speech_committed", user_msg)

self._transcribed_text = self._transcribed_text[len(user_question) :]
self._transcribed_text = self._transcribed_text[
len(speech_handle.user_question) :
]
speech_handle.mark_user_commited()

# wait for the play_handle to finish and check every 1s if the user question should be committed
# wait for the play_handle to finish and check every 0.2s if the user question should be committed
_commit_user_question_if_needed()

while not join_fut.done():
Expand All @@ -710,20 +709,14 @@ def _commit_user_question_if_needed() -> None:

_commit_user_question_if_needed()

collected_text = speech_handle.synthesis_handle.tts_forwarder.played_text
interrupted = speech_handle.interrupted
is_using_tools = isinstance(speech_handle.source, LLMStream) and len(
speech_handle.source.function_calls
)

extra_tools_messages = [] # additional messages from the functions to add to the context if needed

# if the answer is using tools, execute the functions and automatically generate
# a response to the user question from the returned values
if is_using_tools and not interrupted:
if speech_handle.is_using_tools() and not speech_handle.interrupted:
assert isinstance(speech_handle.source, LLMStream)
assert (
not user_question or speech_handle.user_commited
not speech_handle.user_question or speech_handle.user_commited
), "user speech should have been committed before using tools"

llm_stream = speech_handle.source
Expand Down Expand Up @@ -778,7 +771,9 @@ def _commit_user_question_if_needed() -> None:

# generate an answer from the tool calls
extra_tools_messages.append(
ChatMessage.create_tool_calls(tool_calls_info, text=collected_text)
ChatMessage.create_tool_calls(
tool_calls_info, text=speech_handle.collected_text()
)
)
extra_tools_messages.extend(tool_calls_results)

Expand All @@ -799,8 +794,6 @@ def _commit_user_question_if_needed() -> None:
play_handle = answer_synthesis.play()
await play_handle.join()

collected_text = answer_synthesis.tts_forwarder.played_text
interrupted = answer_synthesis.interrupted
new_function_calls = answer_llm_stream.function_calls

self.emit("function_calls_finished", called_fncs)
Expand All @@ -811,28 +804,27 @@ def _commit_user_question_if_needed() -> None:
_CallContextVar.reset(tk)

if speech_handle.add_to_chat_ctx and (
not user_question or speech_handle.user_commited
not speech_handle.user_question or speech_handle.user_commited
):
self._chat_ctx.messages.extend(extra_tools_messages)

if interrupted:
collected_text += "..."

msg = ChatMessage.create(text=collected_text, role="assistant")
msg = ChatMessage.create(
text=speech_handle.collected_text(), role="assistant"
)
self._chat_ctx.messages.append(msg)

speech_handle.mark_speech_commited()

if interrupted:
if speech_handle.interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)

logger.debug(
"committed agent speech",
extra={
"agent_transcript": collected_text,
"interrupted": interrupted,
"agent_transcript": speech_handle.collected_text(),
"interrupted": speech_handle.interrupted,
"speech_id": speech_handle.id,
},
)
Expand Down
8 changes: 8 additions & 0 deletions livekit-agents/livekit/agents/pipeline/speech_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def create_assistant_speech(
user_question="",
)

def collected_text(self) -> str:
if self.interrupted:
return self.synthesis_handle.tts_forwarder.played_text + "..."
return self.synthesis_handle.tts_forwarder.played_text

def is_using_tools(self) -> bool:
return isinstance(self.source, LLMStream) and len(self.source.function_calls)

async def wait_for_initialization(self) -> None:
await asyncio.shield(self._init_fut)

Expand Down
Loading