From 4a04172e6e3c793748ca16ce3d556e81cb048300 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 29 Feb 2024 10:39:21 +0800 Subject: [PATCH] Fix comments, more fixes --- .../keyword-spotter-from-microphone.py | 12 ------- python-api-examples/keyword-spotter.py | 1 - .../python/sherpa_onnx/keyword_spotter.py | 35 +++++++++++-------- 3 files changed, 20 insertions(+), 28 deletions(-) diff --git a/python-api-examples/keyword-spotter-from-microphone.py b/python-api-examples/keyword-spotter-from-microphone.py index f1e550400..4b0be3159 100755 --- a/python-api-examples/keyword-spotter-from-microphone.py +++ b/python-api-examples/keyword-spotter-from-microphone.py @@ -98,7 +98,6 @@ def get_args(): parser.add_argument( "--keywords-file", type=str, - default="", help=""" The file containing keywords, one words/phrases per line, and for each phrase the bpe/cjkchar/pinyin are separated by a space. For example: @@ -128,17 +127,6 @@ def get_args(): """, ) - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to decode. Each file must be of WAVE" - "format with a single channel, and each sample has 16-bit, " - "i.e., int16_t. " - "The sample rate of the file can be arbitrary and does not need to " - "be 16 kHz", - ) - return parser.parse_args() diff --git a/python-api-examples/keyword-spotter.py b/python-api-examples/keyword-spotter.py index ccd9e4fe1..64debbddb 100755 --- a/python-api-examples/keyword-spotter.py +++ b/python-api-examples/keyword-spotter.py @@ -83,7 +83,6 @@ def get_args(): parser.add_argument( "--keywords-file", type=str, - default="", help=""" The file containing keywords, one words/phrases per line, and for each phrase the bpe/cjkchar/pinyin are separated by a space. For example: diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index a176b3487..8373dd091 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -19,7 +19,7 @@ def _assert_file_exists(f: str): class KeywordSpotter(object): - """A class for streaming speech recognition. + """A class for keyword spotting. Please refer to the following files for usages - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter.py @@ -32,11 +32,11 @@ def __init__( encoder: str, decoder: str, joiner: str, + keywords_file: str, num_threads: int = 2, sample_rate: float = 16000, feature_dim: int = 80, max_active_paths: int = 4, - keywords_file: str = "", keywords_score: float = 1.5, keywords_threshold: float = 0.35, num_tailing_blanks: int = 1, @@ -61,6 +61,9 @@ def __init__( Path to ``decoder.onnx``. joiner: Path to ``joiner.onnx``. + keywords_file: + The file containing keywords, one word/phrase per line, and for each + phrase the bpe/cjkchar/pinyin are separated by a space. num_threads: Number of threads for neural network computation. sample_rate: @@ -70,14 +73,16 @@ def __init__( max_active_paths: Use only when decoding_method is modified_beam_search. It specifies the maximum number of active paths during beam search. - keywords_file: - The file containing hotwords, one words/phrases per line, and for each - phrase the bpe/cjkchar are separated by a space. keywords_score: - The hotword score of each token for biasing word/phrase. Used only if - hotwords_file is given with modified_beam_search as decoding method. + The boosting score of each token for keywords. The larger the easier to + survive beam search. keywords_threshold: + The trigger threshold (i.e. probability) of the keyword. The larger the + harder to trigger. num_tailing_blanks: + The number of tailing blanks a keyword should be followed. Setting + to a larger value (e.g. 8) when your keywords has overlapping tokens + between each other. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ @@ -107,13 +112,13 @@ def __init__( ) keywords_spotter_config = KeywordSpotterConfig( - feat_config, - model_config, - max_active_paths, - num_tailing_blanks, - keywords_score, - keywords_threshold, - keywords_file, + feat_config=feat_config, + model_config=model_config, + max_active_paths=max_active_paths, + num_tailing_blanks=num_tailing_blanks, + keywords_score=keywords_score, + keywords_threshold=keywords_threshold, + keywords_file=keywords_file, ) self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) @@ -121,7 +126,7 @@ def create_stream(self, keywords: Optional[str] = None): if keywords is None: return self.keyword_spotter.create_stream() else: - return self.keyword_spotter.create_stream(hotwords) + return self.keyword_spotter.create_stream(keywords) def decode_stream(self, s: OnlineStream): self.keyword_spotter.decode_stream(s)