Skip to content

Commit

Permalink
Add JNI support for spoken language identification (k2-fsa#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Apr 17, 2024
1 parent b5106fb commit 8e8aaf8
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 42 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/test-go-package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,12 @@ jobs:
./run-vits-vctk.sh
rm -rf vits-vctk
echo "Test vits-zh-aishell3"
git clone https://huggingface.co/csukuangfj/vits-zh-aishell3
echo "Test vits-icefall-zh-aishell3"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2
tar xvf vits-icefall-zh-aishell3.tar.bz2
rm vits-icefall-zh-aishell3.tar.bz2
./run-vits-zh-aishell3.sh
rm -rf vits-zh-aishell3
rm -rf vits-icefall-zh-aishell3*
echo "Test vits-piper-en_US-lessac-medium"
git clone https://huggingface.co/csukuangfj/vits-piper-en_US-lessac-medium
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ sr-data

vits-icefall-*
sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
spoken-language-identification-test-wavs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import android.util.Log
private val TAG = "sherpa-onnx"

data class OfflineZipformerAudioTaggingModelConfig(
val model: String,
var model: String,
)

data class AudioTaggingModelConfig(
Expand Down Expand Up @@ -134,4 +134,4 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
}

return null
}
}
36 changes: 36 additions & 0 deletions kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,49 @@ fun callback(samples: FloatArray): Unit {
}

fun main() {
testSpokenLanguageIdentifcation()
testAudioTagging()
testSpeakerRecognition()
testTts()
testAsr("transducer")
testAsr("zipformer2-ctc")
}

fun testSpokenLanguageIdentifcation() {
val config = SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
tailPaddings = 33,
),
numThreads=1,
debug=true,
provider="cpu",
)
val slid = SpokenLanguageIdentification(assetManager=null, config=config)

val testFiles = arrayOf(
"./spoken-language-identification-test-wavs/ar-arabic.wav",
"./spoken-language-identification-test-wavs/bg-bulgarian.wav",
"./spoken-language-identification-test-wavs/de-german.wav",
)

for (waveFilename in testFiles) {
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int

val stream = slid.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
val lang = slid.compute(stream)
stream.release()
println(waveFilename)
println(lang)
}
}

fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,22 @@ import android.util.Log

private val TAG = "sherpa-onnx"

data class OfflineZipformerAudioTaggingModelConfig (
val model: String,
data class SpokenLanguageIdentificationWhisperConfig (
var encoder: String,
var decoder: String,
var tailPaddings: Int = -1,
)

data class AudioTaggingModelConfig (
var zipformer: OfflineZipformerAudioTaggingModelConfig,
data class SpokenLanguageIdentificationConfig (
var whisper: SpokenLanguageIdentificationWhisperConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)

data class AudioTaggingConfig (
var model: AudioTaggingModelConfig,
var labels: String,
var topK: Int = 5,
)

data class AudioEvent (
val name: String,
val index: Int,
val prob: Float,
)

class AudioTagging(
class SpokenLanguageIdentification (
assetManager: AssetManager? = null,
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
) {
private var ptr: Long

Expand All @@ -43,10 +33,10 @@ class AudioTagging(
}

protected fun finalize() {
if(ptr != 0) {
delete(ptr)
ptr = 0
}
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()
Expand All @@ -56,25 +46,22 @@ class AudioTagging(
return OfflineStream(p)
}

// fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
var events :Array<Any> = compute(ptr, stream.ptr, topK)
}
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)

private external fun newFromAsset(
assetManager: AssetManager,
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
): Long

private external fun newFromFile(
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
): Long

private external fun delete(ptr: Long)

private external fun createStream(ptr: Long): Long

private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
private external fun compute(ptr: Long, streamPtr: Long): String

companion object {
init {
Expand Down
32 changes: 24 additions & 8 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ cd ../kotlin-api-examples

function testSpeakerEmbeddingExtractor() {
if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./speaker1_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
fi

if [ ! -f ./speaker1_b_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
fi

if [ ! -f ./speaker2_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
fi
}

Expand All @@ -53,15 +53,15 @@ function testAsr() {
fi

if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
}

function testTts() {
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
tar xf vits-piper-en_US-amy-low.tar.bz2
rm vits-piper-en_US-amy-low.tar.bz2
fi
Expand All @@ -75,7 +75,22 @@ function testAudioTagging() {
fi
}

function testSpokenLanguageIdentification() {
if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
fi

if [ ! -f ./spoken-language-identification-test-wavs/ar-arabic.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2
tar xvf spoken-language-identification-test-wavs.tar.bz2
rm spoken-language-identification-test-wavs.tar.bz2
fi
}

function test() {
testSpokenLanguageIdentification
testAudioTagging
testSpeakerEmbeddingExtractor
testAsr
Expand All @@ -90,6 +105,7 @@ kotlinc-jvm -include-runtime -d main.jar \
OfflineStream.kt \
SherpaOnnx.kt \
Speaker.kt \
SpokenLanguageIdentification.kt \
Tts.kt \
WaveReader.kt \
faked-asset-manager.kt \
Expand All @@ -101,13 +117,13 @@ java -Djava.library.path=../build/lib -jar main.jar

function testTwoPass() {
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
fi

if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
rm sherpa-onnx-whisper-tiny.en.tar.bz2
fi
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_library(sherpa-onnx-jni
audio-tagging.cc
jni.cc
offline-stream.cc
spoken-language-identification.cc
)
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
install(TARGETS sherpa-onnx-jni DESTINATION lib)
104 changes: 104 additions & 0 deletions sherpa-onnx/jni/spoken-language-identification.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// sherpa-onnx/jni/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/spoken-language-identification.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"

namespace sherpa_onnx {

static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig(
JNIEnv *env, jobject config) {
SpokenLanguageIdentificationConfig ans;

jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(
cls, "whisper",
"Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;");

jobject whisper = env->GetObjectField(config, fid);
jclass whisper_cls = env->GetObjectClass(whisper);

fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;");

jstring s = (jstring)env->GetObjectField(whisper, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.whisper.encoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.whisper.decoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(whisper_cls, "tailPaddings", "I");
ans.whisper.tail_paddings = env->GetIntField(whisper, fid);

fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);

fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);

fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);

return ans;
}

} // namespace sherpa_onnx

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config =
sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config);
SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s",
config.ToString().c_str());

if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}

auto tagger = new sherpa_onnx::SpokenLanguageIdentification(config);

return (jlong)tagger;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto slid =
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
std::unique_ptr<sherpa_onnx::OfflineStream> s = slid->CreateStream();

// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OfflineStream *p = s.release();
return (jlong)p;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong s_ptr) {
sherpa_onnx::SpokenLanguageIdentification *slid =
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
sherpa_onnx::OfflineStream *s =
reinterpret_cast<sherpa_onnx::OfflineStream *>(s_ptr);
std::string lang = slid->Compute(s);
return env->NewStringUTF(lang.c_str());
}

0 comments on commit 8e8aaf8

Please sign in to comment.