Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python API for keyword spotting #576

Merged
merged 8 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
python3 sherpa-onnx/python/tests/test_text2token.py --verbose

rm -rf /tmp/sherpa-test-data

mkdir -p /tmp/onnx-models
dir=/tmp/onnx-models

log "Test keyword spotting models"

python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")

echo "sherpa_onnx version: $sherpa_onnx_version"

pwd
ls -lh

repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
log "Start testing ${repo}"

pushd $dir
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
popd

repo=$dir/$repo
ls -lh $repo

python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav

repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
log "Start testing ${repo}"

pushd $dir
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
popd

repo=$dir/$repo
ls -lh $repo

python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav \
$repo/test_wavs/5.wav

python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose

rm -r $dir
191 changes: 191 additions & 0 deletions python-api-examples/keyword-spotter-from-microphone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#!/usr/bin/env python3

# Real-time keyword spotting from a microphone with sherpa-onnx Python API
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
# to download pre-trained models

import argparse
import sys
from pathlib import Path

from typing import List

try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)

import sherpa_onnx


def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
)


def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)

parser.add_argument(
"--encoder",
type=str,
help="Path to the transducer encoder model",
)

parser.add_argument(
"--decoder",
type=str,
help="Path to the transducer decoder model",
)

parser.add_argument(
"--joiner",
type=str,
help="Path to the transducer joiner model",
)

parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)

parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""
It specifies number of active paths to keep during decoding.
""",
)

parser.add_argument(
"--num-trailing-blanks",
type=int,
default=1,
help="""The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
)

parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
pkufool marked this conversation as resolved.
Show resolved Hide resolved
phrase the bpe/cjkchar/pinyin are separated by a space. For example:

▁HE LL O ▁WORLD
x iǎo ài t óng x ué
""",
)

parser.add_argument(
"--keywords-score",
type=float,
default=1.0,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
""",
)

parser.add_argument(
"--keywords-threshold",
type=float,
default=0.25,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
""",
)

return parser.parse_args()


def main():
args = get_args()

devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)

print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')

assert_file_exists(args.tokens)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)

assert Path(
args.keywords_file
).is_file(), (
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
)

keyword_spotter = sherpa_onnx.KeywordSpotter(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
max_active_paths=args.max_active_paths,
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_tailing_blanks=args.rnum_tailing_blanks,
provider=args.provider,
)

print("Started! Please speak")

sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
stream = keyword_spotter.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while keyword_spotter.is_ready(stream):
keyword_spotter.decode_stream(stream)
result = keyword_spotter.get_result(stream)
if result:
print("\r{}".format(result), end="", flush=True)


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
Loading
Loading