Skip to content

Commit

Permalink
Merge branch 'rhasspy:master' into bugfix-wav-volume
Browse files Browse the repository at this point in the history
  • Loading branch information
llluis authored Feb 19, 2024
2 parents 61ae4ce + 70652bb commit 886019e
Show file tree
Hide file tree
Showing 13 changed files with 377 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
---
name: test

# yamllint disable-line rule:truthy
on:
workflow_dispatch:
pull_request:
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 1.2.0

- Add `--tts-played-command`
- Add `--mic-seconds-to-mute-after-awake-wav` and `--mic-no-mute-during-awake-wav`
- Send preferred sound format to server

## 1.1.1

- Bump to wyoming 1.5.2 (package fix)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
wyoming==1.5.2
wyoming==1.5.3
zeroconf==0.88.0
pyring-buffer==1.0.0
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def get_requirements(req_path: Path) -> List[str]:
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
198 changes: 198 additions & 0 deletions tests/test_satellite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import asyncio
import io
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Final, Optional
from unittest.mock import patch

import pytest
from wyoming.asr import Transcript
from wyoming.audio import AudioChunk
from wyoming.client import AsyncClient
from wyoming.event import Event, async_read_event
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite, StreamingStarted, StreamingStopped
from wyoming.wake import Detection

from wyoming_satellite import (
EventSettings,
MicSettings,
SatelliteSettings,
WakeSettings,
WakeStreamingSatellite,
)

from .shared import AUDIO_CHUNK

_LOGGER = logging.getLogger()

TIMEOUT: Final = 1


class MicClient(AsyncClient):
def __init__(self) -> None:
super().__init__()

async def read_event(self) -> Optional[Event]:
await asyncio.sleep(AUDIO_CHUNK.seconds)
return AUDIO_CHUNK.event()

async def write_event(self, event: Event) -> None:
# Output only
pass


class WakeClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
self._event_ready = asyncio.Event()
self._event: Optional[Event] = None
self._detected: bool = False

async def read_event(self) -> Optional[Event]:
await self._event_ready.wait()
self._event_ready.clear()
return self._event

async def write_event(self, event: Event) -> None:
if AudioChunk.is_type(event.type):
if not self._detected:
self._detected = True
self._event = Detection().event()
self._event_ready.set()


class EventClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
self.detection = asyncio.Event()
self.streaming_started = asyncio.Event()
self.streaming_stopped = asyncio.Event()

async def read_event(self) -> Optional[Event]:
# Input only
return None

async def write_event(self, event: Event) -> None:
if Detection.is_type(event.type):
self.detection.set()
elif StreamingStarted.is_type(event.type):
self.streaming_started.set()
elif StreamingStopped.is_type(event.type):
self.streaming_stopped.set()


class FakeStreamReaderWriter:
def __init__(self) -> None:
self._undrained_data = bytes()
self._value = bytes()
self._data_ready = asyncio.Event()

def write(self, data: bytes) -> None:
self._undrained_data += data

def writelines(self, data: Iterable[bytes]) -> None:
for line in data:
self.write(line)

async def drain(self) -> None:
self._value += self._undrained_data
self._undrained_data = bytes()
self._data_ready.set()
self._data_ready.clear()

async def readline(self) -> bytes:
while b"\n" not in self._value:
await self._data_ready.wait()

with io.BytesIO(self._value) as value_io:
data = value_io.readline()
self._value = self._value[len(data) :]
return data

async def readexactly(self, n: int) -> bytes:
while len(self._value) < n:
await self._data_ready.wait()

data = self._value[:n]
self._value = self._value[n:]
return data


@pytest.mark.asyncio
async def test_satellite_and_server(tmp_path: Path) -> None:
mic_client = MicClient()
wake_client = WakeClient()
event_client = EventClient()

with patch(
"wyoming_satellite.satellite.SatelliteBase._make_mic_client",
return_value=mic_client,
), patch(
"wyoming_satellite.satellite.SatelliteBase._make_wake_client",
return_value=wake_client,
), patch(
"wyoming_satellite.satellite.SatelliteBase._make_event_client",
return_value=event_client,
):
satellite = WakeStreamingSatellite(
SatelliteSettings(
mic=MicSettings(uri="test"),
wake=WakeSettings(uri="test"),
event=EventSettings(uri="test"),
)
)

# Fake server connection
server_io = FakeStreamReaderWriter()
await satellite.set_server("test", server_io) # type: ignore

async def event_from_satellite() -> Optional[Event]:
return await async_read_event(server_io)

satellite_task = asyncio.create_task(satellite.run(), name="satellite")
await satellite.event_from_server(RunSatellite().event())

# Trigger detection
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert Detection.is_type(event.type), event

# Pipeline should start
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert RunPipeline.is_type(event.type), event
run_pipeline = RunPipeline.from_event(event)
assert run_pipeline.start_stage == PipelineStage.ASR

# No TTS
assert run_pipeline.end_stage == PipelineStage.HANDLE

# Event service should have received detection
await asyncio.wait_for(event_client.detection.wait(), timeout=TIMEOUT)

# Server should be receiving audio now
assert satellite.is_streaming, "Not streaming"
for _ in range(5):
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert AudioChunk.is_type(event.type)

# Event service should have received streaming start
await asyncio.wait_for(event_client.streaming_started.wait(), timeout=TIMEOUT)

# Send transcript
await satellite.event_from_server(Transcript(text="test").event())

# Wait for streaming to stop
while satellite.is_streaming:
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
assert event is not None
assert AudioChunk.is_type(event.type)

# Event service should have received streaming stop
await asyncio.wait_for(event_client.streaming_stopped.wait(), timeout=TIMEOUT)

await satellite.stop()
await satellite_task
4 changes: 2 additions & 2 deletions tests/test_wake_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ async def test_multiple_wakeups(tmp_path: Path) -> None:
await satellite.event_from_server(Transcript("test").event())

# Should not trigger again within refractory period (default: 5 sec)
# with pytest.raises(asyncio.TimeoutError):
# await asyncio.wait_for(event_client.wake_event.wait(), timeout=0.15)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(event_client.wake_event.wait(), timeout=0.15)

await satellite.stop()
await satellite_task
14 changes: 14 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[tox]
env_list =
py{39,310,311}
minversion = 4.12.1

[testenv]
description = run the tests with pytest
package = wheel
wheel_build_env = .pkg
deps =
pytest>=7,<8
pytest-asyncio<1
commands =
pytest {tty:--color=yes} {posargs}
15 changes: 13 additions & 2 deletions wyoming_satellite/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SndSettings,
VadSettings,
WakeSettings,
WakeWordAndPipeline,
)
from .utils import (
get_mac_address,
Expand Down Expand Up @@ -84,6 +85,11 @@ async def main() -> None:
action="store_true",
help="Don't mute the microphone while awake wav is being played",
)
parser.add_argument(
"--mic-channel-index",
type=int,
help="Take microphone input from a specific channel (first channel is 0)",
)

# Sound output
parser.add_argument("--snd-uri", help="URI of Wyoming sound service")
Expand Down Expand Up @@ -114,7 +120,9 @@ async def main() -> None:
"--wake-word-name",
action="append",
default=[],
help="Name of wake word to listen for (requires --wake-uri)",
nargs="+",
metavar=("name", "pipeline"),
help="Name of wake word to listen for and optional pipeline name to run (requires --wake-uri)",
)
parser.add_argument("--wake-command", help="Program to run for wake word detection")
parser.add_argument(
Expand Down Expand Up @@ -318,6 +326,7 @@ async def main() -> None:
noise_suppression=args.mic_noise_suppression,
seconds_to_mute_after_awake_wav=args.mic_seconds_to_mute_after_awake_wav,
mute_during_awake_wav=(not args.mic_no_mute_during_awake_wav),
channel_index=args.mic_channel_index,
),
vad=VadSettings(
enabled=args.vad,
Expand All @@ -329,7 +338,9 @@ async def main() -> None:
wake=WakeSettings(
uri=args.wake_uri,
command=split_command(args.wake_command),
names=args.wake_word_name,
names=[
WakeWordAndPipeline(*wake_name) for wake_name in args.wake_word_name
],
refractory_seconds=args.wake_refractory_seconds
if args.wake_refractory_seconds > 0
else None,
Expand Down
5 changes: 3 additions & 2 deletions wyoming_satellite/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
super().__init__(*args, **kwargs)

self.cli_args = cli_args
self.wyoming_info_event = wyoming_info.event()
self.wyoming_info = wyoming_info
self.client_id = str(time.monotonic_ns())
self.satellite = satellite

Expand All @@ -35,7 +35,8 @@ def __init__(
async def handle_event(self, event: Event) -> bool:
"""Handle events from the server."""
if Describe.is_type(event.type):
await self.write_event(self.wyoming_info_event)
await self.satellite.update_info(self.wyoming_info)
await self.write_event(self.wyoming_info.event())
return True

if self.satellite.server_id is None:
Expand Down
Loading

0 comments on commit 886019e

Please sign in to comment.