-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Python API and Python examples for audio tagging (#753)
- Loading branch information
1 parent
904a3cc
commit 34d70a2
Showing
12 changed files
with
245 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/usr/bin/env python3 | ||
|
||
""" | ||
This script shows how to use audio tagging Python APIs to tag a file. | ||
Please read the code to download the required model files and test wave file. | ||
""" | ||
|
||
import logging | ||
import time | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import sherpa_onnx | ||
import soundfile as sf | ||
|
||
|
||
def read_test_wave(): | ||
# Please download the model files and test wave files from | ||
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models | ||
test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav" | ||
|
||
if not Path(test_wave).is_file(): | ||
raise ValueError( | ||
f"Please download {test_wave} from " | ||
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" | ||
) | ||
|
||
# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read | ||
data, sample_rate = sf.read( | ||
test_wave, | ||
always_2d=True, | ||
dtype="float32", | ||
) | ||
data = data[:, 0] # use only the first channel | ||
samples = np.ascontiguousarray(data) | ||
|
||
# samples is a 1-d array of dtype float32 | ||
# sample_rate is a scalar | ||
return samples, sample_rate | ||
|
||
|
||
def create_audio_tagger(): | ||
# Please download the model files and test wave files from | ||
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models | ||
model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx" | ||
label_file = ( | ||
"./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv" | ||
) | ||
|
||
if not Path(model_file).is_file(): | ||
raise ValueError( | ||
f"Please download {model_file} from " | ||
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" | ||
) | ||
|
||
if not Path(label_file).is_file(): | ||
raise ValueError( | ||
f"Please download {label_file} from " | ||
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" | ||
) | ||
|
||
config = sherpa_onnx.AudioTaggingConfig( | ||
model=sherpa_onnx.AudioTaggingModelConfig( | ||
zipformer=sherpa_onnx.OfflineZipformerAudioTaggingModelConfig( | ||
model=model_file, | ||
), | ||
num_threads=1, | ||
debug=True, | ||
provider="cpu", | ||
), | ||
labels=label_file, | ||
top_k=5, | ||
) | ||
if not config.validate(): | ||
raise ValueError(f"Please check the config: {config}") | ||
|
||
print(config) | ||
|
||
return sherpa_onnx.AudioTagging(config) | ||
|
||
|
||
def main(): | ||
logging.info("Create audio tagger") | ||
audio_tagger = create_audio_tagger() | ||
|
||
logging.info("Read test wave") | ||
samples, sample_rate = read_test_wave() | ||
|
||
logging.info("Computing") | ||
|
||
start_time = time.time() | ||
|
||
stream = audio_tagger.create_stream() | ||
stream.accept_waveform(sample_rate=sample_rate, waveform=samples) | ||
result = audio_tagger.compute(stream) | ||
end_time = time.time() | ||
|
||
elapsed_seconds = end_time - start_time | ||
audio_duration = len(samples) / sample_rate | ||
|
||
real_time_factor = elapsed_seconds / audio_duration | ||
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
logging.info(f"Audio duration in seconds: {audio_duration:.3f}") | ||
logging.info( | ||
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" | ||
) | ||
|
||
s = "\n" | ||
for i, e in enumerate(result): | ||
s += f"{i}: {e}\n" | ||
|
||
logging.info(s) | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
include_directories(${CMAKE_SOURCE_DIR}) | ||
|
||
set(srcs | ||
audio-tagging.cc | ||
circular-buffer.cc | ||
display.cc | ||
endpoint.cc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// sherpa-onnx/python/csrc/audio-tagging.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/python/csrc/audio-tagging.h" | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/audio-tagging.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) { | ||
using PyClass = OfflineZipformerAudioTaggingModelConfig; | ||
py::class_<PyClass>(*m, "OfflineZipformerAudioTaggingModelConfig") | ||
.def(py::init<>()) | ||
.def(py::init<const std::string &>(), py::arg("model")) | ||
.def_readwrite("model", &PyClass::model) | ||
.def("validate", &PyClass::Validate) | ||
.def("__str__", &PyClass::ToString); | ||
} | ||
|
||
static void PybindAudioTaggingModelConfig(py::module *m) { | ||
PybindOfflineZipformerAudioTaggingModelConfig(m); | ||
|
||
using PyClass = AudioTaggingModelConfig; | ||
|
||
py::class_<PyClass>(*m, "AudioTaggingModelConfig") | ||
.def(py::init<>()) | ||
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t, | ||
bool, const std::string &>(), | ||
py::arg("zipformer"), py::arg("num_threads") = 1, | ||
py::arg("debug") = false, py::arg("provider") = "cpu") | ||
.def_readwrite("zipformer", &PyClass::zipformer) | ||
.def_readwrite("num_threads", &PyClass::num_threads) | ||
.def_readwrite("debug", &PyClass::debug) | ||
.def_readwrite("provider", &PyClass::provider) | ||
.def("validate", &PyClass::Validate) | ||
.def("__str__", &PyClass::ToString); | ||
} | ||
|
||
static void PybindAudioTaggingConfig(py::module *m) { | ||
PybindAudioTaggingModelConfig(m); | ||
|
||
using PyClass = AudioTaggingConfig; | ||
|
||
py::class_<PyClass>(*m, "AudioTaggingConfig") | ||
.def(py::init<>()) | ||
.def(py::init<const AudioTaggingModelConfig &, const std::string &, | ||
int32_t>(), | ||
py::arg("model"), py::arg("labels"), py::arg("top_k") = 5) | ||
.def_readwrite("model", &PyClass::model) | ||
.def_readwrite("labels", &PyClass::labels) | ||
.def_readwrite("top_k", &PyClass::top_k) | ||
.def("validate", &PyClass::Validate) | ||
.def("__str__", &PyClass::ToString); | ||
} | ||
|
||
static void PybindAudioEvent(py::module *m) { | ||
using PyClass = AudioEvent; | ||
|
||
py::class_<PyClass>(*m, "AudioEvent") | ||
.def_property_readonly( | ||
"name", [](const PyClass &self) -> std::string { return self.name; }) | ||
.def_property_readonly( | ||
"index", [](const PyClass &self) -> int32_t { return self.index; }) | ||
.def_property_readonly( | ||
"prob", [](const PyClass &self) -> float { return self.prob; }) | ||
.def("__str__", &PyClass::ToString); | ||
} | ||
|
||
void PybindAudioTagging(py::module *m) { | ||
PybindAudioTaggingConfig(m); | ||
PybindAudioEvent(m); | ||
|
||
using PyClass = AudioTagging; | ||
|
||
py::class_<PyClass>(*m, "AudioTagging") | ||
.def(py::init<const AudioTaggingConfig &>(), py::arg("config"), | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("create_stream", &PyClass::CreateStream, | ||
py::call_guard<py::gil_scoped_release>()) | ||
.def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1, | ||
py::call_guard<py::gil_scoped_release>()); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// sherpa-onnx/python/csrc/audio-tagging.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ | ||
#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ | ||
|
||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
void PybindAudioTagging(py::module *m); | ||
|
||
} | ||
|
||
#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters