Skip to content

Commit

Permalink
Separate STT from Audio transformers timing context
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Nov 14, 2023
1 parent e889581 commit 04edd69
Showing 1 changed file with 43 additions and 32 deletions.
75 changes: 43 additions & 32 deletions neon_speech/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
from typing import Dict

import ovos_dinkum_listener.plugins

from tempfile import mkstemp
Expand Down Expand Up @@ -80,8 +82,6 @@ def on_started():


class NeonSpeechClient(OVOSDinkumVoiceService):
_stopwatch = Stopwatch("get_stt")

def __init__(self, ready_hook=on_ready, error_hook=on_error,
stopping_hook=on_stopping, alive_hook=on_alive,
started_hook=on_started, watchdog=lambda: None,
Expand Down Expand Up @@ -327,6 +327,7 @@ def handle_get_stt(self, message: Message):
Emits a response to the sender with stt data or error data
:param message: Message associated with request
"""
received_time = time()
if message.data.get("audio_data"):
wav_file_path = self._write_encoded_file(
message.data.pop("audio_data"))
Expand All @@ -345,11 +346,14 @@ def handle_get_stt(self, message: Message):
ident, data={"error": f"{wav_file_path} Not found!"}))

try:
message.context['timing'].setdefault(dict())

_, parser_data, transcriptions = \
self._get_stt_from_file(wav_file_path, lang)
received_time = time()
sent_time = message.context.get("timing", {}).get("client_sent",
received_time)
timing = parser_data.pop('timing')
message.context["timing"] = {**message.context["timing"], **timing}
sent_time = message.context["timing"].get("client_sent",
received_time)
if received_time != sent_time:
message.context['timing']['client_to_core'] = \
received_time - sent_time
Expand Down Expand Up @@ -397,13 +401,13 @@ def build_context(msg: Message):
wav_file_path = message.data.get("audio_file")
lang = message.data.get("lang")
try:
with self._stopwatch:
# _=transformed audio_data
_, parser_data, transcriptions = \
self._get_stt_from_file(wav_file_path, lang)
# _=transformed audio_data
_, parser_data, transcriptions = \
self._get_stt_from_file(wav_file_path, lang)
timing = parser_data.pop('timing')
message.context["audio_parser_data"] = parser_data
message.context.setdefault('timing', dict())
message.context['timing']['get_stt'] = self._stopwatch.time
message.context['timing'] = {**timing, **message.context['timing']}
context = build_context(message)
data = {
"utterances": transcriptions,
Expand Down Expand Up @@ -441,7 +445,7 @@ def handle_offline(self, _):
Handle notification to operate in offline mode
"""
LOG.info("Offline mode selected, Reloading STT Plugin")
config = dict(self.config)
config: Dict[str, dict] = dict(self.config)
if config['stt'].get('offline_module'):
config['stt']['module'] = config['stt'].get('offline_module')
self.voice_loop.stt = STTFactory.create(config)
Expand Down Expand Up @@ -474,6 +478,7 @@ def _get_stt_from_file(self, wav_file: str,
:return: (AudioData of object, extracted context, transcriptions)
"""
from neon_utils.file_utils import get_audio_file_stream
_stopwatch = Stopwatch()
lang = lang or self.config.get('lang')
desired_sample_rate = self.config['listener'].get('sample_rate', 16000)
desired_sample_width = self.config['listener'].get('sample_width', 2)
Expand All @@ -487,28 +492,34 @@ def _get_stt_from_file(self, wav_file: str,
if not self.api_stt:
raise RuntimeError("api_stt not initialized."
" is `listener['enable_stt_api'] set to False?")
if hasattr(self.api_stt, 'stream_start'):
audio_stream = get_audio_file_stream(wav_file, desired_sample_rate)
if self.lock.acquire(True, 30):
LOG.info(f"Starting STT processing (lang={lang}): {wav_file}")
self.api_stt.stream_start(lang)
while True:
try:
data = audio_stream.read(1024)
self.api_stt.stream_data(data)
except EOFError:
break
transcriptions = self.api_stt.stream_stop()
self.lock.release()
with _stopwatch:
if hasattr(self.api_stt, 'stream_start'):
audio_stream = get_audio_file_stream(wav_file, desired_sample_rate)
if self.lock.acquire(True, 30):
LOG.info(f"Starting STT processing (lang={lang}): {wav_file}")
self.api_stt.stream_start(lang)
while True:
try:
data = audio_stream.read(1024)
self.api_stt.stream_data(data)
except EOFError:
break
transcriptions = self.api_stt.stream_stop()
self.lock.release()
else:
LOG.error(f"Timed out acquiring lock, not processing: {wav_file}")
transcriptions = []
else:
LOG.error(f"Timed out acquiring lock, not processing: {wav_file}")
transcriptions = []
else:
transcriptions = self.api_stt.execute(audio_data, lang)
if isinstance(transcriptions, str):
LOG.warning("Transcriptions is a str, no alternatives provided")
transcriptions = [transcriptions]
audio, audio_context = self.transformers.transform(audio_data)
transcriptions = self.api_stt.execute(audio_data, lang)
if isinstance(transcriptions, str):
LOG.warning("Transcriptions is a str, no alternatives provided")
transcriptions = [transcriptions]

get_stt = float(_stopwatch.time)
with _stopwatch:
audio, audio_context = self.transformers.transform(audio_data)
audio_context["timing"] = {"get_stt": get_stt,
"transform_audio": _stopwatch.time}
LOG.info(f"Transcribed: {transcriptions}")
return audio, audio_context, transcriptions

Expand Down

0 comments on commit 04edd69

Please sign in to comment.