Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python API for clustering #1385

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/scripts/test-online-punctuation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

set -ex

echo "TODO(fangjun): Skip this test since the sanitizer test is failed. We need to fix it"
exit 0

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
Expand Down
12 changes: 12 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test_clustering"
pushd /tmp/
mkdir test-cluster
cd test-cluster
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
git clone https://github.com/csukuangfj/sr-data
popd

python3 ./sherpa-onnx/python/tests/test_fast_clustering.py

rm -rf /tmp/test-cluster

export GIT_CLONE_PROTECTION_ACTIVE=false

log "test offline SenseVoice CTC"
Expand Down
14 changes: 8 additions & 6 deletions .github/workflows/run-python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ jobs:
fail-fast: false
matrix:
include:
- os: ubuntu-20.04
python-version: "3.7"
- os: ubuntu-20.04
python-version: "3.8"
- os: ubuntu-20.04
python-version: "3.9"
# it fails to install ffmpeg on ubuntu 20.04
#
# - os: ubuntu-20.04
# python-version: "3.7"
# - os: ubuntu-20.04
# python-version: "3.8"
# - os: ubuntu-20.04
# python-version: "3.9"

- os: ubuntu-22.04
python-version: "3.10"
Expand Down
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ else()
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
message(STATUS "speaker diarization is enabled")
add_definitions(-DSHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=1)
else()
message(WARNING "speaker diarization is disabled")
add_definitions(-DSHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=0)
endif()

if(SHERPA_ONNX_ENABLE_DIRECTML)
message(STATUS "DirectML is enabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
Expand Down
5 changes: 5 additions & 0 deletions build-android-arm64-v8a.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ]; then
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
Expand All @@ -77,6 +81,7 @@ fi

cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=$SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
Expand Down
5 changes: 5 additions & 0 deletions build-android-armv7-eabi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ]; then
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
Expand All @@ -78,6 +82,7 @@ fi

cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=$SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
Expand Down
5 changes: 5 additions & 0 deletions build-android-x86-64.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ]; then
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
Expand All @@ -78,6 +82,7 @@ fi

cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=$SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
Expand Down
5 changes: 5 additions & 0 deletions build-android-x86.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ if [ -z $SHERPA_ONNX_ENABLE_TTS ]; then
SHERPA_ONNX_ENABLE_TTS=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ]; then
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=ON
fi

if [ -z $SHERPA_ONNX_ENABLE_BINARY ]; then
SHERPA_ONNX_ENABLE_BINARY=OFF
fi
Expand All @@ -78,6 +82,7 @@ fi

cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
-DSHERPA_ONNX_ENABLE_TTS=$SHERPA_ONNX_ENABLE_TTS \
-DSHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=$SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION \
-DSHERPA_ONNX_ENABLE_BINARY=$SHERPA_ONNX_ENABLE_BINARY \
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
Expand Down
1 change: 1 addition & 0 deletions scripts/apk/build-apk-asr-2pass.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " "
log "Building streaming ASR two-pass APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"

export SHERPA_ONNX_ENABLE_TTS=OFF
export SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=OFF

log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
Expand Down
1 change: 1 addition & 0 deletions scripts/apk/build-apk-asr.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " "
log "Building streaming ASR APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"

export SHERPA_ONNX_ENABLE_TTS=OFF
export SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=OFF

log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
Expand Down
1 change: 1 addition & 0 deletions scripts/apk/build-apk-audio-tagging-wearos.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ log "====================x86===================="
./build-android-x86.sh

export SHERPA_ONNX_ENABLE_TTS=OFF
export SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=OFF

mkdir -p apks

Expand Down
1 change: 1 addition & 0 deletions scripts/apk/build-apk-audio-tagging.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ log "====================x86===================="
./build-android-x86.sh

export SHERPA_ONNX_ENABLE_TTS=OFF
export SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=OFF

mkdir -p apks

Expand Down
1 change: 1 addition & 0 deletions scripts/apk/build-apk-kws.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " "
log "Building keyword spotting APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"

export SHERPA_ONNX_ENABLE_TTS=OFF
export SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=OFF

log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
Expand Down
1 change: 1 addition & 0 deletions scripts/apk/build-apk-slid.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ log "====================x86===================="
./build-android-x86.sh

export SHERPA_ONNX_ENABLE_TTS=OFF
export SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION=OFF

mkdir -p apks

Expand Down
4 changes: 2 additions & 2 deletions scripts/apk/build-apk-speaker-identification.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " "

log "Building Speaker identification APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"

export SHERPA_ONNX_ENABLE_TTS=OFF

log "====================arm64-v8a================="
./build-android-arm64-v8a.sh
log "====================armv7-eabi================"
Expand All @@ -29,8 +31,6 @@ log "====================x86-64===================="
log "====================x86===================="
./build-android-x86.sh

