Skip to content

Commit

Permalink
Add Python API and Python examples for audio tagging (#753)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Apr 11, 2024
1 parent 904a3cc commit 34d70a2
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 6 deletions.
9 changes: 9 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test audio tagging"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
python3 ./python-api-examples/audio-tagging-from-a-file.py
rm -rf sherpa-onnx-zipformer-audio-tagging-2024-04-09


log "test streaming zipformer2 ctc HLG decoding"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
Expand Down
1 change: 0 additions & 1 deletion build-ios.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ fi
if [ ! -f $onnxruntime_dir/onnxruntime.xcframework/ios-arm64/onnxruntime.a ]; then
mkdir -p $onnxruntime_dir
pushd $onnxruntime_dir
# rm -f onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
wget -c https://${SHERPA_ONNX_GITHUB}/csukuangfj/onnxruntime-libs/releases/download/v${onnxruntime_version}/onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
tar xvf onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
rm onnxruntime.xcframework-${onnxruntime_version}.tar.bz2
Expand Down
1 change: 0 additions & 1 deletion cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ function(download_kaldi_native_fbank)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://hub.nuaa.cf/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
# set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
Expand Down
121 changes: 121 additions & 0 deletions python-api-examples/audio-tagging-from-a-file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3

"""
This script shows how to use audio tagging Python APIs to tag a file.
Please read the code to download the required model files and test wave file.
"""

import logging
import time
from pathlib import Path

import numpy as np
import sherpa_onnx
import soundfile as sf


def read_test_wave():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"

if not Path(test_wave).is_file():
raise ValueError(
f"Please download {test_wave} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)

# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
data, sample_rate = sf.read(
test_wave,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)

# samples is a 1-d array of dtype float32
# sample_rate is a scalar
return samples, sample_rate


def create_audio_tagger():
# Please download the model files and test wave files from
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx"
label_file = (
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"
)

if not Path(model_file).is_file():
raise ValueError(
f"Please download {model_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)

if not Path(label_file).is_file():
raise ValueError(
f"Please download {label_file} from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
)

config = sherpa_onnx.AudioTaggingConfig(
model=sherpa_onnx.AudioTaggingModelConfig(
zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig(
model=model_file,
),
num_threads=1,
debug=True,
provider="cpu",
),
labels=label_file,
top_k=5,
)
if not config.validate():
raise ValueError(f"Please check the config: {config}")

print(config)

return sherpa_onnx.AudioTagging(config)


def main():
logging.info("Create audio tagger")
audio_tagger = create_audio_tagger()

logging.info("Read test wave")
samples, sample_rate = read_test_wave()

logging.info("Computing")

start_time = time.time()

stream = audio_tagger.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
result = audio_tagger.compute(stream)
end_time = time.time()

elapsed_seconds = end_time - start_time
audio_duration = len(samples) / sample_rate

real_time_factor = elapsed_seconds / audio_duration
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
logging.info(
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
)

s = "\n"
for i, e in enumerate(result):
s += f"{i}: {e}\n"

logging.info(s)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

main()
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include_directories(${CMAKE_SOURCE_DIR})

set(srcs
audio-tagging.cc
circular-buffer.cc
display.cc
endpoint.cc
Expand Down
87 changes: 87 additions & 0 deletions sherpa-onnx/python/csrc/audio-tagging.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// sherpa-onnx/python/csrc/audio-tagging.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/audio-tagging.h"

#include <string>

#include "sherpa-onnx/csrc/audio-tagging.h"

namespace sherpa_onnx {

static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) {
using PyClass = OfflineZipformerAudioTaggingModelConfig;
py::class_<PyClass>(*m, "OfflineZipformerAudioTaggingModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindAudioTaggingModelConfig(py::module *m) {
PybindOfflineZipformerAudioTaggingModelConfig(m);

using PyClass = AudioTaggingModelConfig;

py::class_<PyClass>(*m, "AudioTaggingModelConfig")
.def(py::init<>())
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
bool, const std::string &>(),
py::arg("zipformer"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("zipformer", &PyClass::zipformer)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindAudioTaggingConfig(py::module *m) {
PybindAudioTaggingModelConfig(m);

using PyClass = AudioTaggingConfig;

py::class_<PyClass>(*m, "AudioTaggingConfig")
.def(py::init<>())
.def(py::init<const AudioTaggingModelConfig &, const std::string &,
int32_t>(),
py::arg("model"), py::arg("labels"), py::arg("top_k") = 5)
.def_readwrite("model", &PyClass::model)
.def_readwrite("labels", &PyClass::labels)
.def_readwrite("top_k", &PyClass::top_k)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}

static void PybindAudioEvent(py::module *m) {
using PyClass = AudioEvent;

py::class_<PyClass>(*m, "AudioEvent")
.def_property_readonly(
"name", [](const PyClass &self) -> std::string { return self.name; })
.def_property_readonly(
"index", [](const PyClass &self) -> int32_t { return self.index; })
.def_property_readonly(
"prob", [](const PyClass &self) -> float { return self.prob; })
.def("__str__", &PyClass::ToString);
}

void PybindAudioTagging(py::module *m) {
PybindAudioTaggingConfig(m);
PybindAudioEvent(m);

using PyClass = AudioTagging;

py::class_<PyClass>(*m, "AudioTagging")
.def(py::init<const AudioTaggingConfig &>(), py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1,
py::call_guard<py::gil_scoped_release>());
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/audio-tagging.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/audio-tagging.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindAudioTagging(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_
2 changes: 1 addition & 1 deletion sherpa-onnx/python/csrc/offline-tts-vits-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void PybindOfflineTtsVitsModelConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineTtsVitsModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &, const std::string, float, float,
const std::string &, const std::string &, float, float,
float>(),
py::arg("model"), py::arg("lexicon"), py::arg("tokens"),
py::arg("data_dir") = "", py::arg("noise_scale") = 0.667,
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

#include "sherpa-onnx/python/csrc/alsa.h"
#include "sherpa-onnx/python/csrc/audio-tagging.h"
#include "sherpa-onnx/python/csrc/circular-buffer.h"
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
Expand Down Expand Up @@ -38,6 +39,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
m.doc() = "pybind11 binding of sherpa-onnx";

PybindWaveWriter(&m);
PybindAudioTagging(&m);

PybindFeatures(&m);
PybindOnlineCtcFstDecoderConfig(&m);
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
using PyClass = SpeakerEmbeddingExtractorConfig;
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
.def(py::init<>())
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
.def(py::init<const std::string &, int32_t, bool, const std::string &>(),
py::arg("model"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("model", &PyClass::model)
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/python/csrc/spoken-language-identification.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
.def(py::init<>())
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
bool, const std::string>(),
bool, const std::string &>(),
py::arg("whisper"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("whisper", &PyClass::whisper)
Expand All @@ -53,7 +53,7 @@ void PybindSpokenLanguageIdentification(py::module *m) {
py::arg("config"), py::call_guard<py::gil_scoped_release>())
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute,
.def("compute", &PyClass::Compute, py::arg("s"),
py::call_guard<py::gil_scoped_release>());
}

Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from _sherpa_onnx import (
Alsa,
AudioEvent,
AudioTagging,
AudioTaggingConfig,
AudioTaggingModelConfig,
CircularBuffer,
Display,
OfflineStream,
OfflineTts,
OfflineTtsConfig,
OfflineTtsModelConfig,
OfflineTtsVitsModelConfig,
OfflineZipformerAudioTaggingModelConfig,
OnlineStream,
SileroVadModelConfig,
SpeakerEmbeddingExtractor,
Expand Down

0 comments on commit 34d70a2

Please sign in to comment.