diff --git a/CMakeLists.txt b/CMakeLists.txt index 9084a0216..d0b44be2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" O option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF) option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) +option(SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION "Whether to enable WASM for speaker diarization" OFF) option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) option(SHERPA_ONNX_ENABLE_WASM_KWS "Whether to enable WASM for KWS" OFF) @@ -135,6 +136,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}") message(STATUS "SHERPA_ONNX_ENABLE_WASM ${SHERPA_ONNX_ENABLE_WASM}") +message(STATUS "SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION ${SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION}") message(STATUS "SHERPA_ONNX_ENABLE_WASM_TTS ${SHERPA_ONNX_ENABLE_WASM_TTS}") message(STATUS "SHERPA_ONNX_ENABLE_WASM_ASR ${SHERPA_ONNX_ENABLE_WASM_ASR}") message(STATUS "SHERPA_ONNX_ENABLE_WASM_KWS ${SHERPA_ONNX_ENABLE_WASM_KWS}") @@ -196,9 +198,19 @@ else() add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0) endif() +if(SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION) + if(NOT SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION to ON if you want to build WASM for speaker diarization") + endif() + + if(NOT SHERPA_ONNX_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_WASM to ON if you enable WASM for speaker diarization") + endif() +endif() + if(SHERPA_ONNX_ENABLE_WASM_TTS) if(NOT SHERPA_ONNX_ENABLE_TTS) - message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS") + message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build WASM for TTS") endif() if(NOT SHERPA_ONNX_ENABLE_WASM) diff --git a/README.md b/README.md index 00fe61fd0..c28645099 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,7 @@ We also have spaces built using WebAssembly. They are listed below: | Speaker identification (Speaker ID) | [Address][sid-models] | | Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]| | Punctuation | [Address][punct-models] | +| Speaker segmentation | [Address][speaker-segmentation-models] | ### Useful links @@ -303,5 +304,6 @@ Video demo in Chinese: [爆了!炫神教你开打字挂!真正影响胜率 [sid-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models [slid-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models [punct-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models +[speaker-segmentation-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models [GigaSpeech]: https://github.com/SpeechColab/GigaSpeech [WenetSpeech]: https://github.com/wenet-e2e/WenetSpeech diff --git a/build-wasm-simd-asr.sh b/build-wasm-simd-asr.sh index eda18f74d..f5e755047 100755 --- a/build-wasm-simd-asr.sh +++ b/build-wasm-simd-asr.sh @@ -14,8 +14,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then echo "git clone https://github.com/emscripten-core/emsdk.git" echo "cd emsdk" echo "git pull" - echo "./emsdk install latest" - echo "./emsdk activate latest" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" echo "source ./emsdk_env.sh" exit 1 else diff --git a/build-wasm-simd-kws.sh b/build-wasm-simd-kws.sh index 6fdf8218f..301bd8711 100755 --- a/build-wasm-simd-kws.sh +++ b/build-wasm-simd-kws.sh @@ -9,8 +9,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then echo "git clone https://github.com/emscripten-core/emsdk.git" echo "cd emsdk" echo "git pull" - echo "./emsdk install latest" - echo "./emsdk activate latest" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" echo "source ./emsdk_env.sh" exit 1 else diff --git a/build-wasm-simd-nodejs.sh b/build-wasm-simd-nodejs.sh index 3ad88d5d4..13bf3c854 100755 --- a/build-wasm-simd-nodejs.sh +++ b/build-wasm-simd-nodejs.sh @@ -16,8 +16,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then echo "git clone https://github.com/emscripten-core/emsdk.git" echo "cd emsdk" echo "git pull" - echo "./emsdk install latest" - echo "./emsdk activate latest" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" echo "source ./emsdk_env.sh" exit 1 else diff --git a/build-wasm-simd-speaker-diarization.sh b/build-wasm-simd-speaker-diarization.sh new file mode 100755 index 000000000..da4be0381 --- /dev/null +++ b/build-wasm-simd-speaker-diarization.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# Copyright (c) 2024 Xiaomi Corporation +# +# This script is to build sherpa-onnx for WebAssembly (speaker diarization) + +set -ex + +if [ x"$EMSCRIPTEN" == x"" ]; then + if ! command -v emcc &> /dev/null; then + echo "Please install emscripten first" + echo "" + echo "You can use the following commands to install it:" + echo "" + echo "git clone https://github.com/emscripten-core/emsdk.git" + echo "cd emsdk" + echo "git pull" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" + echo "source ./emsdk_env.sh" + exit 1 + else + EMSCRIPTEN=$(dirname $(realpath $(which emcc))) + fi +fi + +export EMSCRIPTEN=$EMSCRIPTEN +echo "EMSCRIPTEN: $EMSCRIPTEN" +if [ ! -f $EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake ]; then + echo "Cannot find $EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake" + echo "Please make sure you have installed emsdk correctly" + exit 1 +fi + +mkdir -p build-wasm-simd-speaker-diarization +pushd build-wasm-simd-speaker-diarization + +export SHERPA_ONNX_IS_USING_BUILD_WASM_SH=ON + +cmake \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake \ + \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=OFF \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=OFF \ + -DSHERPA_ONNX_ENABLE_C_API=ON \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ + -DSHERPA_ONNX_ENABLE_GPU=OFF \ + -DSHERPA_ONNX_ENABLE_WASM=ON \ + -DSHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION=ON \ + -DSHERPA_ONNX_ENABLE_BINARY=OFF \ + -DSHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY=OFF \ + .. +make -j2 +make install + +ls -lh install/bin/wasm/speaker-diarization diff --git a/build-wasm-simd-tts.sh b/build-wasm-simd-tts.sh index 6835e4c43..4e37d2047 100755 --- a/build-wasm-simd-tts.sh +++ b/build-wasm-simd-tts.sh @@ -14,8 +14,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then echo "git clone https://github.com/emscripten-core/emsdk.git" echo "cd emsdk" echo "git pull" - echo "./emsdk install latest" - echo "./emsdk activate latest" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" echo "source ./emsdk_env.sh" exit 1 else diff --git a/build-wasm-simd-vad-asr.sh b/build-wasm-simd-vad-asr.sh index 5d15cf651..4bf899da3 100755 --- a/build-wasm-simd-vad-asr.sh +++ b/build-wasm-simd-vad-asr.sh @@ -15,8 +15,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then echo "git clone https://github.com/emscripten-core/emsdk.git" echo "cd emsdk" echo "git pull" - echo "./emsdk install latest" - echo "./emsdk activate latest" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" echo "source ./emsdk_env.sh" exit 1 else diff --git a/build-wasm-simd-vad.sh b/build-wasm-simd-vad.sh index c74f57d37..b73f7f156 100755 --- a/build-wasm-simd-vad.sh +++ b/build-wasm-simd-vad.sh @@ -14,8 +14,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then echo "git clone https://github.com/emscripten-core/emsdk.git" echo "cd emsdk" echo "git pull" - echo "./emsdk install latest" - echo "./emsdk activate latest" + echo "./emsdk install 3.1.53" + echo "./emsdk activate 3.1.53" echo "source ./emsdk_env.sh" exit 1 else diff --git a/scripts/dotnet/OfflineSpeakerDiarization.cs b/scripts/dotnet/OfflineSpeakerDiarization.cs index b56cab9b6..cfe28e941 100644 --- a/scripts/dotnet/OfflineSpeakerDiarization.cs +++ b/scripts/dotnet/OfflineSpeakerDiarization.cs @@ -16,6 +16,11 @@ public OfflineSpeakerDiarization(OfflineSpeakerDiarizationConfig config) _handle = new HandleRef(this, h); } + public void SetConfig(OfflineSpeakerDiarizationConfig config) + { + SherpaOnnxOfflineSpeakerDiarizationSetConfig(_handle.Handle, ref config); + } + public OfflineSpeakerDiarizationSegment[] Process(float[] samples) { IntPtr result = SherpaOnnxOfflineSpeakerDiarizationProcess(_handle.Handle, samples, samples.Length); @@ -117,6 +122,9 @@ public int SampleRate [DllImport(Dll.Filename)] private static extern void SherpaOnnxOfflineSpeakerDiarizationDestroySegment(IntPtr handle); + + [DllImport(Dll.Filename)] + private static extern void SherpaOnnxOfflineSpeakerDiarizationSetConfig(IntPtr handle, ref OfflineSpeakerDiarizationConfig config); } } diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index b8b9e6ee2..8055380c6 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1276,6 +1276,16 @@ func (sd *OfflineSpeakerDiarization) SampleRate() int { return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl)) } +// only config.Clustering is used. All other fields are ignored +func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarizationConfig) { + c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{} + + c.clustering.num_clusters = C.int(config.Clustering.NumClusters) + c.clustering.threshold = C.float(config.Clustering.Threshold) + + SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c) +} + type OfflineSpeakerDiarizationSegment struct { Start float32 End float32 diff --git a/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js b/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js index ae5158517..8ec31ee10 100644 --- a/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js +++ b/scripts/node-addon-api/lib/non-streaming-speaker-diarization.js @@ -25,6 +25,11 @@ class OfflineSpeakerDiarization { process(samples) { return addon.offlineSpeakerDiarizationProcess(this.handle, samples); } + + setConfig(config) { + addon.offlineSpeakerDiarizationSetConfig(config); + this.config.clustering = config.clustering; + } } module.exports = { diff --git a/scripts/node-addon-api/src/non-streaming-speaker-diarization.cc b/scripts/node-addon-api/src/non-streaming-speaker-diarization.cc index d8ac9033c..56767476a 100644 --- a/scripts/node-addon-api/src/non-streaming-speaker-diarization.cc +++ b/scripts/node-addon-api/src/non-streaming-speaker-diarization.cc @@ -251,6 +251,46 @@ static Napi::Array OfflineSpeakerDiarizationProcessWrapper( return ans; } +static void OfflineSpeakerDiarizationSetConfigWrapper( + const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + + if (info.Length() != 2) { + std::ostringstream os; + os << "Expect only 2 arguments. Given: " << info.Length(); + + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); + + return; + } + + if (!info[0].IsExternal()) { + Napi::TypeError::New( + env, "Argument 0 should be an offline speaker diarization pointer.") + .ThrowAsJavaScriptException(); + + return; + } + + const SherpaOnnxOfflineSpeakerDiarization *sd = + info[0].As>().Data(); + + if (!info[1].IsObject()) { + Napi::TypeError::New(env, "Expect an object as the argument") + .ThrowAsJavaScriptException(); + + return; + } + + Napi::Object o = info[0].As(); + + SherpaOnnxOfflineSpeakerDiarizationConfig c; + memset(&c, 0, sizeof(c)); + + c.clustering = GetFastClusteringConfig(o); + SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd, &c); +} + void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports) { exports.Set(Napi::String::New(env, "createOfflineSpeakerDiarization"), Napi::Function::New(env, CreateOfflineSpeakerDiarizationWrapper)); @@ -262,4 +302,8 @@ void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports) { exports.Set( Napi::String::New(env, "offlineSpeakerDiarizationProcess"), Napi::Function::New(env, OfflineSpeakerDiarizationProcessWrapper)); + + exports.Set( + Napi::String::New(env, "offlineSpeakerDiarizationSetConfig"), + Napi::Function::New(env, OfflineSpeakerDiarizationSetConfigWrapper)); } diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 322c4f79e..abcfc5b82 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -1749,6 +1749,20 @@ int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate( return sd->impl->SampleRate(); } +void SherpaOnnxOfflineSpeakerDiarizationSetConfig( + const SherpaOnnxOfflineSpeakerDiarization *sd, + const SherpaOnnxOfflineSpeakerDiarizationConfig *config) { + sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config; + + sd_config.clustering.num_clusters = + SHERPA_ONNX_OR(config->clustering.num_clusters, -1); + + sd_config.clustering.threshold = + SHERPA_ONNX_OR(config->clustering.threshold, 0.5); + + sd->impl->SetConfig(sd_config); +} + int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers( const SherpaOnnxOfflineSpeakerDiarizationResult *r) { return r->impl.NumSpeakers(); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index d378dedec..c9e7f9ee1 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1449,6 +1449,11 @@ SHERPA_ONNX_API void SherpaOnnxDestroyOfflineSpeakerDiarization( SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate( const SherpaOnnxOfflineSpeakerDiarization *sd); +// Only config->clustering is used. All other fields are ignored +SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationSetConfig( + const SherpaOnnxOfflineSpeakerDiarization *sd, + const SherpaOnnxOfflineSpeakerDiarizationConfig *config); + SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationResult SherpaOnnxOfflineSpeakerDiarizationResult; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h index f7fe39499..3aed9d72f 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -20,6 +20,10 @@ class OfflineSpeakerDiarizationImpl { virtual int32_t SampleRate() const = 0; + // Note: Only config.clustering is used. All other fields in config are + // ignored + virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0; + virtual OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 64b087c00..9667088d5 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -60,7 +60,7 @@ class OfflineSpeakerDiarizationPyannoteImpl : config_(config), segmentation_model_(config_.segmentation), embedding_extractor_(config_.embedding), - clustering_(config_.clustering) { + clustering_(std::make_unique(config_.clustering)) { Init(); } @@ -70,6 +70,15 @@ class OfflineSpeakerDiarizationPyannoteImpl return meta_data.sample_rate; } + void SetConfig(const OfflineSpeakerDiarizationConfig &config) override { + if (!config.clustering.Validate()) { + SHERPA_ONNX_LOGE("Invalid clustering config. Skip it"); + return; + } + clustering_ = std::make_unique(config.clustering); + config_.clustering = config.clustering; + } + OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, @@ -105,7 +114,7 @@ class OfflineSpeakerDiarizationPyannoteImpl ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, std::move(callback), callback_arg); - std::vector cluster_labels = clustering_.Cluster( + std::vector cluster_labels = clustering_->Cluster( &embeddings(0, 0), embeddings.rows(), embeddings.cols()); int32_t max_cluster_index = @@ -636,7 +645,7 @@ class OfflineSpeakerDiarizationPyannoteImpl OfflineSpeakerDiarizationConfig config_; OfflineSpeakerSegmentationPyannoteModel segmentation_model_; SpeakerEmbeddingExtractor embedding_extractor_; - FastClustering clustering_; + std::unique_ptr clustering_; Matrix2DInt32 powerset_mapping_; }; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index 4748b1cb4..00733bfb2 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -79,6 +79,11 @@ int32_t OfflineSpeakerDiarization::SampleRate() const { return impl_->SampleRate(); } +void OfflineSpeakerDiarization::SetConfig( + const OfflineSpeakerDiarizationConfig &config) { + impl_->SetConfig(config); +} + OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index e5d02c473..376e5f975 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -62,6 +62,10 @@ class OfflineSpeakerDiarization { // Expected sample rate of the input audio samples int32_t SampleRate() const; + // Note: Only config.clustering is used. All other fields in config are + // ignored + void SetConfig(const OfflineSpeakerDiarizationConfig &config); + OfflineSpeakerDiarizationResult Process( const float *audio, int32_t n, OfflineSpeakerDiarizationProgressCallback callback = nullptr, diff --git a/sherpa-onnx/python/csrc/offline-speaker-diarization.cc b/sherpa-onnx/python/csrc/offline-speaker-diarization.cc index c77979b3c..b3a332f70 100644 --- a/sherpa-onnx/python/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/python/csrc/offline-speaker-diarization.cc @@ -68,6 +68,7 @@ void PybindOfflineSpeakerDiarization(py::module *m) { .def(py::init(), py::arg("config")) .def_property_readonly("sample_rate", &PyClass::SampleRate) + .def("set_config", &PyClass::SetConfig, py::arg("config")) .def( "process", [](const PyClass &self, const std::vector samples, diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 881291fd6..783f59224 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -1161,6 +1161,11 @@ class SherpaOnnxOfflineSpeakerDiarizationWrapper { return Int(SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(impl)) } + // only config.clustering is used. All other fields are ignored + func setConfig(config: UnsafePointer!) { + SherpaOnnxOfflineSpeakerDiarizationSetConfig(impl, config) + } + func process(samples: [Float]) -> [SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper] { let result = SherpaOnnxOfflineSpeakerDiarizationProcess( impl, samples, Int32(samples.count)) diff --git a/wasm/CMakeLists.txt b/wasm/CMakeLists.txt index b143e57b8..7dd6ce7b5 100644 --- a/wasm/CMakeLists.txt +++ b/wasm/CMakeLists.txt @@ -18,6 +18,10 @@ if(SHERPA_ONNX_ENABLE_WASM_VAD_ASR) add_subdirectory(vad-asr) endif() +if(SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION) + add_subdirectory(speaker-diarization) +endif() + if(SHERPA_ONNX_ENABLE_WASM_NODEJS) add_subdirectory(nodejs) endif() diff --git a/wasm/speaker-diarization/CMakeLists.txt b/wasm/speaker-diarization/CMakeLists.txt new file mode 100644 index 000000000..71af018ac --- /dev/null +++ b/wasm/speaker-diarization/CMakeLists.txt @@ -0,0 +1,61 @@ +if(NOT $ENV{SHERPA_ONNX_IS_USING_BUILD_WASM_SH}) + message(FATAL_ERROR "Please use ./build-wasm-simd-speaker-diarization.sh to build for WASM for speaker diarization") +endif() + +if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/assets/segmentation.onnx" OR NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/assets/embedding.onnx") + message(FATAL_ERROR "Please read ${CMAKE_CURRENT_SOURCE_DIR}/assets/README.md before you continue") +endif() + +set(exported_functions + MyPrint + SherpaOnnxCreateOfflineSpeakerDiarization + SherpaOnnxDestroyOfflineSpeakerDiarization + SherpaOnnxOfflineSpeakerDiarizationDestroyResult + SherpaOnnxOfflineSpeakerDiarizationDestroySegment + SherpaOnnxOfflineSpeakerDiarizationGetSampleRate + SherpaOnnxOfflineSpeakerDiarizationProcess + SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback + SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments + SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime + SherpaOnnxOfflineSpeakerDiarizationSetConfig +) +set(mangled_exported_functions) +foreach(x IN LISTS exported_functions) + list(APPEND mangled_exported_functions "_${x}") +endforeach() +list(JOIN mangled_exported_functions "," all_exported_functions) + + +include_directories(${CMAKE_SOURCE_DIR}) +set(MY_FLAGS " -s FORCE_FILESYSTEM=1 -s INITIAL_MEMORY=512MB -s ALLOW_MEMORY_GROWTH=1") +string(APPEND MY_FLAGS " -sSTACK_SIZE=10485760 ") # 10MB +string(APPEND MY_FLAGS " -sEXPORTED_FUNCTIONS=[_CopyHeap,_malloc,_free,${all_exported_functions}] ") +string(APPEND MY_FLAGS "--preload-file ${CMAKE_CURRENT_SOURCE_DIR}/assets@. ") +string(APPEND MY_FLAGS " -sEXPORTED_RUNTIME_METHODS=['ccall','stringToUTF8','setValue','getValue','lengthBytesUTF8','UTF8ToString'] ") + +message(STATUS "MY_FLAGS: ${MY_FLAGS}") + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${MY_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MY_FLAGS}") +set(CMAKE_EXECUTBLE_LINKER_FLAGS "${CMAKE_EXECUTBLE_LINKER_FLAGS} ${MY_FLAGS}") + +if (NOT CMAKE_EXECUTABLE_SUFFIX STREQUAL ".js") + message(FATAL_ERROR "The default suffix for building executables should be .js!") +endif() +# set(CMAKE_EXECUTABLE_SUFFIX ".html") + +add_executable(sherpa-onnx-wasm-main-speaker-diarization sherpa-onnx-wasm-main-speaker-diarization.cc) +target_link_libraries(sherpa-onnx-wasm-main-speaker-diarization sherpa-onnx-c-api) +install(TARGETS sherpa-onnx-wasm-main-speaker-diarization DESTINATION bin/wasm/speaker-diarization) + +install( + FILES + "$/sherpa-onnx-wasm-main-speaker-diarization.js" + "index.html" + "sherpa-onnx-speaker-diarization.js" + "app-speaker-diarization.js" + "$/sherpa-onnx-wasm-main-speaker-diarization.wasm" + "$/sherpa-onnx-wasm-main-speaker-diarization.data" + DESTINATION + bin/wasm/speaker-diarization +) diff --git a/wasm/speaker-diarization/app-speaker-diarization.js b/wasm/speaker-diarization/app-speaker-diarization.js new file mode 100644 index 000000000..cb757fcfd --- /dev/null +++ b/wasm/speaker-diarization/app-speaker-diarization.js @@ -0,0 +1,124 @@ +const startBtn = document.getElementById('startBtn'); +const hint = document.getElementById('hint'); +const numClustersInput = document.getElementById('numClustersInputID'); +const thresholdInput = document.getElementById('thresholdInputID'); +const textArea = document.getElementById('text'); + +const fileSelectCtrl = document.getElementById('file'); + +let sd = null; +let float32Samples = null; + +Module = {}; +Module.onRuntimeInitialized = function() { + console.log('Model files downloaded!'); + + console.log('Initializing speaker diarization ......'); + sd = createOfflineSpeakerDiarization(Module) + console.log('sampleRate', sd.sampleRate); + + hint.innerText = + 'Initialized! Please select a wave file and click the Start button.'; + + fileSelectCtrl.disabled = false; +}; + +function onFileChange() { + var files = document.getElementById('file').files; + + if (files.length == 0) { + console.log('No file selected'); + float32Samples = null; + startBtn.disabled = true; + return; + } + textArea.value = ''; + + console.log('files: ' + files); + + const file = files[0]; + console.log(file); + console.log('file.name ' + file.name); + console.log('file.type ' + file.type); + console.log('file.size ' + file.size); + + let audioCtx = new AudioContext({sampleRate: sd.sampleRate}); + + let reader = new FileReader(); + reader.onload = function() { + console.log('reading file!'); + audioCtx.decodeAudioData(reader.result, decodedDone); + }; + + function decodedDone(decoded) { + let typedArray = new Float32Array(decoded.length); + float32Samples = decoded.getChannelData(0); + + startBtn.disabled = false; + } + + reader.readAsArrayBuffer(file); +} + +startBtn.onclick = function() { + textArea.value = ''; + if (float32Samples == null) { + alert('Empty audio samples!'); + + startBtn.disabled = true; + return; + } + + let numClusters = numClustersInput.value; + if (numClusters.trim().length == 0) { + alert( + 'Please provide numClusters. Use -1 if you are not sure how many speakers are there'); + return; + } + + if (!numClusters.match(/^\d+$/)) { + alert(`number of clusters ${ + numClusters} is not an integer .\nPlease enter an integer`); + return; + } + numClusters = parseInt(numClusters, 10); + if (numClusters < -1) { + alert(`Number of clusters should be >= -1`); + return; + } + + let threshold = 0.5; + if (numClusters <= 0) { + threshold = thresholdInput.value; + if (threshold.trim().length == 0) { + alert('Please provide a threshold.'); + return; + } + + threshold = parseFloat(threshold); + if (threshold < 0) { + alert(`Pleaser enter a positive threshold`); + return; + } + } + + let config = sd.config + config.clustering = {numClusters: numClusters, threshold: threshold}; + sd.setConfig(config); + let segments = sd.process(float32Samples); + if (segments == null) { + textArea.value = 'No speakers detected'; + return + } + + let s = ''; + let sep = ''; + + for (seg of segments) { + // clang-format off + s += sep + `${seg.start.toFixed(2)} -- ${seg.end.toFixed(2)} speaker_${seg.speaker}` + // clang-format on + sep = '\n'; + } + textArea.value = s; +} diff --git a/wasm/speaker-diarization/assets/README.md b/wasm/speaker-diarization/assets/README.md new file mode 100644 index 000000000..5c06139e2 --- /dev/null +++ b/wasm/speaker-diarization/assets/README.md @@ -0,0 +1,30 @@ +# Introduction + +Please refer to +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +to download a speaker segmentation model +and +refer to +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a speaker embedding extraction model. + +Remember to rename the downloaded files. + +The following is an example. + + +```bash +cd wasm/speaker-diarization/assets/ + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +cp sherpa-onnx-pyannote-segmentation-3-0/model.onnx ./segmentation.onnx +rm -rf sherpa-onnx-pyannote-segmentation-3-0 + + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx +mv 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ./embedding.onnx + + +``` diff --git a/wasm/speaker-diarization/index.html b/wasm/speaker-diarization/index.html new file mode 100644 index 000000000..55de8bd3b --- /dev/null +++ b/wasm/speaker-diarization/index.html @@ -0,0 +1,48 @@ + + + + + + Next-gen Kaldi WebAssembly with sherpa-onnx for Speaker Diarization + + + + +

