Skip to content

Commit

Permalink
WebAssembly example for speaker diarization
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 10, 2024
1 parent 67349b5 commit 246f257
Show file tree
Hide file tree
Showing 28 changed files with 830 additions and 16 deletions.
14 changes: 13 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" O
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
option(SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION "Whether to enable WASM for speaker diarization" OFF)
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
option(SHERPA_ONNX_ENABLE_WASM_KWS "Whether to enable WASM for KWS" OFF)
Expand Down Expand Up @@ -135,6 +136,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}")
message(STATUS "SHERPA_ONNX_ENABLE_WASM ${SHERPA_ONNX_ENABLE_WASM}")
message(STATUS "SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION ${SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION}")
message(STATUS "SHERPA_ONNX_ENABLE_WASM_TTS ${SHERPA_ONNX_ENABLE_WASM_TTS}")
message(STATUS "SHERPA_ONNX_ENABLE_WASM_ASR ${SHERPA_ONNX_ENABLE_WASM_ASR}")
message(STATUS "SHERPA_ONNX_ENABLE_WASM_KWS ${SHERPA_ONNX_ENABLE_WASM_KWS}")
Expand Down Expand Up @@ -196,9 +198,19 @@ else()
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
endif()

if(SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION)
if(NOT SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION to ON if you want to build WASM for speaker diarization")
endif()

if(NOT SHERPA_ONNX_ENABLE_WASM)
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_WASM to ON if you enable WASM for speaker diarization")
endif()
endif()

if(SHERPA_ONNX_ENABLE_WASM_TTS)
if(NOT SHERPA_ONNX_ENABLE_TTS)
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build WASM for TTS")
endif()

if(NOT SHERPA_ONNX_ENABLE_WASM)
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ We also have spaces built using WebAssembly. They are listed below:
| Speaker identification (Speaker ID) | [Address][sid-models] |
| Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]|
| Punctuation | [Address][punct-models] |
| Speaker segmentation | [Address][speaker-segmentation-models] |

### Useful links

