Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python API for keyword spotting #576

Merged
merged 8 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions python-api-examples/keyword-spotter-from-microphone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#!/usr/bin/env python3

# Real-time keyword spotting from a microphone with sherpa-onnx Python API
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
# to download pre-trained models

import argparse
import sys
from pathlib import Path

from typing import List

try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)

import sherpa_onnx


def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
)


def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)

parser.add_argument(
"--encoder",
type=str,
help="Path to the transducer encoder model",
)

parser.add_argument(
"--decoder",
type=str,
help="Path to the transducer decoder model",
)

parser.add_argument(
"--joiner",
type=str,
help="Path to the transducer joiner model",
)

parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)

parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""
It specifies number of active paths to keep during decoding.
""",
)

parser.add_argument(
"--num-tailing-blanks",
type=int,
default=1,
help="""The number of tailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
)

parser.add_argument(
"--keywords-file",
type=str,
default="",
pkufool marked this conversation as resolved.
Show resolved Hide resolved
help="""
The file containing keywords, one words/phrases per line, and for each
pkufool marked this conversation as resolved.
Show resolved Hide resolved
phrase the bpe/cjkchar/pinyin are separated by a space. For example:

▁HE LL O ▁WORLD
x iǎo ài t óng x ué
""",
)

parser.add_argument(
"--keywords-score",
type=float,
default=1.5,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
""",
)

parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
""",
)

parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
pkufool marked this conversation as resolved.
Show resolved Hide resolved

return parser.parse_args()


def main():
args = get_args()

devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)

print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')

assert_file_exists(args.tokens)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)

assert Path(
args.keywords_file
).is_file(), (
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
)

keyword_spotter = sherpa_onnx.KeywordSpotter(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
max_active_paths=args.max_active_paths,
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_tailing_blanks=args.num_tailing_blanks,
provider=args.provider,
)

print("Started! Please speak")

sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
stream = keyword_spotter.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while keyword_spotter.is_ready(stream):
keyword_spotter.decode_stream(stream)
result = keyword_spotter.get_result(stream)
if last_result != result:
last_result = result
print("\r{}".format(result), end="", flush=True)


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
Loading