diff --git a/.github/scripts/test-nodejs-npm.sh b/.github/scripts/test-nodejs-npm.sh index 95dcf0271..a27214383 100755 --- a/.github/scripts/test-nodejs-npm.sh +++ b/.github/scripts/test-nodejs-npm.sh @@ -58,7 +58,6 @@ rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 node ./test-online-zipformer2-ctc.js rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 - curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 @@ -70,9 +69,9 @@ rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 tar xf vits-piper-en_US-amy-low.tar.bz2 node ./test-offline-tts-en.js -rm vits-piper-en_US-amy-low* +rm -rf vits-piper-en_US-amy-low* curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2 tar xvf vits-icefall-zh-aishell3.tar.bz2 node ./test-offline-tts-zh.js -rm vits-icefall-zh-aishell3* +rm -rf vits-icefall-zh-aishell3* diff --git a/.github/workflows/build-wheels-aarch64.yaml b/.github/workflows/build-wheels-aarch64.yaml index 17ff53d7e..4ecc7a415 100644 --- a/.github/workflows/build-wheels-aarch64.yaml +++ b/.github/workflows/build-wheels-aarch64.yaml @@ -59,8 +59,27 @@ jobs: run: | ls -lh ./wheelhouse/ + - name: Install patchelf + if: matrix.os == 'ubuntu-latest' + shell: bash + run: | + sudo apt-get update -q + sudo apt-get install -q -y patchelf + patchelf --help + + - name: Patch wheels + shell: bash + if: matrix.os == 'ubuntu-latest' + run: | + mkdir ./wheels + sudo ./scripts/wheel/patch_wheel.py --in-dir ./wheelhouse --out-dir ./wheels + + ls -lh ./wheels/ + rm -rf ./wheelhouse + mv ./wheels ./wheelhouse + - name: Publish to huggingface - if: matrix.python-version == 'cp38' && matrix.manylinux == 'manylinux2014' + if: (matrix.python-version == 'cp38' || matrix.python-version == 'cp39' ) && matrix.manylinux == 'manylinux2014' env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 9f8e6325f..369aaa8c5 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -186,7 +186,7 @@ class MainActivity : AppCompatActivity() { // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2 // modelDir = "vits-icefall-zh-aishell3" // modelName = "model.onnx" - // ruleFsts = "vits-icefall-zh-aishell3/phone.fst,vits-icefall-zh-aishell3/date.fst,vits-icefall-zh-aishell3/number.fst," + // ruleFsts = "vits-icefall-zh-aishell3/phone.fst,vits-icefall-zh-aishell3/date.fst,vits-icefall-zh-aishell3/number.fst,vits-icefall-zh-aishell3/new_heteronym.fst" // ruleFars = "vits-icefall-zh-aishell3/rule.far" // lexicon = "lexicon.txt" diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index ea52bdd64..75b09a5c5 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -67,6 +67,7 @@ def get_binaries(): "sherpa-onnx-alsa-offline", "sherpa-onnx-alsa-offline-speaker-identification", "sherpa-onnx-offline-tts-play-alsa", + "sherpa-onnx-vad-alsa", ] if is_windows(): diff --git a/cmake/openfst.cmake b/cmake/openfst.cmake index 575ea8aed..cb0826a98 100644 --- a/cmake/openfst.cmake +++ b/cmake/openfst.cmake @@ -75,6 +75,10 @@ function(download_openfst) set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst") set_target_properties(fstfar PROPERTIES OUTPUT_NAME "sherpa-onnx-fstfar") + if(LINUX) + target_compile_options(fst PUBLIC -Wno-missing-template-keyword) + endif() + target_include_directories(fst PUBLIC ${openfst_SOURCE_DIR}/src/include diff --git a/python-api-examples/vad-alsa.py b/python-api-examples/vad-alsa.py new file mode 100755 index 000000000..8f23d477e --- /dev/null +++ b/python-api-examples/vad-alsa.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +""" +This script works only on Linux. It uses ALSA for recording. +""" + +import argparse +from pathlib import Path + +import sherpa_onnx + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "--device-name", + type=str, + required=True, + help=""" +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and the device 0 on that card, please use: + + plughw:3,0 + +as the device_name. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + if not Path(args.silero_vad_model).is_file(): + raise RuntimeError( + f"{args.silero_vad_model} does not exist. Please download it from " + "https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx" + ) + + device_name = args.device_name + print(f"device_name: {device_name}") + alsa = sherpa_onnx.Alsa(device_name) + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_onnx.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + print("Started! Please speak. Press Ctrl C to exit") + + printed = False + k = 0 + try: + while True: + samples = alsa.read(samples_per_read) # a blocking read + + vad.accept_waveform(samples) + + if vad.is_speech_detected() and not printed: + print("Detected speech") + printed = True + + if not vad.is_speech_detected(): + printed = False + + while not vad.empty(): + samples = vad.front.samples + duration = len(samples) / sample_rate + filename = f"seg-{k}-{duration:.3f}-seconds.wav" + k += 1 + sherpa_onnx.write_wave(filename, samples, sample_rate) + print(f"Duration: {duration:.3f} seconds") + print(f"Saved to {filename}") + print("----------") + + vad.pop() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exit") + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/vad-microphone.py b/python-api-examples/vad-microphone.py new file mode 100755 index 000000000..85cde0830 --- /dev/null +++ b/python-api-examples/vad-microphone.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +import argparse +import os +import sys +from pathlib import Path + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_onnx + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + if not Path(args.silero_vad_model).is_file(): + raise RuntimeError( + f"{args.silero_vad_model} does not exist. Please download it from " + "https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx" + ) + + mic_sample_rate = 16000 + if "SHERPA_ONNX_MIC_SAMPLE_RATE" in os.environ: + mic_sample_rate = int(os.environ.get("SHERPA_ONNX_MIC_SAMPLE_RATE")) + print(f"Change microphone sample rate to {mic_sample_rate}") + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_onnx.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + # python3 -m sounddevice + # can also be used to list all devices + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + print( + "If you are using Linux and you are sure there is a microphone " + "on your system, please use " + "./vad-alsa.py" + ) + sys.exit(0) + + print(devices) + + if "SHERPA_ONNX_MIC_DEVICE" in os.environ: + input_device_idx = int(os.environ.get("SHERPA_ONNX_MIC_DEVICE")) + sd.default.device[0] = input_device_idx + print(f'Use selected device: {devices[input_device_idx]["name"]}') + else: + input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[input_device_idx]["name"]}') + + print("Started! Please speak. Press Ctrl C to exit") + + printed = False + k = 0 + try: + with sd.InputStream( + channels=1, dtype="float32", samplerate=mic_sample_rate + ) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + + if mic_sample_rate != sample_rate: + import librosa + + samples = librosa.resample( + samples, orig_sr=mic_sample_rate, target_sr=sample_rate + ) + + vad.accept_waveform(samples) + + if vad.is_speech_detected() and not printed: + print("Detected speech") + printed = True + + if not vad.is_speech_detected(): + printed = False + + while not vad.empty(): + samples = vad.front.samples + duration = len(samples) / sample_rate + filename = f"seg-{k}-{duration:.3f}-seconds.wav" + k += 1 + sherpa_onnx.write_wave(filename, samples, sample_rate) + print(f"Duration: {duration:.3f} seconds") + print(f"Saved to {filename}") + print("----------") + + vad.pop() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exit") + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/vad-remove-non-speech-segments-alsa.py b/python-api-examples/vad-remove-non-speech-segments-alsa.py new file mode 100755 index 000000000..34f88e40f --- /dev/null +++ b/python-api-examples/vad-remove-non-speech-segments-alsa.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +""" +This file shows how to remove non-speech segments +and merge all speech segments into a large segment +and save it to a file. + +Different from ./vad-remove-non-speech-segments.py, this file supports only +Linux. + +Usage + +python3 ./vad-remove-non-speech-segments-alsa.py \ + --silero-vad-model silero_vad.onnx + +Please visit +https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx +""" + +import argparse +import time +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "--device-name", + type=str, + required=True, + help=""" +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and the device 0 on that card, please use: + + plughw:3,0 + +as the device_name. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + assert_file_exists(args.silero_vad_model) + + device_name = args.device_name + print(f"device_name: {device_name}") + alsa = sherpa_onnx.Alsa(device_name) + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_onnx.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + window_size = config.silero_vad.window_size + + buffer = [] + vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + all_samples = [] + + print("Started! Please speak. Press Ctrl C to exit") + + try: + while True: + samples = alsa.read(samples_per_read) # a blocking read + samples = np.array(samples) + + buffer = np.concatenate([buffer, samples]) + + all_samples = np.concatenate([all_samples, samples]) + + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Saving & Exiting") + + speech_samples = [] + while not vad.empty(): + speech_samples.extend(vad.front.samples) + vad.pop() + + speech_samples = np.array(speech_samples, dtype=np.float32) + + filename_for_speech = time.strftime("%Y%m%d-%H%M%S-speech.wav") + sf.write(filename_for_speech, speech_samples, samplerate=sample_rate) + + filename_for_all = time.strftime("%Y%m%d-%H%M%S-all.wav") + sf.write(filename_for_all, all_samples, samplerate=sample_rate) + + print(f"Saved to {filename_for_speech} and {filename_for_all}") + + +if __name__ == "__main__": + main() diff --git a/python-api-examples/vad-remove-non-speech-segments.py b/python-api-examples/vad-remove-non-speech-segments.py index e55d88b07..e242801aa 100755 --- a/python-api-examples/vad-remove-non-speech-segments.py +++ b/python-api-examples/vad-remove-non-speech-segments.py @@ -66,6 +66,11 @@ def main(): devices = sd.query_devices() if len(devices) == 0: print("No microphone devices found") + print( + "If you are using Linux and you are sure there is a microphone " + "on your system, please use " + "./vad-remove-non-speech-segments-alsa.py" + ) sys.exit(0) print(devices) @@ -89,7 +94,7 @@ def main(): all_samples = [] - print("Started! Please speak") + print("Started! Please speak. Press Ctrl C to exit") try: with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 423a777f7..bedd1ed2a 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -251,6 +251,7 @@ if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc) add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) + add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc) if(SHERPA_ONNX_ENABLE_TTS) @@ -259,9 +260,10 @@ if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) set(exes sherpa-onnx-alsa - sherpa-onnx-keyword-spotter-alsa sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline-speaker-identification + sherpa-onnx-keyword-spotter-alsa + sherpa-onnx-vad-alsa ) if(SHERPA_ONNX_ENABLE_TTS) diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-alsa.cc new file mode 100644 index 000000000..31a3f39b0 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-alsa.cc @@ -0,0 +1,132 @@ +// sherpa-onnx/csrc/sherpa-onnx-vad-alsa.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include + +#include "sherpa-onnx/csrc/alsa.h" +#include "sherpa-onnx/csrc/circular-buffer.h" +#include "sherpa-onnx/csrc/voice-activity-detector.h" +#include "sherpa-onnx/csrc/wave-writer.h" + +bool stop = false; +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use VAD in sherpa-onnx. + + ./bin/sherpa-onnx-vad-alsa \ + --silero-vad-model=/path/to/silero_vad.onnx \ + device_name + +Please download silero_vad.onnx from +https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + +For instance, use +wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and the device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::VadModelConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + std::string device_name = po.GetArg(1); + sherpa_onnx::Alsa alsa(device_name.c_str()); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + int32_t sample_rate = 16000; + + if (alsa.GetExpectedSampleRate() != sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + auto vad = std::make_unique(config); + + fprintf(stderr, "Started. Please speak\n"); + + int32_t window_size = config.silero_vad.window_size; + bool printed = false; + + int32_t k = 0; + while (!stop) { + { + const std::vector &samples = alsa.Read(chunk); + + vad->AcceptWaveform(samples.data(), samples.size()); + + if (vad->IsSpeechDetected() && !printed) { + printed = true; + fprintf(stderr, "\nDetected speech!\n"); + } + if (!vad->IsSpeechDetected()) { + printed = false; + } + + while (!vad->Empty()) { + const auto &segment = vad->Front(); + float duration = + segment.samples.size() / static_cast(sample_rate); + + fprintf(stderr, "Duration: %.3f seconds\n", duration); + + char filename[128]; + snprintf(filename, sizeof(filename), "seg-%d-%.3fs.wav", k, duration); + k += 1; + sherpa_onnx::WriteWave(filename, 16000, segment.samples.data(), + segment.samples.size()); + fprintf(stderr, "Saved to %s\n", filename); + fprintf(stderr, "----------\n"); + + vad->Pop(); + } + } + } + + return 0; +} diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc index 19dd1d85f..da013b9e8 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc @@ -13,6 +13,7 @@ #include "sherpa-onnx/csrc/circular-buffer.h" #include "sherpa-onnx/csrc/microphone.h" #include "sherpa-onnx/csrc/voice-activity-detector.h" +#include "sherpa-onnx/csrc/wave-writer.h" bool stop = false; std::mutex mutex; @@ -122,6 +123,7 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx int32_t window_size = config.silero_vad.window_size; bool printed = false; + int32_t k = 0; while (!stop) { { std::lock_guard lock(mutex); @@ -140,9 +142,19 @@ wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx } while (!vad->Empty()) { - float duration = vad->Front().samples.size() / sample_rate; - vad->Pop(); + const auto &segment = vad->Front(); + float duration = segment.samples.size() / sample_rate; fprintf(stderr, "Duration: %.3f seconds\n", duration); + + char filename[128]; + snprintf(filename, sizeof(filename), "seg-%d-%.3fs.wav", k, duration); + k += 1; + sherpa_onnx::WriteWave(filename, 16000, segment.samples.data(), + segment.samples.size()); + fprintf(stderr, "Saved to %s\n", filename); + fprintf(stderr, "----------\n"); + + vad->Pop(); } } } diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 53aebd78c..12409a9be 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -35,6 +35,7 @@ set(srcs vad-model-config.cc vad-model.cc voice-activity-detector.cc + wave-writer.cc ) if(SHERPA_ONNX_HAS_ALSA) list(APPEND srcs ${CMAKE_SOURCE_DIR}/sherpa-onnx/csrc/alsa.cc alsa.cc) diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 4952e150b..8a5ae5cd3 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -26,6 +26,7 @@ #include "sherpa-onnx/python/csrc/vad-model-config.h" #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" +#include "sherpa-onnx/python/csrc/wave-writer.h" #if SHERPA_ONNX_ENABLE_TTS == 1 #include "sherpa-onnx/python/csrc/offline-tts.h" @@ -36,6 +37,8 @@ namespace sherpa_onnx { PYBIND11_MODULE(_sherpa_onnx, m) { m.doc() = "pybind11 binding of sherpa-onnx"; + PybindWaveWriter(&m); + PybindFeatures(&m); PybindOnlineCtcFstDecoderConfig(&m); PybindOnlineModelConfig(&m); diff --git a/sherpa-onnx/python/csrc/wave-writer.cc b/sherpa-onnx/python/csrc/wave-writer.cc new file mode 100644 index 000000000..6ec4d65df --- /dev/null +++ b/sherpa-onnx/python/csrc/wave-writer.cc @@ -0,0 +1,27 @@ +// sherpa-onnx/python/csrc/wave-writer.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/wave-writer.h" + +#include +#include + +#include "sherpa-onnx/csrc/wave-writer.h" + +namespace sherpa_onnx { + +void PybindWaveWriter(py::module *m) { + m->def( + "write_wave", + [](const std::string &filename, const std::vector &samples, + int32_t sample_rate) -> bool { + bool ok = + WriteWave(filename, sample_rate, samples.data(), samples.size()); + + return ok; + }, + py::arg("filename"), py::arg("samples"), py::arg("sample_rate")); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/wave-writer.h b/sherpa-onnx/python/csrc/wave-writer.h new file mode 100644 index 000000000..c8ab58d5b --- /dev/null +++ b/sherpa-onnx/python/csrc/wave-writer.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/wave-writer.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_WAVE_WRITER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_WAVE_WRITER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindWaveWriter(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_WAVE_WRITER_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 1f98bef69..2282687ea 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -19,6 +19,7 @@ VadModel, VadModelConfig, VoiceActivityDetector, + write_wave, ) from .keyword_spotter import KeywordSpotter