Skip to content

Commit

Permalink
feat: Enhancements to ASR API (#5)
Browse files Browse the repository at this point in the history
1. Decouples all the elements of the ASR pipeline.
2. Unifies the API of each element of the ASR pipeline.
3. Introduces a new `AsrPipeline` class to construct a pipeline using appropriate elements. This would serve as a manager of the pipeline and its methods like `open`, `start`, `stop`, `close` would act as the controlling interfaces to the pipeline.

Each element of the ASR pipeline can still be used standalone by supplying the appropriate input using the method `next_chunk` of the element. Supplying input to `Source` elements  using `next_chunk` would do nothing as they generate data, but to keep an unified API they accept input.
  • Loading branch information
ar13pit authored Jun 23, 2020
2 parents 519e76e + e42e8a1 commit 334a34a
Show file tree
Hide file tree
Showing 12 changed files with 523 additions and 273 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pkgconfig


VERSION = "0.2.0"
VERSION = "0.3.0"
PACKAGE = "yapykaldi"
PACKAGE_DIR = os.path.join('src', 'python')

Expand Down
120 changes: 0 additions & 120 deletions src/python/yapykaldi/asr.py

This file was deleted.

22 changes: 22 additions & 0 deletions src/python/yapykaldi/asr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Yapykaldi ASR: Classes and functions for ASR pipeline
"""

__all__ = [
# From .asr
"Asr",

# From .pipeline
"AsrPipeline",

# From .sources
"PyAudioMicrophoneSource", "WaveFileSource",

# From .sinks
"WaveFileSink"
]

from .asr import Asr
from .pipeline import AsrPipeline
from .sources import PyAudioMicrophoneSource, WaveFileSource
from .sinks import WaveFileSink
84 changes: 84 additions & 0 deletions src/python/yapykaldi/asr/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Base classes for the ASR pipeline"""
from __future__ import print_function, division, absolute_import, unicode_literals
from builtins import *
from abc import ABC, abstractmethod
from threading import Event
import pyaudio


class AsrPipelineElementBase(ABC):
"""Class AsrPipelineElementBase is the base class for all Asr Pipeline elements.
It requires three abstract methods to be implemented:
1. open
2. close
3. next_chunk
The right order of setting up an element is:
1. element = AsrPipelineElementBase()
2. element.open() # To open the file, connect the mic etc.
3. element.start() # Start streaming audio data
4. element.next_chunk() # Use the audio data
5. element.stop() # stop getting audio data
6. element.close() # close the file
Elements need to support open and close at least once but must support
start, next_chunk, stop several times
"""
# pylint: disable=too-many-instance-attributes

def __init__(self, source=None, sink=None, rate=16000, chunksize=1024, fmt=pyaudio.paInt16, channels=1, timeout=1):
self._source = None
self._sink = None
self.rate = rate
self.chunksize = chunksize
self.format = fmt
self.channels = channels
self.timeout = timeout
self._finalize = Event()

self.link(source=source, sink=sink)

@abstractmethod
def open(self):
"""Abstract method to open the stream of the element. Opening may or may not start the stream."""

@abstractmethod
def next_chunk(self, chunk):
"""Abstract method to process a chunk generated in the source element or received from the source element"""

@abstractmethod
def close(self):
"""Abstract method to close the stream of the element. In this method all resources of the stream should be
freed."""

def start(self):
"""Optional method to start the stream of the element"""

def stop(self):
"""Optional method to stop the stream of the element"""

def register_callback(self, callback):
"""Register a callback to the element outside the pipeline"""
raise NotImplementedError()

def link(self, source=None, sink=None):
"""Link a source or a sink to the element
This method does not override preset source or sink of the element.
:param source: (default None) A source object
:param sink: (default None) A sink object
"""
if (not self._source) and source:
self._source = source
source.link(sink=self)

if (not self._sink) and sink:
self._sink = sink
sink.link(source=self)

def finalize(self):
"""Set the finalize flag of the element"""
self._finalize.set()
122 changes: 122 additions & 0 deletions src/python/yapykaldi/asr/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Yapykaldi ASR: Class definition for ASR component. It connects to a source and an optional sink
"""
from __future__ import (print_function, division, absolute_import, unicode_literals)
from builtins import *
import struct
import numpy as np
from ._base import AsrPipelineElementBase
from ..logger import logger
from ..nnet3 import KaldiNNet3OnlineDecoder, KaldiNNet3OnlineModel
from ..gmm import KaldiGmmOnlineDecoder, KaldiGmmOnlineModel
from ..utils import volume_indicator


ONLINE_MODELS = {'nnet3': KaldiNNet3OnlineModel, 'gmm': KaldiGmmOnlineModel}
ONLINE_DECODERS = {'nnet3': KaldiNNet3OnlineDecoder, 'gmm': KaldiGmmOnlineDecoder}


class Asr(AsrPipelineElementBase):
"""API for ASR"""
# pylint: disable=too-many-instance-attributes, useless-object-inheritance

def __init__(self, model_dir, model_type, rate=16000, chunksize=1024, debug=False, source=None, sink=None):
"""
:param model_dir: Path to model directory
:param model_type: Type of ASR model 'nnet3' or 'hmm'
:param rate: (default 16000) sampling frequency of audio data. This must be the same as the audio source
:param chunksize: (default 1024) size of audio data buffer. This must be the same as the audio source
:param debug: (default False) Flag to set logger to log audio chunk volume and partially decoded string and
likelihood
:param source: (default None) Element to be connected as source when constructing an AsrPipeline
:type source: AsrPipelineElementBase
:param sink: (default None) Element to be connected as sink when constructing an AsrPipeline
:type sink: AsrPipelineElementBase
"""
super().__init__(chunksize=chunksize, rate=rate, source=source, sink=sink)
self.model_dir = model_dir
self.model_type = model_type

self._model = None
self._decoder = None
self._decoded_string = None
self._likelihood = None

self._string_partially_recognized_callbacks = []
self._string_fully_recognized_callbacks = []

self._debug = debug

def open(self):
# No definition for this method while inheriting abstract class AsrPipelineElementBase
pass

def close(self):
# No definition for this method while inheriting abstract class AsrPipelineElementBase
pass

def next_chunk(self, chunk):
"""Method to start the recognition process on audio stream added to process queue"""
try:
data = np.array(struct.unpack_from('<%dh' % self.chunksize, chunk), dtype=np.float32)
except Exception as e: # pylint: disable=invalid-name, broad-except
logger.error("Other exception happened: %s", e)
raise
else:
if self._decoder.decode(self.rate, data, self._finalize.is_set()):
if self._finalize.is_set():
logger.info("Finalized decoding with latest data chunk")

self._decoded_string, self._likelihood = self._decoder.get_decoded_string()
if self._debug:
chunk_volume_level = volume_indicator(data)
logger.info("Chunk volume level: %s", chunk_volume_level)
logger.info("Partially decoded (%s): %s", self._likelihood, self._decoded_string)

for callback in self._string_partially_recognized_callbacks:
callback(self._decoded_string)

return chunk

raise RuntimeError("Decoding failed")

def stop(self):
"""Stop ASR process"""
logger.info("Stop ASR")

logger.info("Decoding of input stream is complete")
logger.info("Final result (%s): %s", self._likelihood, self._decoded_string)

for callback in self._string_fully_recognized_callbacks:
callback(self._decoded_string)

def start(self):
"""Begin ASR process"""
logger.info("Starting speech recognition")
# Reset internal states at the start of a new call

self._finalize.clear()

logger.info("Trying to initialize %s model from %s", self.model_type, self.model_dir)
self._model = ONLINE_MODELS[self.model_type](self.model_dir)
logger.info("Successfully initialized %s model from %s", self.model_type, self.model_dir)

logger.info("Trying to initialize %s model decoder", self.model_type)
self._decoder = ONLINE_DECODERS[self.model_type](self._model)
logger.info("Successfully initialized %s model decoder", self.model_type)

self._decoded_string = ""
self._likelihood = None

def register_callback(self, callback, partial=False):
"""
Register a callback to receive the decoded string both partial and complete.
:param callback: a function taking a single string as it's parameter
:param partial: (default False) flag to set callback for partial recognitions
:return: None
"""
if partial:
self._string_partially_recognized_callbacks += [callback]
else:
self._string_fully_recognized_callbacks += [callback]
Loading

0 comments on commit 334a34a

Please sign in to comment.