Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam committed Feb 18, 2024
1 parent f0dcf39 commit 70652bb
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
---
name: test

# yamllint disable-line rule:truthy
on:
workflow_dispatch:
pull_request:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
wyoming==1.5.2
wyoming==1.5.3
zeroconf==0.88.0
pyring-buffer==1.0.0
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def get_requirements(req_path: Path) -> List[str]:
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
198 changes: 198 additions & 0 deletions tests/test_satellite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import asyncio
import io
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Final, Optional
from unittest.mock import patch

import pytest
from wyoming.asr import Transcript
from wyoming.audio import AudioChunk
from wyoming.client import AsyncClient
from wyoming.event import Event, async_read_event
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite, StreamingStarted, StreamingStopped
from wyoming.wake import Detection

from wyoming_satellite import (
EventSettings,
MicSettings,
SatelliteSettings,
WakeSettings,
WakeStreamingSatellite,
)

from .shared import AUDIO_CHUNK

_LOGGER = logging.getLogger()

TIMEOUT: Final = 1


class MicClient(AsyncClient):
def __init__(self) -> None:
super().__init__()

async def read_event(self) -> Optional[Event]:
await asyncio.sleep(AUDIO_CHUNK.seconds)
return AUDIO_CHUNK.event()

async def write_event(self, event: Event) -> None:
# Output only
pass


class WakeClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
self._event_ready = asyncio.Event()
self._event: Optional[Event] = None
self._detected: bool = False

async def read_event(self) -> Optional[Event]:
await self._event_ready.wait()
self._event_ready.clear()
return self._event

async def write_event(self, event: Event) -> None:
if AudioChunk.is_type(event.type):
if not self._detected:
self._detected = True
self._event = Detection().event()
self._event_ready.set()


class EventClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
self.detection = asyncio.Event()
self.streaming_started = asyncio.Event()
self.streaming_stopped = asyncio.Event()

async def read_event(self) -> Optional[Event]:
# Input only
return None

async def write_event(self, event: Event) -> None:
if Detection.is_type(event.type):
self.detection.set()
elif StreamingStarted.is_type(event.type):
self.streaming_started.set()
elif StreamingStopped.is_type(event.type):
self.streaming_stopped.set()


class FakeStreamReaderWriter:
def __init__(self) -> None:
self._undrained_data = bytes()
self._value = bytes()
self._data_ready = asyncio.Event()

def write(self, data: bytes) -> None:
self._undrained_data += data

def writelines(self, data: Iterable[bytes]) -> None:
for line in data:
self.write(line)

async def drain(self) -> None:
self._value += self._undrained_data
self._undrained_data = bytes()
self._data_ready.set()
self._data_ready.clear()

async def readline(self) -> bytes:
while b"\n" not in self._value:
await self._data_ready.wait()

with io.BytesIO(self._value) as value_io:
data = value_io.readline()
self._value = self._value[len(data) :]
return data

async def readexactly(self, n: int) -> bytes:
while len(self._value) < n:
await self._data_ready.wait()

data = self._value[:n]
self._value = self._value[n:]
return data


@pytest.mark.asyncio
async def test_satellite_and_server(tmp_path: Path) -> None:
mic_client = MicClient()
wake_client = WakeClient()
event_client = EventClient()

with patch(
"wyoming_satellite.satellite.SatelliteBase._make_mic_client",
return_value=mic_client,
), patch(
"wyoming_satellite.satellite.SatelliteBase._make_wake_client",
return_value=wake_client,
), patch(
"wyoming_satellite.satellite.SatelliteBase._make_event_client",
return_value=event_client,
):
satellite = WakeStreamingSatellite(
SatelliteSettings(
mic=MicSettings(uri="test"),
wake=WakeSettings(uri="test"),
event=EventSettings(uri="test"),
)
)

# Fake server connection
server_io = FakeStreamReaderWriter()
await satellite.set_server("test", server_io) # type: ignore

async def event_from_satellite() -> Optional[Event]:
return await async_read_event(server_io)

satellite_task = asyncio.create_task(satellite.run(), name="satellite")
await satellite.event_from_server(RunSatellite().event())

# Trigger detection
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert Detection.is_type(event.type), event

# Pipeline should start
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert RunPipeline.is_type(event.type), event
run_pipeline = RunPipeline.from_event(event)
assert run_pipeline.start_stage == PipelineStage.ASR

# No TTS
assert run_pipeline.end_stage == PipelineStage.HANDLE

# Event service should have received detection
await asyncio.wait_for(event_client.detection.wait(), timeout=TIMEOUT)

