From aac907aadb2b819f853a2b465b03c789e3265769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 13 Dec 2024 07:22:05 -0800 Subject: [PATCH] riva: make sure we don't block on fastpitch --- src/pipecat/services/riva.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/pipecat/services/riva.py b/src/pipecat/services/riva.py index 470fe6dc9..8ab0e99d8 100644 --- a/src/pipecat/services/riva.py +++ b/src/pipecat/services/riva.py @@ -35,6 +35,8 @@ ) raise Exception(f"Missing module: {e}") +FASTPITCH_TIMEOUT_SECS = 5 + class FastPitchTTSService(TTSService): class InputParams(BaseModel): @@ -102,20 +104,23 @@ def add_response(r): logger.debug(f"Generating TTS: [{text}]") - queue = asyncio.Queue() - await asyncio.to_thread(read_audio_responses, queue) - - # Wait for the thread to start. - resp = await queue.get() - while resp: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=resp.audio, - sample_rate=self._sample_rate, - num_channels=1, - ) - yield frame - resp = await queue.get() + try: + queue = asyncio.Queue() + await asyncio.to_thread(read_audio_responses, queue) + + # Wait for the thread to start. + resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) + while resp: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame( + audio=resp.audio, + sample_rate=self._sample_rate, + num_channels=1, + ) + yield frame + resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) + except asyncio.TimeoutError: + logger.error(f"{self} timeout waiting for audio response") await self.start_tts_usage_metrics(text) yield TTSStoppedFrame()