Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python API and Python examples for audio tagging #753

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test audio tagging"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
python3 ./python-api-examples/audio-tagging-from-a-file.py
rm -rf sherpa-onnx-zipformer-audio-tagging-2024-04-09


log "test streaming zipformer2 ctc HLG decoding"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
Expand Down
1 change: 0 additions & 1 deletion build-ios.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ fi
if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then
mkdir -p $onnxruntime_dir
pushd $onnxruntime_dir
# rm -f onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
Expand Down
1 change: 0 additions & 1 deletion cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ function(download_kaldi_native_fbank)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
# set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
Expand Down
121 changes: 121 additions & 0 deletions python-api-examples/audio-tagging-from-a-file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3

"""
This script shows how to use audio tagging Python APIs to tag a file.

Please read the code to download the required model files and test wave file.
"""

import logging
import time
from pathlib import Path

import numpy as np
import sherpa_onnx
import soundfile as sf


def read_test_wave():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"

if not Path(test_wave).is_file():
raise ValueError(
f"Please download {test_wave} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)

# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
data, sample_rate = sf.read(
test_wave,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)

# samples is a 1-d array of dtype float32
# sample_rate is a scalar
return samples, sample_rate


def create_audio_tagger():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx"
label_file = (
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"
)

if not Path(model_file).is_file():
raise ValueError(
f"Please download {model_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)

if not Path(label_file).is_file():
raise ValueError(
f"Please download {label_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)

config = sherpa_onnx.AudioTaggingConfig(
model=sherpa_onnx.AudioTaggingModelConfig(
zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
model=model_file,
),
num_threads=1,
debug=True,
provider="cpu",
),
labels=label_file,
top_k=5,
)
if not config.validate():
raise ValueError(f"Please check the config: {config}")

print(config)

return sherpa_onnx.AudioTagging(config)


def main():
logging.info("Create audio tagger")
audio_tagger = create_audio_tagger()

logging.info("Read test wave")
samples, sample_rate = read_test_wave()

logging.info("Computing")

start_time = time.time()

stream = audio_tagger.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
result = audio_tagger.compute(stream)
end_time = time.time()

elapsed_seconds = end_time - start_time
audio_duration = len(samples) / sample_rate

real_time_factor = elapsed_seconds / audio_duration
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
logging.info(
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)

s = "\n"
for i, e in enumerate(result):
s += f"{i}: {e}\n"

logging.info(s)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

main()
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include_directories(${CMAKE_SOURCE_DIR})

set(srcs
audio-tagging.cc
circular-buffer.cc
display.cc
endpoint.cc
Expand Down
87 changes: 87 additions & 0 deletions sherpa-onnx/python/csrc/audio-tagging.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// sherpa-onnx/python/csrc/audio-tagging.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/audio-tagging.h"

#include <string>

#include "sherpa-onnx/csrc/audio-tagging.h"

namespace sherpa_onnx {

static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) {
using PyClass = OfflineZipformerAudioTaggingModelConfig;
py::class_<PyClass>(*m, "OfflineZipformerAudioTaggingModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindAudioTaggingModelConfig(py::module *m) {
PybindOfflineZipformerAudioTaggingModelConfig(m);

using PyClass = AudioTaggingModelConfig;

py::class_<PyClass>(*m, "AudioTaggingModelConfig")
.def(py::init<>())
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
bool, const std::string &>(),
py::arg("zipformer"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("zipformer", &PyClass::zipformer)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindAudioTaggingConfig(py::module *m) {
PybindAudioTaggingModelConfig(m);

using PyClass = AudioTaggingConfig;

py::class_<PyClass>(*m, "AudioTaggingConfig")
.def(py::init<>())
.def(py::init<const AudioTaggingModelConfig &, const std::string &,
int32_t>(),
py::arg("model"), py::arg("labels"), py::arg("top_k") = 5)
.def_readwrite("model", &PyClass::model)
.def_readwrite("labels", &PyClass::labels)
.def_readwrite("top_k", &PyClass::top_k)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindAudioEvent(py::module *m) {
using PyClass = AudioEvent;

py::class_<PyClass>(*m, "AudioEvent")
.def_property_readonly(
"name", [](const PyClass &self) -> std::string { return self.name; })
.def_property_readonly(
"index", [](const PyClass &self) -> int32_t { return self.index; })
.def_property_readonly(
"prob", [](const PyClass &self) -> float { return self.prob; })
.def("__str__", &PyClass::ToString);
}

void PybindAudioTagging(py::module *m) {
PybindAudioTaggingConfig(m);
PybindAudioEvent(m);

using PyClass = AudioTagging;

py::class_<PyClass>(*m, "AudioTagging")
.def(py::init<const AudioTaggingConfig &>(), py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1,
py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/audio-tagging.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/audio-tagging.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindAudioTagging(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
2 changes: 1 addition & 1 deletion sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string, float, float,
const std::string &, const std::string &, float, float,
float>(),
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
py::arg("data_dir") = "", py::arg("noise_scale") = 0.667,
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

#include "sherpa-onnx/python/csrc/alsa.h"
#include "sherpa-onnx/python/csrc/audio-tagging.h"
#include "sherpa-onnx/python/csrc/circular-buffer.h"
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
Expand Down Expand Up @@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
m.doc() = "pybind11 binding of sherpa-onnx";

PybindWaveWriter(&m);
PybindAudioTagging(&m);

PybindFeatures(&m);
PybindOnlineCtcFstDecoderConfig(&m);
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
using PyClass = SpeakerEmbeddingExtractorConfig;
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
.def(py::init<>())
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
.def(py::init<const std::string &, int32_t, bool, const std::string &>(),
py::arg("model"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("model", &PyClass::model)
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/python/csrc/spoken-language-identification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
.def(py::init<>())
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
bool, const std::string>(),
bool, const std::string &>(),
py::arg("whisper"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("whisper", &PyClass::whisper)
Expand All @@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) {
py::arg("config"), py::call_guard<py::gil_scoped_release>())
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute,
.def("compute", &PyClass::Compute, py::arg("s"),
py::call_guard<py::gil_scoped_release>());
}

Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from _sherpa_onnx import (
Alsa,
AudioEvent,
AudioTagging,
AudioTaggingConfig,
AudioTaggingModelConfig,
CircularBuffer,
Display,
OfflineStream,
OfflineTts,
OfflineTtsConfig,
OfflineTtsModelConfig,
OfflineTtsVitsModelConfig,
OfflineZipformerAudioTaggingModelConfig,
OnlineStream,
SileroVadModelConfig,
SpeakerEmbeddingExtractor,
Expand Down
Loading