diff --git a/wyoming_satellite/satellite.py b/wyoming_satellite/satellite.py index 9c651d4..336fb86 100644 --- a/wyoming_satellite/satellite.py +++ b/wyoming_satellite/satellite.py @@ -5,6 +5,7 @@ import math import time import wave +from dataclasses import dataclass from enum import Enum, auto from pathlib import Path from typing import Callable, Dict, Final, List, Optional, Set, Union @@ -60,6 +61,12 @@ class State(Enum): STOPPED = auto() +@dataclass +class SoundEvent: + event: Event + is_tts: bool + + # ----------------------------------------------------------------------------- @@ -76,7 +83,7 @@ def __init__(self, settings: SatelliteSettings) -> None: self._mic_task: Optional[asyncio.Task] = None self._mic_webrtc: Optional[Callable[[bytes], bytes]] = None self._snd_task: Optional[asyncio.Task] = None - self._snd_queue: "Optional[asyncio.Queue[Event]]" = None + self._snd_queue: "Optional[asyncio.Queue[SoundEvent]]" = None self._wake_task: Optional[asyncio.Task] = None self._wake_queue: "Optional[asyncio.Queue[Event]]" = None self._event_task: Optional[asyncio.Task] = None @@ -543,10 +550,10 @@ def _process_mic_audio(self, audio_bytes: bytes) -> bytes: # Sound # ------------------------------------------------------------------------- - async def event_to_snd(self, event: Event) -> None: + async def event_to_snd(self, event: Event, is_tts: bool = True) -> None: """Send an event to the sound service.""" if self._snd_queue is not None: - self._snd_queue.put_nowait(event) + self._snd_queue.put_nowait(SoundEvent(event, is_tts)) def _make_snd_client(self) -> Optional[AsyncClient]: """Create client for snd service.""" @@ -581,7 +588,8 @@ async def _disconnect() -> None: if self._snd_queue is None: self._snd_queue = asyncio.Queue() - event = await self._snd_queue.get() + snd_event = await self._snd_queue.get() + event = snd_event.event if snd_client is None: snd_client = self._make_snd_client() @@ -608,7 +616,7 @@ async def _disconnect() -> None: event.type ): await _disconnect() - if not hasattr(event, 'wav'): + if snd_event.is_tts: await self.trigger_played() snd_client = None # reconnect on next event except asyncio.CancelledError: @@ -655,8 +663,7 @@ async def _play_wav( samples_per_chunk=self.settings.snd.samples_per_chunk, volume_multiplier=self.settings.snd.volume_multiplier, ): - event.wav = True - await self.event_to_snd(event) + await self.event_to_snd(event, is_tts=False) except Exception: # Unmute in case of an error self.microphone_muted = False