From 8e8aaf8d13590375461660f6e61074b94184a4f1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 17 Apr 2024 19:27:15 +0800 Subject: [PATCH] Add JNI support for spoken language identification (#782) --- .github/workflows/test-go-package.yaml | 8 +- .gitignore | 1 + .../sherpa/onnx/audio/tagging/AudioTagging.kt | 4 +- kotlin-api-examples/Main.kt | 36 ++++++ .../SpokenLanguageIdentification.kt | 45 +++----- kotlin-api-examples/run.sh | 32 ++++-- sherpa-onnx/jni/CMakeLists.txt | 1 + .../jni/spoken-language-identification.cc | 104 ++++++++++++++++++ 8 files changed, 189 insertions(+), 42 deletions(-) rename sherpa-onnx/jni/AudioTagging.kt => kotlin-api-examples/SpokenLanguageIdentification.kt (53%) create mode 100644 sherpa-onnx/jni/spoken-language-identification.cc diff --git a/.github/workflows/test-go-package.yaml b/.github/workflows/test-go-package.yaml index 2713295003..f76157ab29 100644 --- a/.github/workflows/test-go-package.yaml +++ b/.github/workflows/test-go-package.yaml @@ -161,10 +161,12 @@ jobs: ./run-vits-vctk.sh rm -rf vits-vctk - echo "Test vits-zh-aishell3" - git clone https://huggingface.co/csukuangfj/vits-zh-aishell3 + echo "Test vits-icefall-zh-aishell3" + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2 + tar xvf vits-icefall-zh-aishell3.tar.bz2 + rm vits-icefall-zh-aishell3.tar.bz2 ./run-vits-zh-aishell3.sh - rm -rf vits-zh-aishell3 + rm -rf vits-icefall-zh-aishell3* echo "Test vits-piper-en_US-lessac-medium" git clone https://huggingface.co/csukuangfj/vits-piper-en_US-lessac-medium diff --git a/.gitignore b/.gitignore index 3047a1e0a3..e0743e07f7 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,4 @@ sr-data vits-icefall-* sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 +spoken-language-identification-test-wavs diff --git a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt index 9c4b5cebdf..437302911a 100644 --- a/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt +++ b/android/SherpaOnnxAudioTagging/app/src/main/java/com/k2fsa/sherpa/onnx/audio/tagging/AudioTagging.kt @@ -6,7 +6,7 @@ import android.util.Log private val TAG = "sherpa-onnx" data class OfflineZipformerAudioTaggingModelConfig( - val model: String, + var model: String, ) data class AudioTaggingModelConfig( @@ -134,4 +134,4 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { } return null -} \ No newline at end of file +} diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index bc82c69934..479ce34285 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -7,6 +7,7 @@ fun callback(samples: FloatArray): Unit { } fun main() { + testSpokenLanguageIdentifcation() testAudioTagging() testSpeakerRecognition() testTts() @@ -14,6 +15,41 @@ fun main() { testAsr("zipformer2-ctc") } +fun testSpokenLanguageIdentifcation() { + val config = SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx", + decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx", + tailPaddings = 33, + ), + numThreads=1, + debug=true, + provider="cpu", + ) + val slid = SpokenLanguageIdentification(assetManager=null, config=config) + + val testFiles = arrayOf( + "./spoken-language-identification-test-wavs/ar-arabic.wav", + "./spoken-language-identification-test-wavs/bg-bulgarian.wav", + "./spoken-language-identification-test-wavs/de-german.wav", + ) + + for (waveFilename in testFiles) { + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = slid.createStream() + stream.acceptWaveform(samples, sampleRate = sampleRate) + val lang = slid.compute(stream) + stream.release() + println(waveFilename) + println(lang) + } +} + fun testAudioTagging() { val config = AudioTaggingConfig( model=AudioTaggingModelConfig( diff --git a/sherpa-onnx/jni/AudioTagging.kt b/kotlin-api-examples/SpokenLanguageIdentification.kt similarity index 53% rename from sherpa-onnx/jni/AudioTagging.kt rename to kotlin-api-examples/SpokenLanguageIdentification.kt index f3d8277967..ef117c8bfb 100644 --- a/sherpa-onnx/jni/AudioTagging.kt +++ b/kotlin-api-examples/SpokenLanguageIdentification.kt @@ -5,32 +5,22 @@ import android.util.Log private val TAG = "sherpa-onnx" -data class OfflineZipformerAudioTaggingModelConfig ( - val model: String, +data class SpokenLanguageIdentificationWhisperConfig ( + var encoder: String, + var decoder: String, + var tailPaddings: Int = -1, ) -data class AudioTaggingModelConfig ( - var zipformer: OfflineZipformerAudioTaggingModelConfig, +data class SpokenLanguageIdentificationConfig ( + var whisper: SpokenLanguageIdentificationWhisperConfig, var numThreads: Int = 1, var debug: Boolean = false, var provider: String = "cpu", ) -data class AudioTaggingConfig ( - var model: AudioTaggingModelConfig, - var labels: String, - var topK: Int = 5, -) - -data class AudioEvent ( - val name: String, - val index: Int, - val prob: Float, -) - -class AudioTagging( +class SpokenLanguageIdentification ( assetManager: AssetManager? = null, - config: AudioTaggingConfig, + config: SpokenLanguageIdentificationConfig, ) { private var ptr: Long @@ -43,10 +33,10 @@ class AudioTagging( } protected fun finalize() { - if(ptr != 0) { - delete(ptr) - ptr = 0 - } + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } } fun release() = finalize() @@ -56,25 +46,22 @@ class AudioTagging( return OfflineStream(p) } - // fun compute(stream: OfflineStream, topK: Int=-1): Array { - fun compute(stream: OfflineStream, topK: Int=-1): Array { - var events :Array = compute(ptr, stream.ptr, topK) - } + fun compute(stream: OfflineStream) = compute(ptr, stream.ptr) private external fun newFromAsset( assetManager: AssetManager, - config: AudioTaggingConfig, + config: SpokenLanguageIdentificationConfig, ): Long private external fun newFromFile( - config: AudioTaggingConfig, + config: SpokenLanguageIdentificationConfig, ): Long private external fun delete(ptr: Long) private external fun createStream(ptr: Long): Long - private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array + private external fun compute(ptr: Long, streamPtr: Long): String companion object { init { diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 750a70e5c2..f14e169cda 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -30,19 +30,19 @@ cd ../kotlin-api-examples function testSpeakerEmbeddingExtractor() { if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx fi if [ ! -f ./speaker1_a_cn_16k.wav ]; then - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav + curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav fi if [ ! -f ./speaker1_b_cn_16k.wav ]; then - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav + curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav fi if [ ! -f ./speaker2_a_cn_16k.wav ]; then - wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav + curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav fi } @@ -53,7 +53,7 @@ function testAsr() { fi if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 fi @@ -61,7 +61,7 @@ function testAsr() { function testTts() { if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then - wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 tar xf vits-piper-en_US-amy-low.tar.bz2 rm vits-piper-en_US-amy-low.tar.bz2 fi @@ -75,7 +75,22 @@ function testAudioTagging() { fi } +function testSpokenLanguageIdentification() { + if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 + tar xvf sherpa-onnx-whisper-tiny.tar.bz2 + rm sherpa-onnx-whisper-tiny.tar.bz2 + fi + + if [ ! -f ./spoken-language-identification-test-wavs/ar-arabic.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2 + tar xvf spoken-language-identification-test-wavs.tar.bz2 + rm spoken-language-identification-test-wavs.tar.bz2 + fi +} + function test() { + testSpokenLanguageIdentification testAudioTagging testSpeakerEmbeddingExtractor testAsr @@ -90,6 +105,7 @@ kotlinc-jvm -include-runtime -d main.jar \ OfflineStream.kt \ SherpaOnnx.kt \ Speaker.kt \ + SpokenLanguageIdentification.kt \ Tts.kt \ WaveReader.kt \ faked-asset-manager.kt \ @@ -101,13 +117,13 @@ java -Djava.library.path=../build/lib -jar main.jar function testTwoPass() { if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 fi if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 rm sherpa-onnx-whisper-tiny.en.tar.bz2 fi diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index 75b6a1bb5b..6f14a35fa7 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -13,6 +13,7 @@ add_library(sherpa-onnx-jni audio-tagging.cc jni.cc offline-stream.cc + spoken-language-identification.cc ) target_link_libraries(sherpa-onnx-jni sherpa-onnx-core) install(TARGETS sherpa-onnx-jni DESTINATION lib) diff --git a/sherpa-onnx/jni/spoken-language-identification.cc b/sherpa-onnx/jni/spoken-language-identification.cc new file mode 100644 index 0000000000..0bff585d4e --- /dev/null +++ b/sherpa-onnx/jni/spoken-language-identification.cc @@ -0,0 +1,104 @@ +// sherpa-onnx/jni/spoken-language-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/spoken-language-identification.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +namespace sherpa_onnx { + +static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig( + JNIEnv *env, jobject config) { + SpokenLanguageIdentificationConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid = env->GetFieldID( + cls, "whisper", + "Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;"); + + jobject whisper = env->GetObjectField(config, fid); + jclass whisper_cls = env->GetObjectClass(whisper); + + fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;"); + + jstring s = (jstring)env->GetObjectField(whisper, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.whisper.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.whisper.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_cls, "tailPaddings", "I"); + ans.whisper.tail_paddings = env->GetIntField(whisper, fid); + + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_onnx + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = + sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config); + SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s", + config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto tagger = new sherpa_onnx::SpokenLanguageIdentification(config); + + return (jlong)tagger; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto slid = + reinterpret_cast(ptr); + std::unique_ptr s = slid->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_onnx::OfflineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL +Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_compute(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong s_ptr) { + sherpa_onnx::SpokenLanguageIdentification *slid = + reinterpret_cast(ptr); + sherpa_onnx::OfflineStream *s = + reinterpret_cast(s_ptr); + std::string lang = slid->Compute(s); + return env->NewStringUTF(lang.c_str()); +}