From 04edd6965ee7d8666de2386e1f2fa89ea45642d9 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 13 Nov 2023 16:13:16 -0800 Subject: [PATCH] Separate STT from Audio transformers timing context --- neon_speech/service.py | 75 ++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/neon_speech/service.py b/neon_speech/service.py index 91eca03..2d1ea40 100644 --- a/neon_speech/service.py +++ b/neon_speech/service.py @@ -27,6 +27,8 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +from typing import Dict + import ovos_dinkum_listener.plugins from tempfile import mkstemp @@ -80,8 +82,6 @@ def on_started(): class NeonSpeechClient(OVOSDinkumVoiceService): - _stopwatch = Stopwatch("get_stt") - def __init__(self, ready_hook=on_ready, error_hook=on_error, stopping_hook=on_stopping, alive_hook=on_alive, started_hook=on_started, watchdog=lambda: None, @@ -327,6 +327,7 @@ def handle_get_stt(self, message: Message): Emits a response to the sender with stt data or error data :param message: Message associated with request """ + received_time = time() if message.data.get("audio_data"): wav_file_path = self._write_encoded_file( message.data.pop("audio_data")) @@ -345,11 +346,14 @@ def handle_get_stt(self, message: Message): ident, data={"error": f"{wav_file_path} Not found!"})) try: + message.context['timing'].setdefault(dict()) + _, parser_data, transcriptions = \ self._get_stt_from_file(wav_file_path, lang) - received_time = time() - sent_time = message.context.get("timing", {}).get("client_sent", - received_time) + timing = parser_data.pop('timing') + message.context["timing"] = {**message.context["timing"], **timing} + sent_time = message.context["timing"].get("client_sent", + received_time) if received_time != sent_time: message.context['timing']['client_to_core'] = \ received_time - sent_time @@ -397,13 +401,13 @@ def build_context(msg: Message): wav_file_path = message.data.get("audio_file") lang = message.data.get("lang") try: - with self._stopwatch: - # _=transformed audio_data - _, parser_data, transcriptions = \ - self._get_stt_from_file(wav_file_path, lang) + # _=transformed audio_data + _, parser_data, transcriptions = \ + self._get_stt_from_file(wav_file_path, lang) + timing = parser_data.pop('timing') message.context["audio_parser_data"] = parser_data message.context.setdefault('timing', dict()) - message.context['timing']['get_stt'] = self._stopwatch.time + message.context['timing'] = {**timing, **message.context['timing']} context = build_context(message) data = { "utterances": transcriptions, @@ -441,7 +445,7 @@ def handle_offline(self, _): Handle notification to operate in offline mode """ LOG.info("Offline mode selected, Reloading STT Plugin") - config = dict(self.config) + config: Dict[str, dict] = dict(self.config) if config['stt'].get('offline_module'): config['stt']['module'] = config['stt'].get('offline_module') self.voice_loop.stt = STTFactory.create(config) @@ -474,6 +478,7 @@ def _get_stt_from_file(self, wav_file: str, :return: (AudioData of object, extracted context, transcriptions) """ from neon_utils.file_utils import get_audio_file_stream + _stopwatch = Stopwatch() lang = lang or self.config.get('lang') desired_sample_rate = self.config['listener'].get('sample_rate', 16000) desired_sample_width = self.config['listener'].get('sample_width', 2) @@ -487,28 +492,34 @@ def _get_stt_from_file(self, wav_file: str, if not self.api_stt: raise RuntimeError("api_stt not initialized." " is `listener['enable_stt_api'] set to False?") - if hasattr(self.api_stt, 'stream_start'): - audio_stream = get_audio_file_stream(wav_file, desired_sample_rate) - if self.lock.acquire(True, 30): - LOG.info(f"Starting STT processing (lang={lang}): {wav_file}") - self.api_stt.stream_start(lang) - while True: - try: - data = audio_stream.read(1024) - self.api_stt.stream_data(data) - except EOFError: - break - transcriptions = self.api_stt.stream_stop() - self.lock.release() + with _stopwatch: + if hasattr(self.api_stt, 'stream_start'): + audio_stream = get_audio_file_stream(wav_file, desired_sample_rate) + if self.lock.acquire(True, 30): + LOG.info(f"Starting STT processing (lang={lang}): {wav_file}") + self.api_stt.stream_start(lang) + while True: + try: + data = audio_stream.read(1024) + self.api_stt.stream_data(data) + except EOFError: + break + transcriptions = self.api_stt.stream_stop() + self.lock.release() + else: + LOG.error(f"Timed out acquiring lock, not processing: {wav_file}") + transcriptions = [] else: - LOG.error(f"Timed out acquiring lock, not processing: {wav_file}") - transcriptions = [] - else: - transcriptions = self.api_stt.execute(audio_data, lang) - if isinstance(transcriptions, str): - LOG.warning("Transcriptions is a str, no alternatives provided") - transcriptions = [transcriptions] - audio, audio_context = self.transformers.transform(audio_data) + transcriptions = self.api_stt.execute(audio_data, lang) + if isinstance(transcriptions, str): + LOG.warning("Transcriptions is a str, no alternatives provided") + transcriptions = [transcriptions] + + get_stt = float(_stopwatch.time) + with _stopwatch: + audio, audio_context = self.transformers.transform(audio_data) + audio_context["timing"] = {"get_stt": get_stt, + "transform_audio": _stopwatch.time} LOG.info(f"Transcribed: {transcriptions}") return audio, audio_context, transcriptions