Skip to content

Commit

Permalink
Go API for speaker diarization
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 9, 2024
1 parent d468527 commit 3813526
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 0 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/test-go-package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,50 @@ jobs:
run: |
gcc --version
- name: Test non-streaming speaker diarization
if: matrix.os != 'windows-latest'
shell: bash
run: |
cd go-api-examples/non-streaming-speaker-diarization/
./run.sh
- name: Test non-streaming speaker diarization
if: matrix.os == 'windows-latest' && matrix.arch == 'x64'
shell: bash
run: |
cd go-api-examples/non-streaming-speaker-diarization/
go mod tidy
cat go.mod
go build
echo $PWD
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/*
cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/x86_64-pc-windows-gnu/*.dll .
./run.sh
- name: Test non-streaming speaker diarization
if: matrix.os == 'windows-latest' && matrix.arch == 'x86'
shell: bash
run: |
cd go-api-examples/non-streaming-speaker-diarization/
go env GOARCH
go env -w GOARCH=386
go env -w CGO_ENABLED=1
go mod tidy
cat go.mod
go build
echo $PWD
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/
ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/*
cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/i686-pc-windows-gnu/*.dll .
./run.sh
- name: Test streaming HLG decoding (Linux/macOS)
if: matrix.os != 'windows-latest'
shell: bash
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/test-go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ jobs:
name: ${{ matrix.os }}-libs
path: to-upload/

- name: Test non-streaming speaker diarization
shell: bash
run: |
cd scripts/go/_internal/non-streaming-speaker-diarization/
./run.sh
- name: Test speaker identification
shell: bash
run: |
Expand Down
3 changes: 3 additions & 0 deletions go-api-examples/non-streaming-speaker-diarization/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module non-streaming-speaker-diarization

go 1.12
82 changes: 82 additions & 0 deletions go-api-examples/non-streaming-speaker-diarization/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package main

import (
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
"log"
)

/*
Usage:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Run it
*/

func initSpeakerDiarization() *sherpa.OfflineSpeakerDiarization {
config := sherpa.OfflineSpeakerDiarizationConfig{}

config.Segmentation.Pyannote.Model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
config.Embedding.Model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"

// The test wave file contains 4 speakers, so we use 4 here
config.Clustering.NumClusters = 4

// if you don't know the actual numbers in the wave file,
// then please don't set NumClusters; you need to use
//
// config.Clustering.Threshold = 0.5
//

// A larger Threshold leads to fewer clusters
// A smaller Threshold leads to more clusters

sd := sherpa.NewOfflineSpeakerDiarization(&config)
return sd
}

func main() {
wave_filename := "./0-four-speakers-zh.wav"
wave := sherpa.ReadWave(wave_filename)
if wave == nil {
log.Printf("Failed to read %v", wave_filename)
return
}

sd := initSpeakerDiarization()
if sd == nil {
log.Printf("Please check your config")
return
}

defer sherpa.DeleteOfflineSpeakerDiarization(sd)

log.Println("Started")
segments := sd.Process(wave.Samples)
n := len(segments)

for i := 0; i < n; i++ {
log.Printf("%.3f -- %.3f speaker_%02d\n", segments[i].Start, segments[i].End, segments[i].Speaker)
}
}
20 changes: 20 additions & 0 deletions go-api-examples/non-streaming-speaker-diarization/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env bash


if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
fi

if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./0-four-speakers-zh.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
fi

go mod tidy
go build
./non-streaming-speaker-diarization
5 changes: 5 additions & 0 deletions scripts/go/_internal/non-streaming-speaker-diarization/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module non-streaming-speaker-diarization

go 1.12

replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
120 changes: 120 additions & 0 deletions scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,14 @@ func ReadWave(filename string) *Wave {
w := C.SherpaOnnxReadWave(s)
defer C.SherpaOnnxFreeWave(w)

if w == nil {
return nil
}

n := int(w.num_samples)
if n == 0 {
return nil
}

ans := &Wave{}
ans.SampleRate = int(w.sample_rate)
Expand All @@ -1189,3 +1196,116 @@ func ReadWave(filename string) *Wave {

return ans
}

// ============================================================
// For offline speaker diarization
// ============================================================
type OfflineSpeakerSegmentationPyannoteModelConfig struct {
Model string
}

type OfflineSpeakerSegmentationModelConfig struct {
Pyannote OfflineSpeakerSegmentationPyannoteModelConfig
NumThreads int
Debug int
Provider string
}

type FastClusteringConfig struct {
NumClusters int
Threshold float32
}

type OfflineSpeakerDiarizationConfig struct {
Segmentation OfflineSpeakerSegmentationModelConfig
Embedding SpeakerEmbeddingExtractorConfig
Clustering FastClusteringConfig
MinDurationOn float32
MinDurationOff float32
}

type OfflineSpeakerDiarization struct {
impl *C.struct_SherpaOnnxOfflineSpeakerDiarization
}

func DeleteOfflineSpeakerDiarization(sd *OfflineSpeakerDiarization) {
C.SherpaOnnxDestroyOfflineSpeakerDiarization(sd.impl)
sd.impl = nil
}

func NewOfflineSpeakerDiarization(config *OfflineSpeakerDiarizationConfig) *OfflineSpeakerDiarization {
c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{}
c.segmentation.pyannote.model = C.CString(config.Segmentation.Pyannote.Model)
defer C.free(unsafe.Pointer(c.segmentation.pyannote.model))

c.segmentation.num_threads = C.int(config.Segmentation.NumThreads)

c.segmentation.debug = C.int(config.Segmentation.Debug)

c.segmentation.provider = C.CString(config.Segmentation.Provider)
defer C.free(unsafe.Pointer(c.segmentation.provider))

c.embedding.model = C.CString(config.Embedding.Model)
defer C.free(unsafe.Pointer(c.embedding.model))

c.embedding.num_threads = C.int(config.Embedding.NumThreads)

c.embedding.debug = C.int(config.Embedding.Debug)

c.embedding.provider = C.CString(config.Embedding.Provider)
defer C.free(unsafe.Pointer(c.embedding.provider))

c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
c.clustering.threshold = C.float(config.Clustering.Threshold)
c.min_duration_on = C.float(config.MinDurationOn)
c.min_duration_off = C.float(config.MinDurationOff)

p := C.SherpaOnnxCreateOfflineSpeakerDiarization(&c)

if p == nil {
return nil
}

sd := &OfflineSpeakerDiarization{}
sd.impl = p

return sd
}

func (sd *OfflineSpeakerDiarization) SampleRate() int {
return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl))
}

type OfflineSpeakerDiarizationSegment struct {
Start float32
End float32
Speaker int
}

// The user has to invoke DeleteOfflineSpeakerDiarizationResult() to free the returned
// pointer to avoid memory leak
func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeakerDiarizationSegment {
r := C.SherpaOnnxOfflineSpeakerDiarizationProcess(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)))
defer C.SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r)

n := int(C.SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r))

if n == 0 {
return nil
}

s := C.SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r)
defer C.SherpaOnnxOfflineSpeakerDiarizationDestroySegment(s)

ans := make([]OfflineSpeakerDiarizationSegment, n)

p := (*[1 << 28]C.struct_SherpaOnnxOfflineSpeakerDiarizationSegment)(unsafe.Pointer(s))[:n:n]

for i := 0; i < n; i++ {
ans[i].Start = float32(p[i].start)
ans[i].End = float32(p[i].end)
ans[i].Speaker = int(p[i].speaker)
}

return ans
}

0 comments on commit 3813526

Please sign in to comment.