Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved timing context handling with unit tests #182

Merged
merged 16 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions neon_speech/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
from typing import Dict

import ovos_dinkum_listener.plugins

from tempfile import mkstemp
Expand Down Expand Up @@ -80,8 +82,6 @@ def on_started():


class NeonSpeechClient(OVOSDinkumVoiceService):
_stopwatch = Stopwatch("get_stt")

def __init__(self, ready_hook=on_ready, error_hook=on_error,
stopping_hook=on_stopping, alive_hook=on_alive,
started_hook=on_started, watchdog=lambda: None,
Expand Down Expand Up @@ -327,31 +327,46 @@ def handle_get_stt(self, message: Message):
Emits a response to the sender with stt data or error data
:param message: Message associated with request
"""
received_time = time()
if message.data.get("audio_data"):
wav_file_path = self._write_encoded_file(
message.data.pop("audio_data"))
else:
wav_file_path = message.data.get("audio_file")
lang = message.data.get("lang")
ident = message.context.get("ident") or "neon.get_stt.response"

message.context.setdefault("timing", dict())
LOG.info(f"Handling STT request: {ident}")
if not wav_file_path:
message.context['timing']['response_sent'] = time()
self.bus.emit(message.reply(
ident, data={"error": f"audio_file not specified!"}))
return

if not os.path.isfile(wav_file_path):
message.context['timing']['response_sent'] = time()
self.bus.emit(message.reply(
ident, data={"error": f"{wav_file_path} Not found!"}))

try:

_, parser_data, transcriptions = \
self._get_stt_from_file(wav_file_path, lang)
timing = parser_data.pop('timing')
message.context["timing"] = {**message.context["timing"], **timing}
sent_time = message.context["timing"].get("client_sent",
received_time)
if received_time != sent_time:
message.context['timing']['client_to_core'] = \
received_time - sent_time
message.context['timing']['response_sent'] = time()
self.bus.emit(message.reply(ident,
data={"parser_data": parser_data,
"transcripts": transcriptions}))
except Exception as e:
LOG.error(e)
message.context['timing']['response_sent'] = time()
self.bus.emit(message.reply(ident, data={"error": repr(e)}))

def handle_audio_input(self, message):
Expand All @@ -370,14 +385,18 @@ def build_context(msg: Message):
'username': self._default_user["user"]["username"] or
"local",
'user_profiles': [self._default_user.content]}
ctx = {**defaults, **ctx, 'destination': ['skills'],
'timing': {'start': msg.data.get('time'),
'transcribed': time()}}
ctx = {**defaults, **ctx, 'destination': ['skills']}
ctx['timing'] = {**ctx.get('timing', {}),
**{'start': msg.data.get('time'),
'transcribed': time()}}
return ctx

received_time = time()
sent_time = message.context.get("timing", {}).get("client_sent",
received_time)
if received_time != sent_time:
message.context['timing']['client_to_core'] = \
received_time - sent_time
ident = message.context.get("ident") or "neon.audio_input.response"
LOG.info(f"Handling audio input: {ident}")
if message.data.get("audio_data"):
Expand All @@ -387,21 +406,23 @@ def build_context(msg: Message):
wav_file_path = message.data.get("audio_file")
lang = message.data.get("lang")
try:
with self._stopwatch:
# _=transformed audio_data
_, parser_data, transcriptions = \
self._get_stt_from_file(wav_file_path, lang)
# _=transformed audio_data
_, parser_data, transcriptions = \
self._get_stt_from_file(wav_file_path, lang)
timing = parser_data.pop('timing')
message.context["audio_parser_data"] = parser_data
message.context.setdefault('timing', dict())
message.context['timing'] = {**timing, **message.context['timing']}
context = build_context(message)
if received_time != sent_time:
context['timing']['mq_from_client'] = received_time - sent_time
context['timing']['get_stt'] = self._stopwatch.time
data = {
"utterances": transcriptions,
"lang": message.data.get("lang", "en-us")
}
# Send a new message to the skills module with proper routing ctx
handled = self._emit_utterance_to_skills(Message(
'recognizer_loop:utterance', data, context))

# Reply to original message with transcription/audio parser data
self.bus.emit(message.reply(ident,
data={"parser_data": parser_data,
"transcripts": transcriptions,
Expand Down Expand Up @@ -429,7 +450,7 @@ def handle_offline(self, _):
Handle notification to operate in offline mode
"""
LOG.info("Offline mode selected, Reloading STT Plugin")
config = dict(self.config)
config: Dict[str, dict] = dict(self.config)
if config['stt'].get('offline_module'):
config['stt']['module'] = config['stt'].get('offline_module')
self.voice_loop.stt = STTFactory.create(config)
Expand Down Expand Up @@ -462,6 +483,7 @@ def _get_stt_from_file(self, wav_file: str,
:return: (AudioData of object, extracted context, transcriptions)
"""
from neon_utils.file_utils import get_audio_file_stream
_stopwatch = Stopwatch()
lang = lang or self.config.get('lang')
desired_sample_rate = self.config['listener'].get('sample_rate', 16000)
desired_sample_width = self.config['listener'].get('sample_width', 2)
Expand All @@ -475,28 +497,34 @@ def _get_stt_from_file(self, wav_file: str,
if not self.api_stt:
raise RuntimeError("api_stt not initialized."
" is `listener['enable_stt_api'] set to False?")
if hasattr(self.api_stt, 'stream_start'):
audio_stream = get_audio_file_stream(wav_file, desired_sample_rate)
if self.lock.acquire(True, 30):
LOG.info(f"Starting STT processing (lang={lang}): {wav_file}")
self.api_stt.stream_start(lang)
while True:
try:
data = audio_stream.read(1024)
self.api_stt.stream_data(data)
except EOFError:
break
transcriptions = self.api_stt.stream_stop()
self.lock.release()
with _stopwatch:
if hasattr(self.api_stt, 'stream_start'):
audio_stream = get_audio_file_stream(wav_file, desired_sample_rate)
if self.lock.acquire(True, 30):
LOG.info(f"Starting STT processing (lang={lang}): {wav_file}")
self.api_stt.stream_start(lang)
while True:
try:
data = audio_stream.read(1024)
self.api_stt.stream_data(data)
except EOFError:
break
transcriptions = self.api_stt.stream_stop()
self.lock.release()
else:
LOG.error(f"Timed out acquiring lock, not processing: {wav_file}")
transcriptions = []
else:
LOG.error(f"Timed out acquiring lock, not processing: {wav_file}")
transcriptions = []
else:
transcriptions = self.api_stt.execute(audio_data, lang)
if isinstance(transcriptions, str):
LOG.warning("Transcriptions is a str, no alternatives provided")
transcriptions = [transcriptions]
audio, audio_context = self.transformers.transform(audio_data)
transcriptions = self.api_stt.execute(audio_data, lang)
if isinstance(transcriptions, str):
LOG.warning("Transcriptions is a str, no alternatives provided")
transcriptions = [transcriptions]

get_stt = float(_stopwatch.time)
with _stopwatch:
audio, audio_context = self.transformers.transform(audio_data)
audio_context["timing"] = {"get_stt": get_stt,
"transform_audio": _stopwatch.time}
LOG.info(f"Transcribed: {transcriptions}")
return audio, audio_context, transcriptions

Expand Down
Loading
Loading