Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go API for speaker diarization #1403

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions .github/workflows/test-go-package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,50 @@ jobs:
run: |
gcc --version

- name: Test non-streaming speaker diarization
if: matrix.os != 'windows-latest'
shell: bash
run: |
cd go-api-examples/non-streaming-speaker-diarization/
./run.sh

- name: Test non-streaming speaker diarization
if: matrix.os == 'windows-latest' && matrix.arch == 'x64'
shell: bash
run: |
cd go-api-examples/non-streaming-speaker-diarization/
go mod tidy
cat go.mod
go build

echo $PWD
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/*
cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/x86_64-pc-windows-gnu/*.dll .

./run.sh

- name: Test non-streaming speaker diarization
if: matrix.os == 'windows-latest' && matrix.arch == 'x86'
shell: bash
run: |
cd go-api-examples/non-streaming-speaker-diarization/

go env GOARCH
go env -w GOARCH=386
go env -w CGO_ENABLED=1

go mod tidy
cat go.mod
go build

echo $PWD
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/*
cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/i686-pc-windows-gnu/*.dll .

./run.sh

- name: Test streaming HLG decoding (Linux/macOS)
if: matrix.os != 'windows-latest'
shell: bash
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/test-go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ jobs:
name: ${{ matrix.os }}-libs
path: to-upload/

- name: Test non-streaming speaker diarization
shell: bash
run: |
cd scripts/go/_internal/non-streaming-speaker-diarization/
./run.sh

- name: Test speaker identification
shell: bash
run: |
Expand Down
3 changes: 3 additions & 0 deletions go-api-examples/non-streaming-speaker-diarization/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module non-streaming-speaker-diarization

go 1.12
87 changes: 87 additions & 0 deletions go-api-examples/non-streaming-speaker-diarization/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package main

import (
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
"log"
)

/*
Usage:

Step 1: Download a speaker segmentation model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

Step 2: Download a speaker embedding extractor model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

Step 3. Download test wave files

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

Step 4. Run it
*/

func initSpeakerDiarization() *sherpa.OfflineSpeakerDiarization {
config := sherpa.OfflineSpeakerDiarizationConfig{}

config.Segmentation.Pyannote.Model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
config.Embedding.Model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"

// The test wave file contains 4 speakers, so we use 4 here
config.Clustering.NumClusters = 4

// if you don't know the actual numbers in the wave file,
// then please don't set NumClusters; you need to use
//
// config.Clustering.Threshold = 0.5
//

// A larger Threshold leads to fewer clusters
// A smaller Threshold leads to more clusters

sd := sherpa.NewOfflineSpeakerDiarization(&config)
return sd
}

func main() {
wave_filename := "./0-four-speakers-zh.wav"
wave := sherpa.ReadWave(wave_filename)
if wave == nil {
log.Printf("Failed to read %v", wave_filename)
return
}

sd := initSpeakerDiarization()
if sd == nil {
log.Printf("Please check your config")
return
}

defer sherpa.DeleteOfflineSpeakerDiarization(sd)

if wave.SampleRate != sd.SampleRate() {
log.Printf("Expected sample rate: %v, given: %d\n", sd.SampleRate(), wave.SampleRate)
return
}

log.Println("Started")
segments := sd.Process(wave.Samples)
n := len(segments)

for i := 0; i < n; i++ {
log.Printf("%.3f -- %.3f speaker_%02d\n", segments[i].Start, segments[i].End, segments[i].Speaker)
}
}
20 changes: 20 additions & 0 deletions go-api-examples/non-streaming-speaker-diarization/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env bash


if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
fi

if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./0-four-speakers-zh.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
fi

go mod tidy
go build
./non-streaming-speaker-diarization
5 changes: 5 additions & 0 deletions scripts/go/_internal/non-streaming-speaker-diarization/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module non-streaming-speaker-diarization

go 1.12

replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
118 changes: 118 additions & 0 deletions scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,14 @@ func ReadWave(filename string) *Wave {
w := C.SherpaOnnxReadWave(s)
defer C.SherpaOnnxFreeWave(w)

if w == nil {
return nil
}

n := int(w.num_samples)
if n == 0 {
return nil
}

ans := &Wave{}
ans.SampleRate = int(w.sample_rate)
Expand All @@ -1189,3 +1196,114 @@ func ReadWave(filename string) *Wave {

return ans
}

// ============================================================
// For offline speaker diarization
// ============================================================
type OfflineSpeakerSegmentationPyannoteModelConfig struct {
Model string
}

type OfflineSpeakerSegmentationModelConfig struct {
Pyannote OfflineSpeakerSegmentationPyannoteModelConfig
NumThreads int
Debug int
Provider string
}

type FastClusteringConfig struct {
NumClusters int
Threshold float32
}

type OfflineSpeakerDiarizationConfig struct {
Segmentation OfflineSpeakerSegmentationModelConfig
Embedding SpeakerEmbeddingExtractorConfig
Clustering FastClusteringConfig
MinDurationOn float32
MinDurationOff float32
}

type OfflineSpeakerDiarization struct {
impl *C.struct_SherpaOnnxOfflineSpeakerDiarization
}

func DeleteOfflineSpeakerDiarization(sd *OfflineSpeakerDiarization) {
C.SherpaOnnxDestroyOfflineSpeakerDiarization(sd.impl)
sd.impl = nil
}

func NewOfflineSpeakerDiarization(config *OfflineSpeakerDiarizationConfig) *OfflineSpeakerDiarization {
c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{}
c.segmentation.pyannote.model = C.CString(config.Segmentation.Pyannote.Model)
defer C.free(unsafe.Pointer(c.segmentation.pyannote.model))

c.segmentation.num_threads = C.int(config.Segmentation.NumThreads)

c.segmentation.debug = C.int(config.Segmentation.Debug)

c.segmentation.provider = C.CString(config.Segmentation.Provider)
defer C.free(unsafe.Pointer(c.segmentation.provider))

c.embedding.model = C.CString(config.Embedding.Model)
defer C.free(unsafe.Pointer(c.embedding.model))

c.embedding.num_threads = C.int(config.Embedding.NumThreads)

c.embedding.debug = C.int(config.Embedding.Debug)

c.embedding.provider = C.CString(config.Embedding.Provider)
defer C.free(unsafe.Pointer(c.embedding.provider))

c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
c.clustering.threshold = C.float(config.Clustering.Threshold)
c.min_duration_on = C.float(config.MinDurationOn)
c.min_duration_off = C.float(config.MinDurationOff)

p := C.SherpaOnnxCreateOfflineSpeakerDiarization(&c)

if p == nil {
return nil
}

sd := &OfflineSpeakerDiarization{}
sd.impl = p

return sd
}

func (sd *OfflineSpeakerDiarization) SampleRate() int {
return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl))
}

type OfflineSpeakerDiarizationSegment struct {
Start float32
End float32
Speaker int
}

func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeakerDiarizationSegment {
r := C.SherpaOnnxOfflineSpeakerDiarizationProcess(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)))
defer C.SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r)

n := int(C.SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r))

if n == 0 {
return nil
}

s := C.SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r)
defer C.SherpaOnnxOfflineSpeakerDiarizationDestroySegment(s)

ans := make([]OfflineSpeakerDiarizationSegment, n)

p := (*[1 << 28]C.struct_SherpaOnnxOfflineSpeakerDiarizationSegment)(unsafe.Pointer(s))[:n:n]

for i := 0; i < n; i++ {
ans[i].Start = float32(p[i].start)
ans[i].End = float32(p[i].end)
ans[i].Speaker = int(p[i].speaker)
}

return ans
}
Loading