Skip to content

Commit

Permalink
add more models for speaker diarization (#1440)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 17, 2024
1 parent 4783c8f commit e0586f1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ val segmentationModel = "segmentation.onnx"

// please download it from
// https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
// and rename it to embedding.onnx
// and move it to the assets folder
val embeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
val embeddingModel = "embedding.onnx"

// in the end, your assets folder should look like below
/*
(py38) fangjuns-MacBook-Pro:assets fangjun$ pwd
/Users/fangjun/open-source/sherpa-onnx/android/SherpaOnnxSpeakerDiarization/app/src/main/assets
(py38) fangjuns-MacBook-Pro:assets fangjun$ ls -lh
total 89048
-rw-r--r-- 1 fangjun staff 38M Oct 12 20:28 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
-rw-r--r-- 1 fangjun staff 38M Oct 12 20:28 embedding.onnx
-rw-r--r-- 1 fangjun staff 5.7M Oct 12 20:28 segmentation.onnx
*/

Expand Down Expand Up @@ -63,4 +64,4 @@ object SpeakerDiarizationObject {
_sd = OfflineSpeakerDiarization(assetManager = assetManager, config = config)
}
}
}
}
26 changes: 14 additions & 12 deletions scripts/apk/build-apk-speaker-diarization.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@ pushd ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/

ls -lh

model_name={{ model.model_name }}
short_name={{ model.short_name }}
segmentation_model_name={{ model.segmentation.model_name }}
segmentation_short_name={{ model.segmentation.short_name }}

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$model_name.tar.bz2
tar xvf $model_name.tar.bz2
rm $model_name.tar.bz2
mv $model_name/model.onnx segmentation.onnx
rm -rf $model_name
embedding_model_name={{ model.embedding.model_name }}
embedding_short_name={{ model.embedding.short_name }}

if [ ! -f 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$segmentation_model_name.tar.bz2
tar xvf $segmentation_model_name.tar.bz2
rm $segmentation_model_name.tar.bz2
mv $segmentation_model_name/model.onnx segmentation.onnx
rm -rf $segmentation_model_name

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$embedding_model_name.onnx
mv $embedding_model_name.onnx embedding.onnx

echo "pwd: $PWD"
ls -lh
Expand All @@ -74,12 +76,12 @@ for arch in arm64-v8a armeabi-v7a x86_64 x86; do
./gradlew build
popd

mv android/SherpaOnnxSpeakerDiarization/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-speaker-diarization-$short_name-3dspeaker.apk
mv android/SherpaOnnxSpeakerDiarization/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-${SHERPA_ONNX_VERSION}-$arch-speaker-diarization-$segmentation_short_name-$embedding_short_name.apk
ls -lh apks
rm -v ./android/SherpaOnnxSpeakerDiarization/app/src/main/jniLibs/$arch/*.so
done

rm -rf ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/segmentation.onnx
rm -rf ./android/SherpaOnnxSpeakerDiarization/app/src/main/assets/*.onnx

{% endfor %}

Expand Down
38 changes: 35 additions & 3 deletions scripts/apk/generate-speaker-diarization-apk-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,22 @@ def get_args():
@dataclass
class SpeakerSegmentationModel:
model_name: str
short_name: str = ""
short_name: str


def get_models() -> List[SpeakerSegmentationModel]:
@dataclass
class SpeakerEmbeddingModel:
model_name: str
short_name: str


@dataclass
class Model:
segmentation: SpeakerSegmentationModel
embedding: SpeakerEmbeddingModel


def get_segmentation_models() -> List[SpeakerSegmentationModel]:
models = [
SpeakerSegmentationModel(
model_name="sherpa-onnx-pyannote-segmentation-3-0",
Expand All @@ -45,13 +57,33 @@ def get_models() -> List[SpeakerSegmentationModel]:
return models


def get_embedding_models() -> List[SpeakerEmbeddingModel]:
models = [
SpeakerSegmentationModel(
model_name="3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k",
short_name="3dspeaker",
),
SpeakerSegmentationModel(
model_name="nemo_en_titanet_small",
short_name="nemo",
),
]
return models


def main():
args = get_args()
index = args.index
total = args.total
assert 0 <= index < total, (index, total)

all_model_list = get_models()
segmentation_models = get_segmentation_models()
embedding_models = get_embedding_models()

all_model_list = []
for s in segmentation_models:
for e in embedding_models:
all_model_list.append(Model(segmentation=s, embedding=e))

num_models = len(all_model_list)

Expand Down

0 comments on commit e0586f1

Please sign in to comment.