Expand Down Expand Up @@ -303,5 +304,6 @@ Video demo in Chinese: [爆了!炫神教你开打字挂!真正影响胜率
[sid-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
[slid-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
[punct-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
[speaker-segmentation-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
[GigaSpeech]: https://github.com/SpeechColab/GigaSpeech
[WenetSpeech]: https://github.com/wenet-e2e/WenetSpeech
4 changes: 2 additions & 2 deletions build-wasm-simd-asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install latest"
echo "./emsdk activate latest"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
Expand Down
4 changes: 2 additions & 2 deletions build-wasm-simd-kws.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install latest"
echo "./emsdk activate latest"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
Expand Down
4 changes: 2 additions & 2 deletions build-wasm-simd-nodejs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install latest"
echo "./emsdk activate latest"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
Expand Down
61 changes: 61 additions & 0 deletions build-wasm-simd-speaker-diarization.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env bash
# Copyright (c) 2024 Xiaomi Corporation
#
# This script is to build sherpa-onnx for WebAssembly (speaker diarization)

set -ex

if [ x"$EMSCRIPTEN" == x"" ]; then
if ! command -v emcc &> /dev/null; then
echo "Please install emscripten first"
echo ""
echo "You can use the following commands to install it:"
echo ""
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
EMSCRIPTEN=$(dirname $(realpath $(which emcc)))
fi
fi

export EMSCRIPTEN=$EMSCRIPTEN
echo "EMSCRIPTEN: $EMSCRIPTEN"
if [ ! -f $EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake ]; then
echo "Cannot find $EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake"
echo "Please make sure you have installed emsdk correctly"
exit 1
fi

mkdir -p build-wasm-simd-speaker-diarization
pushd build-wasm-simd-speaker-diarization

export SHERPA_ONNX_IS_USING_BUILD_WASM_SH=ON

cmake \
-DCMAKE_INSTALL_PREFIX=./install \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_TOOLCHAIN_FILE=$EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake \
\
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=OFF \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=OFF \
-DSHERPA_ONNX_ENABLE_C_API=ON \
-DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \
-DSHERPA_ONNX_ENABLE_GPU=OFF \
-DSHERPA_ONNX_ENABLE_WASM=ON \
-DSHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION=ON \
-DSHERPA_ONNX_ENABLE_BINARY=OFF \
-DSHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY=OFF \
..
make -j2
make install

ls -lh install/bin/wasm/speaker-diarization
4 changes: 2 additions & 2 deletions build-wasm-simd-tts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install latest"
echo "./emsdk activate latest"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
Expand Down
4 changes: 2 additions & 2 deletions build-wasm-simd-vad-asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install latest"
echo "./emsdk activate latest"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
Expand Down
4 changes: 2 additions & 2 deletions build-wasm-simd-vad.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ if [ x"$EMSCRIPTEN" == x"" ]; then
echo "git clone https://github.com/emscripten-core/emsdk.git"
echo "cd emsdk"
echo "git pull"
echo "./emsdk install latest"
echo "./emsdk activate latest"
echo "./emsdk install 3.1.53"
echo "./emsdk activate 3.1.53"
echo "source ./emsdk_env.sh"
exit 1
else
Expand Down
8 changes: 8 additions & 0 deletions scripts/dotnet/OfflineSpeakerDiarization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ public OfflineSpeakerDiarization(OfflineSpeakerDiarizationConfig config)
_handle = new HandleRef(this, h);
}

public void SetConfig(OfflineSpeakerDiarizationConfig config)
{
SherpaOnnxOfflineSpeakerDiarizationSetConfig(_handle.Handle, ref config);
}

public OfflineSpeakerDiarizationSegment[] Process(float[] samples)
{
IntPtr result = SherpaOnnxOfflineSpeakerDiarizationProcess(_handle.Handle, samples, samples.Length);
Expand Down Expand Up @@ -117,6 +122,9 @@ public int SampleRate

[DllImport(Dll.Filename)]
private static extern void SherpaOnnxOfflineSpeakerDiarizationDestroySegment(IntPtr handle);

[DllImport(Dll.Filename)]
private static extern void SherpaOnnxOfflineSpeakerDiarizationSetConfig(IntPtr handle, ref OfflineSpeakerDiarizationConfig config);
}
}

10 changes: 10 additions & 0 deletions scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,16 @@ func (sd *OfflineSpeakerDiarization) SampleRate() int {
return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl))
}

// only config.Clustering is used. All other fields are ignored
func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarizationConfig) {
c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{}

c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
c.clustering.threshold = C.float(config.Clustering.Threshold)

SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c)
}

type OfflineSpeakerDiarizationSegment struct {
Start float32
End float32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class OfflineSpeakerDiarization {
process(samples) {
return addon.offlineSpeakerDiarizationProcess(this.handle, samples);
}

setConfig(config) {
addon.offlineSpeakerDiarizationSetConfig(config);
this.config.clustering = config.clustering;
}
}

module.exports = {
Expand Down
44 changes: 44 additions & 0 deletions scripts/node-addon-api/src/non-streaming-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,46 @@ static Napi::Array OfflineSpeakerDiarizationProcessWrapper(
return ans;
}

static void OfflineSpeakerDiarizationSetConfigWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

return;
}

if (!info[0].IsExternal()) {
Napi::TypeError::New(
env, "Argument 0 should be an offline speaker diarization pointer.")
.ThrowAsJavaScriptException();

return;
}

const SherpaOnnxOfflineSpeakerDiarization *sd =
info[0].As<Napi::External<SherpaOnnxOfflineSpeakerDiarization>>().Data();

if (!info[1].IsObject()) {
Napi::TypeError::New(env, "Expect an object as the argument")
.ThrowAsJavaScriptException();

return;
}

Napi::Object o = info[0].As<Napi::Object>();

SherpaOnnxOfflineSpeakerDiarizationConfig c;
memset(&c, 0, sizeof(c));

c.clustering = GetFastClusteringConfig(o);
SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd, &c);
}

