Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Feb 29, 2024
1 parent 857c432 commit 248e692
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 21 deletions.
58 changes: 58 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
python3 sherpa-onnx/python/tests/test_text2token.py --verbose

rm -rf /tmp/sherpa-test-data

mkdir -p /tmp/onnx-models
dir=/tmp/onnx-models

log "Test keyword spotting models"

python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")

echo "sherpa_onnx version: $sherpa_onnx_version"

pwd
ls -lh

repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
log "Start testing ${repo}"

pushd $dir
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
popd

repo=$dir/$repo
ls -lh $repo

python3 ./python-api-examples/keyword_spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav

repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
log "Start testing ${repo}"

pushd $dir
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
popd

repo=$dir/$repo
ls -lh $repo

python3 ./python-api-examples/keyword_spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav \
$repo/test_wavs/5.wav

python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose

rm -r $dir
14 changes: 6 additions & 8 deletions python-api-examples/keyword-spotter-from-microphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def get_args():
)

parser.add_argument(
"--num-tailing-blanks",
"--num-trailing-blanks",
type=int,
default=1,
help="""The number of tailing blanks a keyword should be followed. Setting
help="""The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
Expand All @@ -110,7 +110,7 @@ def get_args():
parser.add_argument(
"--keywords-score",
type=float,
default=1.5,
default=1.0,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
Expand All @@ -120,7 +120,7 @@ def get_args():
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
default=0.25,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
Expand Down Expand Up @@ -163,15 +163,14 @@ def main():
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_tailing_blanks=args.num_tailing_blanks,
num_tailing_blanks=args.rnum_tailing_blanks,
provider=args.provider,
)

print("Started! Please speak")

sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = keyword_spotter.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
Expand All @@ -181,8 +180,7 @@ def main():
while keyword_spotter.is_ready(stream):
keyword_spotter.decode_stream(stream)
result = keyword_spotter.get_result(stream)
if last_result != result:
last_result = result
if result:
print("\r{}".format(result), end="", flush=True)


Expand Down
18 changes: 11 additions & 7 deletions python-api-examples/keyword-spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def get_args():
)

parser.add_argument(
"--num-tailing-blanks",
"--num-trailing-blanks",
type=int,
default=1,
help="""The number of tailing blanks a keyword should be followed. Setting
help="""The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
Expand All @@ -95,7 +95,7 @@ def get_args():
parser.add_argument(
"--keywords-score",
type=float,
default=1.5,
default=1.0,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
Expand All @@ -105,7 +105,7 @@ def get_args():
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
default=0.25,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
Expand Down Expand Up @@ -182,7 +182,7 @@ def main():
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_tailing_blanks=args.num_tailing_blanks,
num_trailing_blanks=args.num_trailing_blanks,
provider=args.provider,
)

Expand All @@ -208,15 +208,19 @@ def main():

streams.append(s)

results = [""] * len(streams)
while True:
ready_list = []
for s in streams:
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
results[i] += "f{r}/"
print(f"{r} is detected.")
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
results = [keyword_spotter.get_result(s) for s in streams]
end_time = time.time()
print("Done!")

Expand Down
12 changes: 6 additions & 6 deletions sherpa-onnx/python/sherpa_onnx/keyword_spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def __init__(
sample_rate: float = 16000,
feature_dim: int = 80,
max_active_paths: int = 4,
keywords_score: float = 1.5,
keywords_threshold: float = 0.35,
num_tailing_blanks: int = 1,
keywords_score: float = 1.0,
keywords_threshold: float = 0.25,
num_trailing_blanks: int = 1,
provider: str = "cpu",
):
"""
Expand Down Expand Up @@ -79,8 +79,8 @@ def __init__(
keywords_threshold:
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
num_tailing_blanks:
The number of tailing blanks a keyword should be followed. Setting
num_trailing_blanks:
The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
provider:
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
feat_config=feat_config,
model_config=model_config,
max_active_paths=max_active_paths,
num_tailing_blanks=num_tailing_blanks,
num_trailing_blanks=num_trailing_blanks,
keywords_score=keywords_score,
keywords_threshold=keywords_threshold,
keywords_file=keywords_file,
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ endfunction()
# please sort the files in alphabetic order
set(py_test_files
test_feature_extractor_config.py
test_keyword_spotter.py
test_offline_recognizer.py
test_online_recognizer.py
test_online_transducer_model_config.py
Expand Down
170 changes: 170 additions & 0 deletions sherpa-onnx/python/tests/test_keyword_spotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# sherpa-onnx/python/tests/test_keyword_spotter.py
#
# Copyright (c) 2024 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_keyword_spotter_py

import unittest
import wave
from pathlib import Path
from typing import Tuple

import numpy as np
import sherpa_onnx

d = "/tmp/onnx-models"
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
# to download pre-trained models for testing


def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""

with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)

samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()


class TestKeywordSpotter(unittest.TestCase):
def test_zipformer_transducer_en(self):
for use_int8 in [True, False]:
if use_int8:
encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
else:
encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"

tokens = (
f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt"
)
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav"
wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav"

if not Path(encoder).is_file():
print("skipping test_zipformer_transducer_en()")
return
keyword_spotter = sherpa_onnx.KeywordSpotter(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
num_threads=1,
keywords_file=keywords_file,
provider="cpu",
)
streams = []
waves = [wave0, wave1]
for wave in waves:
s = keyword_spotter.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)

tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)

results = [""] * len(streams)
while True:
ready_list = []
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
print(f"{r} is detected.")
results[i] += f"{r}/"
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result[0:-1]}")
print("-" * 10)

def test_zipformer_transducer_cn(self):
for use_int8 in [True, False]:
if use_int8:
encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
else:
encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"

tokens = (
f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
)
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav"
wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav"

if not Path(encoder).is_file():
print("skipping test_zipformer_transducer_cn()")
return
keyword_spotter = sherpa_onnx.KeywordSpotter(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
num_threads=1,
keywords_file=keywords_file,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = keyword_spotter.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)

tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)

results = [""] * len(streams)
while True:
ready_list = []
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
print(f"{r} is detected.")
results[i] += f"{r}/"
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result[0:-1]}")
print("-" * 10)


if __name__ == "__main__":
unittest.main()

0 comments on commit 248e692

Please sign in to comment.