From 5b73b9a9ef6c103497837c1a5b0db6f5232816e4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 11 Apr 2024 11:00:34 +0800 Subject: [PATCH 1/3] Add audio tagging example for Python API --- cmake/kaldi-native-fbank.cmake | 1 - .../audio-tagging-from-a-file.py | 121 ++++++++++++++++++ sherpa-onnx/python/csrc/CMakeLists.txt | 1 + sherpa-onnx/python/csrc/audio-tagging.cc | 87 +++++++++++++ sherpa-onnx/python/csrc/audio-tagging.h | 16 +++ .../csrc/offline-tts-vits-model-config.cc | 2 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 2 + .../csrc/speaker-embedding-extractor.cc | 2 +- .../csrc/spoken-language-identification.cc | 4 +- sherpa-onnx/python/sherpa_onnx/__init__.py | 5 + 10 files changed, 236 insertions(+), 5 deletions(-) create mode 100755 python-api-examples/audio-tagging-from-a-file.py create mode 100644 sherpa-onnx/python/csrc/audio-tagging.cc create mode 100644 sherpa-onnx/python/csrc/audio-tagging.h diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index c77ec5fc6..ce76745ed 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -3,7 +3,6 @@ function(download_kaldi_native_fbank) set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz") set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz") -# set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz") set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) diff --git a/python-api-examples/audio-tagging-from-a-file.py b/python-api-examples/audio-tagging-from-a-file.py new file mode 100755 index 000000000..e5cb9feb7 --- /dev/null +++ b/python-api-examples/audio-tagging-from-a-file.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use audio tagging Python APIs to tag a file. + +Please read the code to download the required model files and test wave file. +""" + +import logging +import time +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def read_test_wave(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav" + + if not Path(test_wave).is_file(): + raise ValueError( + f"Please download {test_wave} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read + data, sample_rate = sf.read( + test_wave, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + + # samples is a 1-d array of dtype float32 + # sample_rate is a scalar + return samples, sample_rate + + +def create_audio_tagger(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx" + label_file = ( + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv" + ) + + if not Path(model_file).is_file(): + raise ValueError( + f"Please download {model_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + if not Path(label_file).is_file(): + raise ValueError( + f"Please download {label_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + config = sherpa_onnx.AudioTaggingConfig( + model=sherpa_onnx.AudioTaggingModelConfig( + zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig( + model=model_file, + ), + num_threads=1, + debug=True, + provider="cpu", + ), + labels=label_file, + top_k=5, + ) + if not config.validate(): + raise ValueError(f"Please check the config: {config}") + + print(config) + + return sherpa_onnx.AudioTagging(config) + + +def main(): + logging.info("Create audio tagger") + audio_tagger = create_audio_tagger() + + logging.info("Read test wave") + samples, sample_rate = read_test_wave() + + logging.info("Computing") + + start_time = time.time() + + stream = audio_tagger.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + result = audio_tagger.compute(stream) + end_time = time.time() + + elapsed_seconds = end_time - start_time + audio_duration = len(samples) / sample_rate + + real_time_factor = elapsed_seconds / audio_duration + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + s = "\n" + for i, e in enumerate(result): + s += f"{i}: {e}\n" + + logging.info(s) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 12409a9be..266b7c312 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${CMAKE_SOURCE_DIR}) set(srcs + audio-tagging.cc circular-buffer.cc display.cc endpoint.cc diff --git a/sherpa-onnx/python/csrc/audio-tagging.cc b/sherpa-onnx/python/csrc/audio-tagging.cc new file mode 100644 index 000000000..170bbc6c2 --- /dev/null +++ b/sherpa-onnx/python/csrc/audio-tagging.cc @@ -0,0 +1,87 @@ +// sherpa-onnx/python/csrc/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/audio-tagging.h" + +#include + +#include "sherpa-onnx/csrc/audio-tagging.h" + +namespace sherpa_onnx { + +static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) { + using PyClass = OfflineZipformerAudioTaggingModelConfig; + py::class_(*m, "OfflineZipformerAudioTaggingModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindAudioTaggingModelConfig(py::module *m) { + PybindOfflineZipformerAudioTaggingModelConfig(m); + + using PyClass = AudioTaggingModelConfig; + + py::class_(*m, "AudioTaggingModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("zipformer"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("zipformer", &PyClass::zipformer) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindAudioTaggingConfig(py::module *m) { + PybindAudioTaggingModelConfig(m); + + using PyClass = AudioTaggingConfig; + + py::class_(*m, "AudioTaggingConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("labels"), py::arg("top_k") = 5) + .def_readwrite("model", &PyClass::model) + .def_readwrite("labels", &PyClass::labels) + .def_readwrite("top_k", &PyClass::top_k) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindAudioEvent(py::module *m) { + using PyClass = AudioEvent; + + py::class_(*m, "AudioEvent") + .def_property_readonly( + "name", [](const PyClass &self) -> std::string { return self.name; }) + .def_property_readonly( + "index", [](const PyClass &self) -> int32_t { return self.index; }) + .def_property_readonly( + "prob", [](const PyClass &self) -> float { return self.prob; }) + .def("__str__", &PyClass::ToString); +} + +void PybindAudioTagging(py::module *m) { + PybindAudioTaggingConfig(m); + PybindAudioEvent(m); + + using PyClass = AudioTagging; + + py::class_(*m, "AudioTagging") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1, + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/audio-tagging.h b/sherpa-onnx/python/csrc/audio-tagging.h new file mode 100644 index 000000000..1cf3eaefb --- /dev/null +++ b/sherpa-onnx/python/csrc/audio-tagging.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/audio-tagging.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ +#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindAudioTagging(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ diff --git a/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc index 6e016715d..c88c92e0b 100644 --- a/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc @@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) { py::class_(*m, "OfflineTtsVitsModelConfig") .def(py::init<>()) .def(py::init(), py::arg("model"), py::arg("lexicon"), py::arg("tokens"), py::arg("data_dir") = "", py::arg("noise_scale") = 0.667, diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 8a5ae5cd3..31dd9bafd 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/python/csrc/sherpa-onnx.h" #include "sherpa-onnx/python/csrc/alsa.h" +#include "sherpa-onnx/python/csrc/audio-tagging.h" #include "sherpa-onnx/python/csrc/circular-buffer.h" #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" @@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { m.doc() = "pybind11 binding of sherpa-onnx"; PybindWaveWriter(&m); + PybindAudioTagging(&m); PybindFeatures(&m); PybindOnlineCtcFstDecoderConfig(&m); diff --git a/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc index 2749ba3bd..e5703caa6 100644 --- a/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc +++ b/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc @@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) { using PyClass = SpeakerEmbeddingExtractorConfig; py::class_(*m, "SpeakerEmbeddingExtractorConfig") .def(py::init<>()) - .def(py::init(), + .def(py::init(), py::arg("model"), py::arg("num_threads") = 1, py::arg("debug") = false, py::arg("provider") = "cpu") .def_readwrite("model", &PyClass::model) diff --git a/sherpa-onnx/python/csrc/spoken-language-identification.cc b/sherpa-onnx/python/csrc/spoken-language-identification.cc index f528e5561..b49f9a9bc 100644 --- a/sherpa-onnx/python/csrc/spoken-language-identification.cc +++ b/sherpa-onnx/python/csrc/spoken-language-identification.cc @@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) { py::class_(*m, "SpokenLanguageIdentificationConfig") .def(py::init<>()) .def(py::init(), + bool, const std::string &>(), py::arg("whisper"), py::arg("num_threads") = 1, py::arg("debug") = false, py::arg("provider") = "cpu") .def_readwrite("whisper", &PyClass::whisper) @@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) { py::arg("config"), py::call_guard()) .def("create_stream", &PyClass::CreateStream, py::call_guard()) - .def("compute", &PyClass::Compute, + .def("compute", &PyClass::Compute, py::arg("s"), py::call_guard()); } diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 2282687ea..2b1607312 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,5 +1,9 @@ from _sherpa_onnx import ( Alsa, + AudioEvent, + AudioTagging, + AudioTaggingConfig, + AudioTaggingModelConfig, CircularBuffer, Display, OfflineStream, @@ -7,6 +11,7 @@ OfflineTtsConfig, OfflineTtsModelConfig, OfflineTtsVitsModelConfig, + OfflineZipformerAudioTaggingModelConfig, OnlineStream, SileroVadModelConfig, SpeakerEmbeddingExtractor, From 383c4b7c1ccdd8907d82716e19f2c8e18beca652 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 11 Apr 2024 11:03:29 +0800 Subject: [PATCH 2/3] Add CI for Python audio tagging examples --- .github/scripts/test-python.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 3604a0059..aa9b795f1 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,15 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test audio tagging" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + python3 ./python-api-examples/audio-tagging-from-a-file.py +rm -rf sherpa-onnx-zipformer-audio-tagging-2024-04-09 + + log "test streaming zipformer2 ctc HLG decoding" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 From 330cf8c3cdb06baf4dfdf460cb8f942a4e257963 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 11 Apr 2024 11:08:31 +0800 Subject: [PATCH 3/3] remove unused code --- build-ios.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/build-ios.sh b/build-ios.sh index 265f50645..ac81b4fb1 100755 --- a/build-ios.sh +++ b/build-ios.sh @@ -17,7 +17,6 @@ fi if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then mkdir -p $onnxruntime_dir pushd $onnxruntime_dir -# rm -f onnxruntime.xcframework-${onnxruntime_version}.tar.bz2 wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2 tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2 rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2