-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. Decouples all the elements of the ASR pipeline. 2. Unifies the API of each element of the ASR pipeline. 3. Introduces a new `AsrPipeline` class to construct a pipeline using appropriate elements. This would serve as a manager of the pipeline and its methods like `open`, `start`, `stop`, `close` would act as the controlling interfaces to the pipeline. Each element of the ASR pipeline can still be used standalone by supplying the appropriate input using the method `next_chunk` of the element. Supplying input to `Source` elements using `next_chunk` would do nothing as they generate data, but to keep an unified API they accept input.
- Loading branch information
Showing
12 changed files
with
523 additions
and
273 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Yapykaldi ASR: Classes and functions for ASR pipeline | ||
""" | ||
|
||
__all__ = [ | ||
# From .asr | ||
"Asr", | ||
|
||
# From .pipeline | ||
"AsrPipeline", | ||
|
||
# From .sources | ||
"PyAudioMicrophoneSource", "WaveFileSource", | ||
|
||
# From .sinks | ||
"WaveFileSink" | ||
] | ||
|
||
from .asr import Asr | ||
from .pipeline import AsrPipeline | ||
from .sources import PyAudioMicrophoneSource, WaveFileSource | ||
from .sinks import WaveFileSink |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
"""Base classes for the ASR pipeline""" | ||
from __future__ import print_function, division, absolute_import, unicode_literals | ||
from builtins import * | ||
from abc import ABC, abstractmethod | ||
from threading import Event | ||
import pyaudio | ||
|
||
|
||
class AsrPipelineElementBase(ABC): | ||
"""Class AsrPipelineElementBase is the base class for all Asr Pipeline elements. | ||
It requires three abstract methods to be implemented: | ||
1. open | ||
2. close | ||
3. next_chunk | ||
The right order of setting up an element is: | ||
1. element = AsrPipelineElementBase() | ||
2. element.open() # To open the file, connect the mic etc. | ||
3. element.start() # Start streaming audio data | ||
4. element.next_chunk() # Use the audio data | ||
5. element.stop() # stop getting audio data | ||
6. element.close() # close the file | ||
Elements need to support open and close at least once but must support | ||
start, next_chunk, stop several times | ||
""" | ||
# pylint: disable=too-many-instance-attributes | ||
|
||
def __init__(self, source=None, sink=None, rate=16000, chunksize=1024, fmt=pyaudio.paInt16, channels=1, timeout=1): | ||
self._source = None | ||
self._sink = None | ||
self.rate = rate | ||
self.chunksize = chunksize | ||
self.format = fmt | ||
self.channels = channels | ||
self.timeout = timeout | ||
self._finalize = Event() | ||
|
||
self.link(source=source, sink=sink) | ||
|
||
@abstractmethod | ||
def open(self): | ||
"""Abstract method to open the stream of the element. Opening may or may not start the stream.""" | ||
|
||
@abstractmethod | ||
def next_chunk(self, chunk): | ||
"""Abstract method to process a chunk generated in the source element or received from the source element""" | ||
|
||
@abstractmethod | ||
def close(self): | ||
"""Abstract method to close the stream of the element. In this method all resources of the stream should be | ||
freed.""" | ||
|
||
def start(self): | ||
"""Optional method to start the stream of the element""" | ||
|
||
def stop(self): | ||
"""Optional method to stop the stream of the element""" | ||
|
||
def register_callback(self, callback): | ||
"""Register a callback to the element outside the pipeline""" | ||
raise NotImplementedError() | ||
|
||
def link(self, source=None, sink=None): | ||
"""Link a source or a sink to the element | ||
This method does not override preset source or sink of the element. | ||
:param source: (default None) A source object | ||
:param sink: (default None) A sink object | ||
""" | ||
if (not self._source) and source: | ||
self._source = source | ||
source.link(sink=self) | ||
|
||
if (not self._sink) and sink: | ||
self._sink = sink | ||
sink.link(source=self) | ||
|
||
def finalize(self): | ||
"""Set the finalize flag of the element""" | ||
self._finalize.set() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
Yapykaldi ASR: Class definition for ASR component. It connects to a source and an optional sink | ||
""" | ||
from __future__ import (print_function, division, absolute_import, unicode_literals) | ||
from builtins import * | ||
import struct | ||
import numpy as np | ||
from ._base import AsrPipelineElementBase | ||
from ..logger import logger | ||
from ..nnet3 import KaldiNNet3OnlineDecoder, KaldiNNet3OnlineModel | ||
from ..gmm import KaldiGmmOnlineDecoder, KaldiGmmOnlineModel | ||
from ..utils import volume_indicator | ||
|
||
|
||
ONLINE_MODELS = {'nnet3': KaldiNNet3OnlineModel, 'gmm': KaldiGmmOnlineModel} | ||
ONLINE_DECODERS = {'nnet3': KaldiNNet3OnlineDecoder, 'gmm': KaldiGmmOnlineDecoder} | ||
|
||
|
||
class Asr(AsrPipelineElementBase): | ||
"""API for ASR""" | ||
# pylint: disable=too-many-instance-attributes, useless-object-inheritance | ||
|
||
def __init__(self, model_dir, model_type, rate=16000, chunksize=1024, debug=False, source=None, sink=None): | ||
""" | ||
:param model_dir: Path to model directory | ||
:param model_type: Type of ASR model 'nnet3' or 'hmm' | ||
:param rate: (default 16000) sampling frequency of audio data. This must be the same as the audio source | ||
:param chunksize: (default 1024) size of audio data buffer. This must be the same as the audio source | ||
:param debug: (default False) Flag to set logger to log audio chunk volume and partially decoded string and | ||
likelihood | ||
:param source: (default None) Element to be connected as source when constructing an AsrPipeline | ||
:type source: AsrPipelineElementBase | ||
:param sink: (default None) Element to be connected as sink when constructing an AsrPipeline | ||
:type sink: AsrPipelineElementBase | ||
""" | ||
super().__init__(chunksize=chunksize, rate=rate, source=source, sink=sink) | ||
self.model_dir = model_dir | ||
self.model_type = model_type | ||
|
||
self._model = None | ||
self._decoder = None | ||
self._decoded_string = None | ||
self._likelihood = None | ||
|
||
self._string_partially_recognized_callbacks = [] | ||
self._string_fully_recognized_callbacks = [] | ||
|
||
self._debug = debug | ||
|
||
def open(self): | ||
# No definition for this method while inheriting abstract class AsrPipelineElementBase | ||
pass | ||
|
||
def close(self): | ||
# No definition for this method while inheriting abstract class AsrPipelineElementBase | ||
pass | ||
|
||
def next_chunk(self, chunk): | ||
"""Method to start the recognition process on audio stream added to process queue""" | ||
try: | ||
data = np.array(struct.unpack_from('<%dh' % self.chunksize, chunk), dtype=np.float32) | ||
except Exception as e: # pylint: disable=invalid-name, broad-except | ||
logger.error("Other exception happened: %s", e) | ||
raise | ||
else: | ||
if self._decoder.decode(self.rate, data, self._finalize.is_set()): | ||
if self._finalize.is_set(): | ||
logger.info("Finalized decoding with latest data chunk") | ||
|
||
self._decoded_string, self._likelihood = self._decoder.get_decoded_string() | ||
if self._debug: | ||
chunk_volume_level = volume_indicator(data) | ||
logger.info("Chunk volume level: %s", chunk_volume_level) | ||
logger.info("Partially decoded (%s): %s", self._likelihood, self._decoded_string) | ||
|
||
for callback in self._string_partially_recognized_callbacks: | ||
callback(self._decoded_string) | ||
|
||
return chunk | ||
|
||
raise RuntimeError("Decoding failed") | ||
|
||
def stop(self): | ||
"""Stop ASR process""" | ||
logger.info("Stop ASR") | ||
|
||
logger.info("Decoding of input stream is complete") | ||
logger.info("Final result (%s): %s", self._likelihood, self._decoded_string) | ||
|
||
for callback in self._string_fully_recognized_callbacks: | ||
callback(self._decoded_string) | ||
|
||
def start(self): | ||
"""Begin ASR process""" | ||
logger.info("Starting speech recognition") | ||
# Reset internal states at the start of a new call | ||
|
||
self._finalize.clear() | ||
|
||
logger.info("Trying to initialize %s model from %s", self.model_type, self.model_dir) | ||
self._model = ONLINE_MODELS[self.model_type](self.model_dir) | ||
logger.info("Successfully initialized %s model from %s", self.model_type, self.model_dir) | ||
|
||
logger.info("Trying to initialize %s model decoder", self.model_type) | ||
self._decoder = ONLINE_DECODERS[self.model_type](self._model) | ||
logger.info("Successfully initialized %s model decoder", self.model_type) | ||
|
||
self._decoded_string = "" | ||
self._likelihood = None | ||
|
||
def register_callback(self, callback, partial=False): | ||
""" | ||
Register a callback to receive the decoded string both partial and complete. | ||
:param callback: a function taking a single string as it's parameter | ||
:param partial: (default False) flag to set callback for partial recognitions | ||
:return: None | ||
""" | ||
if partial: | ||
self._string_partially_recognized_callbacks += [callback] | ||
else: | ||
self._string_fully_recognized_callbacks += [callback] |
Oops, something went wrong.