diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index f63c2de66..1273494d2 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data python3 sherpa-onnx/python/tests/test_text2token.py --verbose rm -rf /tmp/sherpa-test-data + +mkdir -p /tmp/onnx-models +dir=/tmp/onnx-models + +log "Test keyword spotting models" + +python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" +sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)") + +echo "sherpa_onnx version: $sherpa_onnx_version" + +pwd +ls -lh + +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 +log "Start testing ${repo}" + +pushd $dir +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz +tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz +popd + +repo=$dir/$repo +ls -lh $repo + +python3 ./python-api-examples/keyword_spotter.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ + --keywords-file=$repo/test_wavs/test_keywords.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav + +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 +log "Start testing ${repo}" + +pushd $dir +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz +tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz +popd + +repo=$dir/$repo +ls -lh $repo + +python3 ./python-api-examples/keyword_spotter.py \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ + --keywords-file=$repo/test_wavs/test_keywords.txt \ + $repo/test_wavs/3.wav \ + $repo/test_wavs/4.wav \ + $repo/test_wavs/5.wav + +python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose + +rm -r $dir diff --git a/python-api-examples/keyword-spotter-from-microphone.py b/python-api-examples/keyword-spotter-from-microphone.py index 4b0be3159..5a0ebafe7 100755 --- a/python-api-examples/keyword-spotter-from-microphone.py +++ b/python-api-examples/keyword-spotter-from-microphone.py @@ -86,10 +86,10 @@ def get_args(): ) parser.add_argument( - "--num-tailing-blanks", + "--num-trailing-blanks", type=int, default=1, - help="""The number of tailing blanks a keyword should be followed. Setting + help="""The number of trailing blanks a keyword should be followed. Setting to a larger value (e.g. 8) when your keywords has overlapping tokens between each other. """, @@ -110,7 +110,7 @@ def get_args(): parser.add_argument( "--keywords-score", type=float, - default=1.5, + default=1.0, help=""" The boosting score of each token for keywords. The larger the easier to survive beam search. @@ -120,7 +120,7 @@ def get_args(): parser.add_argument( "--keywords-threshold", type=float, - default=0.35, + default=0.25, help=""" The trigger threshold (i.e. probability) of the keyword. The larger the harder to trigger. @@ -163,7 +163,7 @@ def main(): keywords_file=args.keywords_file, keywords_score=args.keywords_score, keywords_threshold=args.keywords_threshold, - num_tailing_blanks=args.num_tailing_blanks, + num_tailing_blanks=args.rnum_tailing_blanks, provider=args.provider, ) @@ -171,7 +171,6 @@ def main(): sample_rate = 16000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms - last_result = "" stream = keyword_spotter.create_stream() with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: while True: @@ -181,8 +180,7 @@ def main(): while keyword_spotter.is_ready(stream): keyword_spotter.decode_stream(stream) result = keyword_spotter.get_result(stream) - if last_result != result: - last_result = result + if result: print("\r{}".format(result), end="", flush=True) diff --git a/python-api-examples/keyword-spotter.py b/python-api-examples/keyword-spotter.py index 64debbddb..643718f2d 100755 --- a/python-api-examples/keyword-spotter.py +++ b/python-api-examples/keyword-spotter.py @@ -71,10 +71,10 @@ def get_args(): ) parser.add_argument( - "--num-tailing-blanks", + "--num-trailing-blanks", type=int, default=1, - help="""The number of tailing blanks a keyword should be followed. Setting + help="""The number of trailing blanks a keyword should be followed. Setting to a larger value (e.g. 8) when your keywords has overlapping tokens between each other. """, @@ -95,7 +95,7 @@ def get_args(): parser.add_argument( "--keywords-score", type=float, - default=1.5, + default=1.0, help=""" The boosting score of each token for keywords. The larger the easier to survive beam search. @@ -105,7 +105,7 @@ def get_args(): parser.add_argument( "--keywords-threshold", type=float, - default=0.35, + default=0.25, help=""" The trigger threshold (i.e. probability) of the keyword. The larger the harder to trigger. @@ -182,7 +182,7 @@ def main(): keywords_file=args.keywords_file, keywords_score=args.keywords_score, keywords_threshold=args.keywords_threshold, - num_tailing_blanks=args.num_tailing_blanks, + num_trailing_blanks=args.num_trailing_blanks, provider=args.provider, ) @@ -208,15 +208,19 @@ def main(): streams.append(s) + results = [""] * len(streams) while True: ready_list = [] - for s in streams: + for i, s in enumerate(streams): if keyword_spotter.is_ready(s): ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + results[i] += "f{r}/" + print(f"{r} is detected.") if len(ready_list) == 0: break keyword_spotter.decode_streams(ready_list) - results = [keyword_spotter.get_result(s) for s in streams] end_time = time.time() print("Done!") diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index 8373dd091..218628ea9 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -37,9 +37,9 @@ def __init__( sample_rate: float = 16000, feature_dim: int = 80, max_active_paths: int = 4, - keywords_score: float = 1.5, - keywords_threshold: float = 0.35, - num_tailing_blanks: int = 1, + keywords_score: float = 1.0, + keywords_threshold: float = 0.25, + num_trailing_blanks: int = 1, provider: str = "cpu", ): """ @@ -79,8 +79,8 @@ def __init__( keywords_threshold: The trigger threshold (i.e. probability) of the keyword. The larger the harder to trigger. - num_tailing_blanks: - The number of tailing blanks a keyword should be followed. Setting + num_trailing_blanks: + The number of trailing blanks a keyword should be followed. Setting to a larger value (e.g. 8) when your keywords has overlapping tokens between each other. provider: @@ -115,7 +115,7 @@ def __init__( feat_config=feat_config, model_config=model_config, max_active_paths=max_active_paths, - num_tailing_blanks=num_tailing_blanks, + num_trailing_blanks=num_trailing_blanks, keywords_score=keywords_score, keywords_threshold=keywords_threshold, keywords_file=keywords_file, diff --git a/sherpa-onnx/python/tests/CMakeLists.txt b/sherpa-onnx/python/tests/CMakeLists.txt index e99636e2b..c82edc612 100644 --- a/sherpa-onnx/python/tests/CMakeLists.txt +++ b/sherpa-onnx/python/tests/CMakeLists.txt @@ -20,6 +20,7 @@ endfunction() # please sort the files in alphabetic order set(py_test_files test_feature_extractor_config.py + test_keyword_spotter.py test_offline_recognizer.py test_online_recognizer.py test_online_transducer_model_config.py diff --git a/sherpa-onnx/python/tests/test_keyword_spotter.py b/sherpa-onnx/python/tests/test_keyword_spotter.py new file mode 100755 index 000000000..bdefa5d10 --- /dev/null +++ b/sherpa-onnx/python/tests/test_keyword_spotter.py @@ -0,0 +1,170 @@ +# sherpa-onnx/python/tests/test_keyword_spotter.py +# +# Copyright (c) 2024 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_keyword_spotter_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + +d = "/tmp/onnx-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestKeywordSpotter(unittest.TestCase): + def test_zipformer_transducer_en(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + + tokens = ( + f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt" + ) + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav" + wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav" + + if not Path(encoder).is_file(): + print("skipping test_zipformer_transducer_en()") + return + keyword_spotter = sherpa_onnx.KeywordSpotter( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + keywords_file=keywords_file, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1] + for wave in waves: + s = keyword_spotter.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + print(f"{r} is detected.") + results[i] += f"{r}/" + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result[0:-1]}") + print("-" * 10) + + def test_zipformer_transducer_cn(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + else: + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + + tokens = ( + f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" + ) + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" + wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav" + wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav" + + if not Path(encoder).is_file(): + print("skipping test_zipformer_transducer_cn()") + return + keyword_spotter = sherpa_onnx.KeywordSpotter( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + keywords_file=keywords_file, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = keyword_spotter.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + print(f"{r} is detected.") + results[i] += f"{r}/" + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result[0:-1]}") + print("-" * 10) + + +if __name__ == "__main__": + unittest.main()