export SHERPA_ONNX_ENABLE_TTS=OFF

mkdir -p apks

{% for model in model_list %}
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/fast-clustering-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ void FastClusteringConfig::Register(ParseOptions *po) {

p.Register("num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored");
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");

p.Register("cluster-threshold", &threshold,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering.");
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
}

bool FastClusteringConfig::Validate() const {
Expand Down
15 changes: 13 additions & 2 deletions sherpa-onnx/csrc/fast-clustering-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@
namespace sherpa_onnx {

struct FastClusteringConfig {
// If greater than 0, then threshold is ignored
// If greater than 0, then threshold is ignored.
//
// We strongly recommend that you set it if you know the number of clusters
// in advance
int32_t num_clusters = -1;

// distance threshold
// distance threshold.
//
// The lower, the more clusters it will generate.
// The higher, the fewer clusters it will generate.
float threshold = 0.5;

FastClusteringConfig() = default;

FastClusteringConfig(int32_t num_clusters, float threshold)
: num_clusters(num_clusters), threshold(threshold) {}

std::string ToString() const;

void Register(ParseOptions *po);
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/fast-clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class FastClustering::Impl {
explicit Impl(const FastClusteringConfig &config) : config_(config) {}

std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols) {
int32_t num_cols) const {
if (num_rows <= 0) {
return {};
}
Expand Down Expand Up @@ -77,7 +77,7 @@ FastClustering::FastClustering(const FastClusteringConfig &config)
FastClustering::~FastClustering() = default;

std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows,
int32_t num_cols) {
int32_t num_cols) const {
return impl_->Cluster(features, num_rows, num_cols);
}
} // namespace sherpa_onnx
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/fast-clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class FastClustering {
* matrix.
*/
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols);
int32_t num_cols) const;

private:
class Impl;
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ if(SHERPA_ONNX_ENABLE_TTS)
)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND srcs
fast-clustering.cc
)
endif()

pybind11_add_module(_sherpa_onnx ${srcs})

if(APPLE)
Expand Down
52 changes: 52 additions & 0 deletions sherpa-onnx/python/csrc/fast-clustering.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// sherpa-onnx/python/csrc/fast-clustering.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/fast-clustering.h"

#include <sstream>
#include <vector>

#include "sherpa-onnx/csrc/fast-clustering.h"

namespace sherpa_onnx {

static void PybindFastClusteringConfig(py::module *m) {
using PyClass = FastClusteringConfig;
py::class_<PyClass>(*m, "FastClusteringConfig")
.def(py::init<int32_t, float>(), py::arg("num_clusters") = -1,
py::arg("threshold") = 0.5)
.def_readwrite("num_clusters", &PyClass::num_clusters)
.def_readwrite("threshold", &PyClass::threshold)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

void PybindFastClustering(py::module *m) {
PybindFastClusteringConfig(m);

using PyClass = FastClustering;
py::class_<PyClass>(*m, "FastClustering")
.def(py::init<const FastClusteringConfig &>(), py::arg("config"))
.def(
"__call__",
[](const PyClass &self,
py::array_t<float> features) -> std::vector<int32_t> {
int num_dim = features.ndim();
if (num_dim != 2) {
std::ostringstream os;
os << "Expect an array of 2 dimensions. Given dim: " << num_dim
<< "\n";
throw py::value_error(os.str());
}

int32_t num_rows = features.shape(0);
int32_t num_cols = features.shape(1);
float *p = features.mutable_data();
py::gil_scoped_release release;
return self.Cluster(p, num_rows, num_cols);
},
py::arg("features"));
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/fast-clustering.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/fast-clustering.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_FAST_CLUSTERING_H_
#define SHERPA_ONNX_PYTHON_CSRC_FAST_CLUSTERING_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindFastClustering(py::module *m);

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_PYTHON_CSRC_FAST_CLUSTERING_H_
8 changes: 8 additions & 0 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
#include "sherpa-onnx/python/csrc/offline-tts.h"
#endif

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/python/csrc/fast-clustering.h"
#endif

namespace sherpa_onnx {

PYBIND11_MODULE(_sherpa_onnx, m) {
Expand Down Expand Up @@ -70,6 +74,10 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts(&m);
#endif

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
#endif

PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
PybindSpokenLanguageIdentification(&m);
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
AudioTaggingModelConfig,
CircularBuffer,
Display,
FastClustering,
FastClusteringConfig,
OfflinePunctuation,
OfflinePunctuationConfig,
OfflinePunctuationModelConfig,
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ endfunction()

# please sort the files in alphabetic order
set(py_test_files
test_fast_clustering.py
test_feature_extractor_config.py
test_keyword_spotter.py
test_offline_recognizer.py
Expand Down
Loading
Loading