# Server should be receiving audio now
assert satellite.is_streaming, "Not streaming"
for _ in range(5):
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert AudioChunk.is_type(event.type)

# Event service should have received streaming start
await asyncio.wait_for(event_client.streaming_started.wait(), timeout=TIMEOUT)

# Send transcript
await satellite.event_from_server(Transcript(text="test").event())

# Wait for streaming to stop
while satellite.is_streaming:
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert AudioChunk.is_type(event.type)

# Event service should have received streaming stop
await asyncio.wait_for(event_client.streaming_stopped.wait(), timeout=TIMEOUT)

await satellite.stop()
await satellite_task
4 changes: 2 additions & 2 deletions tests/test_wake_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ async def test_multiple_wakeups(tmp_path: Path) -> None:
await satellite.event_from_server(Transcript("test").event())

# Should not trigger again within refractory period (default: 5 sec)
# with pytest.raises(asyncio.TimeoutError):
# await asyncio.wait_for(event_client.wake_event.wait(), timeout=0.15)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(event_client.wake_event.wait(), timeout=0.15)

await satellite.stop()
await satellite_task
14 changes: 14 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[tox]
env_list =
py{39,310,311}
minversion = 4.12.1

[testenv]
description = run the tests with pytest
package = wheel
wheel_build_env = .pkg
deps =
pytest>=7,<8
pytest-asyncio<1
commands =
pytest {tty:--color=yes} {posargs}
5 changes: 3 additions & 2 deletions wyoming_satellite/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
super().__init__(*args, **kwargs)

self.cli_args = cli_args
self.wyoming_info_event = wyoming_info.event()
self.wyoming_info = wyoming_info
self.client_id = str(time.monotonic_ns())
self.satellite = satellite

Expand All @@ -35,7 +35,8 @@ def __init__(
async def handle_event(self, event: Event) -> bool:
"""Handle events from the server."""
if Describe.is_type(event.type):
await self.write_event(self.wyoming_info_event)
await self.satellite.update_info(self.wyoming_info)
await self.write_event(self.wyoming_info.event())
return True

if self.satellite.server_id is None:
Expand Down
38 changes: 37 additions & 1 deletion wyoming_satellite/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from wyoming.client import AsyncClient
from wyoming.error import Error
from wyoming.event import Event, async_write_event
from wyoming.info import Describe, Info
from wyoming.mic import MicProcessAsyncClient
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
Expand Down Expand Up @@ -46,6 +47,7 @@

_PONG_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2
_WAKE_INFO_TIMEOUT: Final = 2


class State(Enum):
Expand Down Expand Up @@ -898,6 +900,13 @@ async def _disconnect() -> None:

await _disconnect()

# -------------------------------------------------------------------------
# Info
# -------------------------------------------------------------------------

async def update_info(self, info: Info) -> None:
pass


# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -1150,6 +1159,9 @@ def __init__(self, settings: SatelliteSettings) -> None:

self._is_paused = False

self._wake_info: Optional[Info] = None
self._wake_info_ready = asyncio.Event()

async def event_from_server(self, event: Event) -> None:
# Only check event types once
is_run_satellite = False
Expand Down Expand Up @@ -1243,8 +1255,13 @@ async def event_from_mic(
await self.event_to_wake(event)

async def event_from_wake(self, event: Event) -> None:
if Info.is_type(event.type):
self._wake_info = Info.from_event(event)
self._wake_info_ready.set()
return

if self.is_streaming or (self.server_id is None):
# Not streaming or no server connected
# Not detecting or no server connected
return

if Detection.is_type(event.type):
Expand Down Expand Up @@ -1281,6 +1298,9 @@ async def event_from_wake(self, event: Event) -> None:
# No refractory period
self.refractory_timestamp.pop(detection.name, None)

# Forward to the server
await self.event_to_server(event)

# Match detected wake word name with pipeline name
pipeline_name: Optional[str] = None
if self.settings.wake.names:
Expand All @@ -1294,3 +1314,19 @@ async def event_from_wake(self, event: Event) -> None:
await self.forward_event(event) # forward to event service
await self.trigger_detection(Detection.from_event(event))
await self.trigger_streaming_start()

async def update_info(self, info: Info) -> None:
self._wake_info = None
self._wake_info_ready.clear()
await self.event_to_wake(Describe().event())

try:
await asyncio.wait_for(
self._wake_info_ready.wait(), timeout=_WAKE_INFO_TIMEOUT
)

if self._wake_info is not None:
# Update wake info only
info.wake = self._wake_info.wake
except asyncio.TimeoutError:
_LOGGER.warning("Failed to get info from wake service")

0 comments on commit 70652bb

Please sign in to comment.