void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "createOfflineSpeakerDiarization"),
Napi::Function::New(env, CreateOfflineSpeakerDiarizationWrapper));
Expand All @@ -262,4 +302,8 @@ void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports) {
exports.Set(
Napi::String::New(env, "offlineSpeakerDiarizationProcess"),
Napi::Function::New(env, OfflineSpeakerDiarizationProcessWrapper));

exports.Set(
Napi::String::New(env, "offlineSpeakerDiarizationSetConfig"),
Napi::Function::New(env, OfflineSpeakerDiarizationSetConfigWrapper));
}
14 changes: 14 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,20 @@ int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
return sd->impl->SampleRate();
}

void SherpaOnnxOfflineSpeakerDiarizationSetConfig(
const SherpaOnnxOfflineSpeakerDiarization *sd,
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;

sd_config.clustering.num_clusters =
SHERPA_ONNX_OR(config->clustering.num_clusters, -1);

sd_config.clustering.threshold =
SHERPA_ONNX_OR(config->clustering.threshold, 0.5);

sd->impl->SetConfig(sd_config);
}

int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
return r->impl.NumSpeakers();
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,11 @@ SHERPA_ONNX_API void SherpaOnnxDestroyOfflineSpeakerDiarization(
SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
const SherpaOnnxOfflineSpeakerDiarization *sd);

// Only config->clustering is used. All other fields are ignored
SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationSetConfig(
const SherpaOnnxOfflineSpeakerDiarization *sd,
const SherpaOnnxOfflineSpeakerDiarizationConfig *config);

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationResult
SherpaOnnxOfflineSpeakerDiarizationResult;

Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class OfflineSpeakerDiarizationImpl {

virtual int32_t SampleRate() const = 0;

// Note: Only config.clustering is used. All other fields in config are
// ignored
virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0;

virtual OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
Expand Down
15 changes: 12 additions & 3 deletions sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
: config_(config),
segmentation_model_(config_.segmentation),
embedding_extractor_(config_.embedding),
clustering_(config_.clustering) {
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init();
}

Expand All @@ -70,6 +70,15 @@ class OfflineSpeakerDiarizationPyannoteImpl
return meta_data.sample_rate;
}

void SetConfig(const OfflineSpeakerDiarizationConfig &config) override {
if (!config.clustering.Validate()) {
SHERPA_ONNX_LOGE("Invalid clustering config. Skip it");
return;
}
clustering_ = std::make_unique<FastClustering>(config.clustering);
config_.clustering = config.clustering;
}

OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
Expand Down Expand Up @@ -105,7 +114,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
std::move(callback), callback_arg);

std::vector<int32_t> cluster_labels = clustering_.Cluster(
std::vector<int32_t> cluster_labels = clustering_->Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());

int32_t max_cluster_index =
Expand Down Expand Up @@ -636,7 +645,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
OfflineSpeakerDiarizationConfig config_;
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
SpeakerEmbeddingExtractor embedding_extractor_;
FastClustering clustering_;
std::unique_ptr<FastClustering> clustering_;
Matrix2DInt32 powerset_mapping_;
};

Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ int32_t OfflineSpeakerDiarization::SampleRate() const {
return impl_->SampleRate();
}

void OfflineSpeakerDiarization::SetConfig(
const OfflineSpeakerDiarizationConfig &config) {
impl_->SetConfig(config);
}

OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class OfflineSpeakerDiarization {
// Expected sample rate of the input audio samples
int32_t SampleRate() const;

// Note: Only config.clustering is used. All other fields in config are
// ignored
void SetConfig(const OfflineSpeakerDiarizationConfig &config);

OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
Expand Down
Loading

0 comments on commit 246f257

Please sign in to comment.