+ Next-gen Kaldi + WebAssembly
+ Speaker Diarization
with sherpa-onnx +

+
+ Loading model ... ... +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+ + + + + diff --git a/wasm/speaker-diarization/sherpa-onnx-speaker-diarization.js b/wasm/speaker-diarization/sherpa-onnx-speaker-diarization.js new file mode 100644 index 000000000..ccfc8373c --- /dev/null +++ b/wasm/speaker-diarization/sherpa-onnx-speaker-diarization.js @@ -0,0 +1,295 @@ + +function freeConfig(config, Module) { + if ('buffer' in config) { + Module._free(config.buffer); + } + + if ('config' in config) { + freeConfig(config.config, Module) + } + + if ('segmentation' in config) { + freeConfig(config.segmentation, Module) + } + + if ('embedding' in config) { + freeConfig(config.embedding, Module) + } + + if ('clustering' in config) { + freeConfig(config.clustering, Module) + } + + Module._free(config.ptr); +} + +function initSherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig( + config, Module) { + const modelLen = Module.lengthBytesUTF8(config.model || '') + 1; + const n = modelLen; + const buffer = Module._malloc(n); + + const len = 1 * 4; + const ptr = Module._malloc(len); + + let offset = 0; + Module.stringToUTF8(config.model || '', buffer + offset, modelLen); + offset += modelLen; + + offset = 0; + Module.setValue(ptr, buffer + offset, 'i8*'); + + return { + buffer: buffer, ptr: ptr, len: len, + } +} + +function initSherpaOnnxOfflineSpeakerSegmentationModelConfig(config, Module) { + if (!('pyannote' in config)) { + config.pyannote = { + model: '', + }; + } + + const pyannote = initSherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig( + config.pyannote, Module); + + const len = pyannote.len + 3 * 4; + const ptr = Module._malloc(len); + + let offset = 0; + Module._CopyHeap(pyannote.ptr, pyannote.len, ptr + offset); + offset += pyannote.len; + + Module.setValue(ptr + offset, config.numThreads || 1, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, config.debug || 1, 'i32'); + offset += 4; + + const providerLen = Module.lengthBytesUTF8(config.provider || 'cpu') + 1; + const buffer = Module._malloc(providerLen); + Module.stringToUTF8(config.provider || 'cpu', buffer, providerLen); + Module.setValue(ptr + offset, buffer, 'i8*'); + + return { + buffer: buffer, + ptr: ptr, + len: len, + config: pyannote, + }; +} + +function initSherpaOnnxSpeakerEmbeddingExtractorConfig(config, Module) { + const modelLen = Module.lengthBytesUTF8(config.model || '') + 1; + const providerLen = Module.lengthBytesUTF8(config.provider || 'cpu') + 1; + const n = modelLen + providerLen; + const buffer = Module._malloc(n); + + const len = 4 * 4; + const ptr = Module._malloc(len); + + let offset = 0; + Module.stringToUTF8(config.model || '', buffer + offset, modelLen); + offset += modelLen; + + Module.stringToUTF8(config.provider || 'cpu', buffer + offset, providerLen); + offset += providerLen; + + offset = 0 + Module.setValue(ptr + offset, buffer, 'i8*'); + offset += 4; + + Module.setValue(ptr + offset, config.numThreads || 1, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, config.debug || 1, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, buffer + modelLen, 'i8*'); + offset += 4; + + return { + buffer: buffer, + ptr: ptr, + len: len, + }; +} + +function initSherpaOnnxFastClusteringConfig(config, Module) { + const len = 2 * 4; + const ptr = Module._malloc(len); + + let offset = 0; + Module.setValue(ptr + offset, config.numClusters || -1, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, config.threshold || 0.5, 'float'); + offset += 4; + + return { + ptr: ptr, + len: len, + }; +} + +function initSherpaOnnxOfflineSpeakerDiarizationConfig(config, Module) { + if (!('segmentation' in config)) { + config.segmentation = { + pyannote: {model: ''}, + numThreads: 1, + debug: 0, + provider: 'cpu', + }; + } + + if (!('embedding' in config)) { + config.embedding = { + model: '', + numThreads: 1, + debug: 0, + provider: 'cpu', + }; + } + + if (!('clustering' in config)) { + config.clustering = { + numClusters: -1, + threshold: 0.5, + }; + } + + const segmentation = initSherpaOnnxOfflineSpeakerSegmentationModelConfig( + config.segmentation, Module); + + const embedding = + initSherpaOnnxSpeakerEmbeddingExtractorConfig(config.embedding, Module); + + const clustering = + initSherpaOnnxFastClusteringConfig(config.clustering, Module); + + const len = segmentation.len + embedding.len + clustering.len + 2 * 4; + const ptr = Module._malloc(len); + + let offset = 0; + Module._CopyHeap(segmentation.ptr, segmentation.len, ptr + offset); + offset += segmentation.len; + + Module._CopyHeap(embedding.ptr, embedding.len, ptr + offset); + offset += embedding.len; + + Module._CopyHeap(clustering.ptr, clustering.len, ptr + offset); + offset += clustering.len; + + Module.setValue(ptr + offset, config.minDurationOn || 0.2, 'float'); + offset += 4; + + Module.setValue(ptr + offset, config.minDurationOff || 0.5, 'float'); + offset += 4; + + return { + ptr: ptr, len: len, segmentation: segmentation, embedding: embedding, + clustering: clustering, + } +} + +class OfflineSpeakerDiarization { + constructor(configObj, Module) { + const config = + initSherpaOnnxOfflineSpeakerDiarizationConfig(configObj, Module) + // Module._MyPrint(config.ptr); + + const handle = + Module._SherpaOnnxCreateOfflineSpeakerDiarization(config.ptr); + + freeConfig(config, Module); + + this.handle = handle; + this.sampleRate = + Module._SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(this.handle); + this.Module = Module + + this.config = configObj; + } + + free() { + this.Module._SherpaOnnxDestroyOfflineSpeakerDiarization(this.handle); + this.handle = 0 + } + + setConfig(configObj) { + if (!('clustering' in configObj)) { + return; + } + + const config = + initSherpaOnnxOfflineSpeakerDiarizationConfig(configObj, this.Module); + + this.Module._SherpaOnnxOfflineSpeakerDiarizationSetConfig( + this.handle, config.ptr); + + freeConfig(config, Module); + + this.config.clustering = configObj.clustering; + } + + process(samples) { + const pointer = + this.Module._malloc(samples.length * samples.BYTES_PER_ELEMENT); + this.Module.HEAPF32.set(samples, pointer / samples.BYTES_PER_ELEMENT); + + let r = this.Module._SherpaOnnxOfflineSpeakerDiarizationProcess( + this.handle, pointer, samples.length); + this.Module._free(pointer); + + let numSegments = + this.Module._SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r); + + let segments = + this.Module._SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime( + r); + + let ans = []; + + let sizeOfSegment = 3 * 4; + for (let i = 0; i < numSegments; ++i) { + let p = segments + i * sizeOfSegment + + let start = this.Module.HEAPF32[p / 4 + 0]; + let end = this.Module.HEAPF32[p / 4 + 1]; + let speaker = this.Module.HEAP32[p / 4 + 2]; + + ans.push({start: start, end: end, speaker: speaker}); + } + + this.Module._SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments); + this.Module._SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r); + + return ans; + } +} + +function createOfflineSpeakerDiarization(Module, myConfig) { + const config = { + segmentation: { + pyannote: {model: './segmentation.onnx'}, + }, + embedding: {model: './embedding.onnx'}, + clustering: {numClusters: -1, threshold: 0.5}, + minDurationOn: 0.3, + minDurationOff: 0.5, + }; + + if (myConfig) { + config = myConfig; + } + + return new OfflineSpeakerDiarization(config, Module); +} + +if (typeof process == 'object' && typeof process.versions == 'object' && + typeof process.versions.node == 'string') { + module.exports = { + createOfflineSpeakerDiarization, + }; +} diff --git a/wasm/speaker-diarization/sherpa-onnx-wasm-main-speaker-diarization.cc b/wasm/speaker-diarization/sherpa-onnx-wasm-main-speaker-diarization.cc new file mode 100644 index 000000000..6e83f61d8 --- /dev/null +++ b/wasm/speaker-diarization/sherpa-onnx-wasm-main-speaker-diarization.cc @@ -0,0 +1,63 @@ +// wasm/sherpa-onnx-wasm-main-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include + +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +// see also +// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/Interacting-with-code.html + +extern "C" { + +static_assert(sizeof(SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig) == + 1 * 4, + ""); + +static_assert( + sizeof(SherpaOnnxOfflineSpeakerSegmentationModelConfig) == + sizeof(SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig) + 3 * 4, + ""); + +static_assert(sizeof(SherpaOnnxFastClusteringConfig) == 2 * 4, ""); + +static_assert(sizeof(SherpaOnnxSpeakerEmbeddingExtractorConfig) == 4 * 4, ""); + +static_assert(sizeof(SherpaOnnxOfflineSpeakerDiarizationConfig) == + sizeof(SherpaOnnxOfflineSpeakerSegmentationModelConfig) + + sizeof(SherpaOnnxSpeakerEmbeddingExtractorConfig) + + sizeof(SherpaOnnxFastClusteringConfig) + 2 * 4, + ""); + +void MyPrint(const SherpaOnnxOfflineSpeakerDiarizationConfig *sd_config) { + const auto &segmentation = sd_config->segmentation; + const auto &embedding = sd_config->embedding; + const auto &clustering = sd_config->clustering; + + fprintf(stdout, "----------segmentation config----------\n"); + fprintf(stdout, "pyannote model: %s\n", segmentation.pyannote.model); + fprintf(stdout, "num threads: %d\n", segmentation.num_threads); + fprintf(stdout, "debug: %d\n", segmentation.debug); + fprintf(stdout, "provider: %s\n", segmentation.provider); + + fprintf(stdout, "----------embedding config----------\n"); + fprintf(stdout, "model: %s\n", embedding.model); + fprintf(stdout, "num threads: %d\n", embedding.num_threads); + fprintf(stdout, "debug: %d\n", embedding.debug); + fprintf(stdout, "provider: %s\n", embedding.provider); + + fprintf(stdout, "----------clustering config----------\n"); + fprintf(stdout, "num_clusters: %d\n", clustering.num_clusters); + fprintf(stdout, "threshold: %.3f\n", clustering.threshold); + + fprintf(stdout, "min_duration_on: %.3f\n", sd_config->min_duration_on); + fprintf(stdout, "min_duration_off: %.3f\n", sd_config->min_duration_off); +} + +void CopyHeap(const char *src, int32_t num_bytes, char *dst) { + std::copy(src, src + num_bytes, dst); +} +}