From f7b3735621a4b5375a49e7256dc0a9331b19d7a3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 25 Apr 2024 17:20:02 +0800 Subject: [PATCH] Add CTC HLG decoding for JNI (#810) --- .github/workflows/jni.yaml | 8 ++- .github/workflows/run-java-test.yaml | 22 ++++++- .../offline-tts-play/offline-tts-play.csproj | 4 -- .../NonStreamingDecodeFileNemo.java | 49 ++++++++++++++++ java-api-examples/README.md | 2 + .../StreamingDecodeFileCtcHLG.java | 58 +++++++++++++++++++ .../run-non-streaming-decode-file-nemo.sh | 51 ++++++++++++++++ .../run-streaming-decode-file-ctc-hlg.sh | 36 ++++++++++++ kotlin-api-examples/run.sh | 24 ++++++++ kotlin-api-examples/test_offline_asr.kt | 23 ++++++-- kotlin-api-examples/test_online_asr.kt | 15 +++++ sherpa-onnx/java-api/Makefile | 8 ++- .../k2fsa/sherpa/onnx/OfflineModelConfig.java | 8 +++ .../onnx/OfflineNemoEncDecCtcModelConfig.java | 31 ++++++++++ .../onnx/OnlineCtcFstDecoderConfig.java | 43 ++++++++++++++ .../sherpa/onnx/OnlineRecognizerConfig.java | 9 +++ sherpa-onnx/jni/offline-recognizer.cc | 13 +++++ sherpa-onnx/jni/online-recognizer.cc | 16 +++++ sherpa-onnx/jni/wave-reader.cc | 1 + sherpa-onnx/kotlin-api/OfflineRecognizer.kt | 15 +++++ sherpa-onnx/kotlin-api/OnlineRecognizer.kt | 6 ++ 21 files changed, 429 insertions(+), 13 deletions(-) create mode 100644 java-api-examples/NonStreamingDecodeFileNemo.java create mode 100644 java-api-examples/StreamingDecodeFileCtcHLG.java create mode 100755 java-api-examples/run-non-streaming-decode-file-nemo.sh create mode 100755 java-api-examples/run-streaming-decode-file-ctc-hlg.sh create mode 100644 sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig.java create mode 100644 sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig.java diff --git a/.github/workflows/jni.yaml b/.github/workflows/jni.yaml index 459cc3836..fdee728fd 100644 --- a/.github/workflows/jni.yaml +++ b/.github/workflows/jni.yaml @@ -37,7 +37,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, macos-14] steps: - uses: actions/checkout@v4 @@ -49,6 +49,11 @@ jobs: with: key: ${{ matrix.os }} + - name: OS info + shell: bash + run: | + uname -a + - name: Display kotlin version shell: bash run: | @@ -58,6 +63,7 @@ jobs: shell: bash run: | java -version + javac -help echo "JAVA_HOME is: ${JAVA_HOME}" - name: Run JNI test diff --git a/.github/workflows/run-java-test.yaml b/.github/workflows/run-java-test.yaml index 533d868c6..d84e2cee4 100644 --- a/.github/workflows/run-java-test.yaml +++ b/.github/workflows/run-java-test.yaml @@ -38,7 +38,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest] + os: [ubuntu-latest, macos-latest, macos-14] steps: - uses: actions/checkout@v4 @@ -50,10 +50,24 @@ jobs: with: key: ${{ matrix.os }}-java + - name: OS info + shell: bash + run: | + uname -a + + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' # See 'Supported distributions' for available options + java-version: '21' + - name: Display java version shell: bash run: | java -version + java -help + echo "----" + javac -version + javac -help echo "JAVA_HOME is: ${JAVA_HOME}" cmake --version @@ -100,6 +114,9 @@ jobs: # Delete model files to save space rm -rf sherpa-onnx-streaming-* + ./run-streaming-decode-file-ctc-hlg.sh + rm -rf sherpa-onnx-streaming-* + ./run-streaming-decode-file-paraformer.sh rm -rf sherpa-onnx-streaming-* @@ -118,3 +135,6 @@ jobs: ./run-non-streaming-decode-file-whisper.sh rm -rf sherpa-onnx-whisper-* + + ./run-non-streaming-decode-file-nemo.sh + rm -rf sherpa-onnx-nemo-* diff --git a/dotnet-examples/offline-tts-play/offline-tts-play.csproj b/dotnet-examples/offline-tts-play/offline-tts-play.csproj index f0ced2453..85caf7e46 100644 --- a/dotnet-examples/offline-tts-play/offline-tts-play.csproj +++ b/dotnet-examples/offline-tts-play/offline-tts-play.csproj @@ -8,10 +8,6 @@ enable - - /tmp/packages;$(RestoreSources);https://api.nuget.org/v3/index.json - - diff --git a/java-api-examples/NonStreamingDecodeFileNemo.java b/java-api-examples/NonStreamingDecodeFileNemo.java new file mode 100644 index 000000000..5286e172c --- /dev/null +++ b/java-api-examples/NonStreamingDecodeFileNemo.java @@ -0,0 +1,49 @@ +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use an offline NeMo CTC model, i.e., non-streaming NeMo CTC model,, +// to decode files. +import com.k2fsa.sherpa.onnx.*; + +public class NonStreamingDecodeFileNemo { + public static void main(String[] args) { + // please refer to + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + // to download model files + String model = "./sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"; + String tokens = "./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"; + + String waveFilename = "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav"; + + WaveReader reader = new WaveReader(waveFilename); + + OfflineNemoEncDecCtcModelConfig nemo = + OfflineNemoEncDecCtcModelConfig.builder().setModel(model).build(); + + OfflineModelConfig modelConfig = + OfflineModelConfig.builder() + .setNemo(nemo) + .setTokens(tokens) + .setNumThreads(1) + .setDebug(true) + .build(); + + OfflineRecognizerConfig config = + OfflineRecognizerConfig.builder() + .setOfflineModelConfig(modelConfig) + .setDecodingMethod("greedy_search") + .build(); + + OfflineRecognizer recognizer = new OfflineRecognizer(config); + OfflineStream stream = recognizer.createStream(); + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + + recognizer.decode(stream); + + String text = recognizer.getResult(stream).getText(); + + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text); + + stream.release(); + recognizer.release(); + } +} diff --git a/java-api-examples/README.md b/java-api-examples/README.md index ae346843a..a217360e9 100755 --- a/java-api-examples/README.md +++ b/java-api-examples/README.md @@ -8,6 +8,7 @@ This directory contains examples for the JAVA API of sherpa-onnx. ``` ./run-streaming-decode-file-ctc.sh +./run-streaming-decode-file-ctc-hlg.sh ./run-streaming-decode-file-paraformer.sh ./run-streaming-decode-file-transducer.sh ``` @@ -18,4 +19,5 @@ This directory contains examples for the JAVA API of sherpa-onnx. ./run-non-streaming-decode-file-paraformer.sh ./run-non-streaming-decode-file-transducer.sh ./run-non-streaming-decode-file-whisper.sh +./run-non-streaming-decode-file-nemo.sh ``` diff --git a/java-api-examples/StreamingDecodeFileCtcHLG.java b/java-api-examples/StreamingDecodeFileCtcHLG.java new file mode 100644 index 000000000..73a738732 --- /dev/null +++ b/java-api-examples/StreamingDecodeFileCtcHLG.java @@ -0,0 +1,58 @@ +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use an online CTC model, i.e., streaming CTC model, +// to decode files. +import com.k2fsa.sherpa.onnx.*; + +public class StreamingDecodeFileCtcHLG { + public static void main(String[] args) { + // please refer to + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + // to download model files + String model = + "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx"; + String tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt"; + String hlg = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst"; + String waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/8k.wav"; + + WaveReader reader = new WaveReader(waveFilename); + + OnlineZipformer2CtcModelConfig ctc = + OnlineZipformer2CtcModelConfig.builder().setModel(model).build(); + + OnlineModelConfig modelConfig = + OnlineModelConfig.builder() + .setZipformer2Ctc(ctc) + .setTokens(tokens) + .setNumThreads(1) + .setDebug(true) + .build(); + + OnlineCtcFstDecoderConfig ctcFstDecoderConfig = + OnlineCtcFstDecoderConfig.builder().setGraph("hlg").build(); + + OnlineRecognizerConfig config = + OnlineRecognizerConfig.builder() + .setOnlineModelConfig(modelConfig) + .setCtcFstDecoderConfig(ctcFstDecoderConfig) + .build(); + + OnlineRecognizer recognizer = new OnlineRecognizer(config); + OnlineStream stream = recognizer.createStream(); + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + + float[] tailPaddings = new float[(int) (0.3 * reader.getSampleRate())]; + stream.acceptWaveform(tailPaddings, reader.getSampleRate()); + + while (recognizer.isReady(stream)) { + recognizer.decode(stream); + } + + String text = recognizer.getResult(stream).getText(); + + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text); + + stream.release(); + recognizer.release(); + } +} diff --git a/java-api-examples/run-non-streaming-decode-file-nemo.sh b/java-api-examples/run-non-streaming-decode-file-nemo.sh new file mode 100755 index 000000000..9ccfa9df7 --- /dev/null +++ b/java-api-examples/run-non-streaming-decode-file-nemo.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib +fi + +if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + NonStreamingDecodeFileNemo.java diff --git a/java-api-examples/run-streaming-decode-file-ctc-hlg.sh b/java-api-examples/run-streaming-decode-file-ctc-hlg.sh new file mode 100755 index 000000000..5400f11b7 --- /dev/null +++ b/java-api-examples/run-streaming-decode-file-ctc-hlg.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [ ! -f ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + StreamingDecodeFileCtcHLG.java diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index bc9184021..310748375 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -69,6 +69,12 @@ function testOnlineAsr() { rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 fi + if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + fi + out_filename=test_online_asr.jar kotlinc-jvm -include-runtime -d $out_filename \ test_online_asr.kt \ @@ -160,6 +166,24 @@ function testOfflineAsr() { rm sherpa-onnx-whisper-tiny.en.tar.bz2 fi + if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 + rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2 + tar xvf sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2 + rm sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2 + fi + out_filename=test_offline_asr.jar kotlinc-jvm -include-runtime -d $out_filename \ test_offline_asr.kt \ diff --git a/kotlin-api-examples/test_offline_asr.kt b/kotlin-api-examples/test_offline_asr.kt index d218e4b6a..a0db54f90 100644 --- a/kotlin-api-examples/test_offline_asr.kt +++ b/kotlin-api-examples/test_offline_asr.kt @@ -1,12 +1,25 @@ package com.k2fsa.sherpa.onnx fun main() { - val recognizer = createOfflineRecognizer() + val types = arrayOf(0, 2, 5, 6) + for (type in types) { + test(type) + } +} + +fun test(type: Int) { + val recognizer = createOfflineRecognizer(type) - val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav" + val waveFilename = when (type) { + 0 -> "./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav" + 2 -> "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" + 5 -> "./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/test_wavs/1.wav" + 6 -> "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav" + else -> null + } val objArray = WaveReader.readWaveFromFile( - filename = waveFilename, + filename = waveFilename!!, ) val samples: FloatArray = objArray[0] as FloatArray val sampleRate: Int = objArray[1] as Int @@ -22,10 +35,10 @@ fun main() { recognizer.release() } -fun createOfflineRecognizer(): OfflineRecognizer { +fun createOfflineRecognizer(type: Int): OfflineRecognizer { val config = OfflineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), - modelConfig = getOfflineModelConfig(type = 2)!!, + modelConfig = getOfflineModelConfig(type = type)!!, ) return OfflineRecognizer(config = config) diff --git a/kotlin-api-examples/test_online_asr.kt b/kotlin-api-examples/test_online_asr.kt index d6236f8af..7376ab61e 100644 --- a/kotlin-api-examples/test_online_asr.kt +++ b/kotlin-api-examples/test_online_asr.kt @@ -3,6 +3,7 @@ package com.k2fsa.sherpa.onnx fun main() { testOnlineAsr("transducer") testOnlineAsr("zipformer2-ctc") + testOnlineAsr("ctc-hlg") } fun testOnlineAsr(type: String) { @@ -11,6 +12,7 @@ fun testOnlineAsr(type: String) { featureDim = 80, ) + var ctcFstDecoderConfig = OnlineCtcFstDecoderConfig() val waveFilename: String val modelConfig: OnlineModelConfig = when (type) { "transducer" -> { @@ -40,6 +42,18 @@ fun testOnlineAsr(type: String) { debug = false, ) } + "ctc-hlg" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/1.wav" + ctcFstDecoderConfig.graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst" + OnlineModelConfig( + zipformer2Ctc = OnlineZipformer2CtcModelConfig( + model = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt", + numThreads = 1, + debug = false, + ) + } else -> throw IllegalArgumentException(type) } @@ -51,6 +65,7 @@ fun testOnlineAsr(type: String) { modelConfig = modelConfig, lmConfig = lmConfig, featConfig = featConfig, + ctcFstDecoderConfig=ctcFstDecoderConfig, endpointConfig = endpointConfig, enableEndpoint = true, decodingMethod = "greedy_search", diff --git a/sherpa-onnx/java-api/Makefile b/sherpa-onnx/java-api/Makefile index 7c28a278d..f93f5e795 100644 --- a/sherpa-onnx/java-api/Makefile +++ b/sherpa-onnx/java-api/Makefile @@ -14,6 +14,7 @@ java_files += OnlineParaformerModelConfig.java java_files += OnlineZipformer2CtcModelConfig.java java_files += OnlineTransducerModelConfig.java java_files += OnlineModelConfig.java +java_files += OnlineCtcFstDecoderConfig.java java_files += OnlineStream.java java_files += OnlineRecognizerConfig.java java_files += OnlineRecognizerResult.java @@ -22,6 +23,7 @@ java_files += OnlineRecognizer.java java_files += OfflineTransducerModelConfig.java java_files += OfflineParaformerModelConfig.java java_files += OfflineWhisperModelConfig.java +java_files += OfflineNemoEncDecCtcModelConfig.java java_files += OfflineModelConfig.java java_files += OfflineRecognizerConfig.java java_files += OfflineRecognizerResult.java @@ -42,10 +44,12 @@ $(info -- class files $(class_files)) all: $(out_jar) $(out_jar): $(class_files) - jar --create --verbose --file $(out_jar) -C $(out_dir) . + # jar --create --verbose --file $(out_jar) -C $(out_dir) ./ + jar cvf $(out_jar) -C $(out_dir) ./ clean: $(RM) -rfv $(out_dir) $(class_files): $(out_dir)/$(package_dir)/%.class: src/$(package_dir)/%.java - javac -d $(out_dir) --class-path $(out_dir) $< + mkdir -p build + javac -d $(out_dir) -cp $(out_dir) $< diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineModelConfig.java index 9de59f7a0..c51f789a8 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineModelConfig.java @@ -5,6 +5,7 @@ public class OfflineModelConfig { private final OfflineTransducerModelConfig transducer; private final OfflineParaformerModelConfig paraformer; private final OfflineWhisperModelConfig whisper; + private final OfflineNemoEncDecCtcModelConfig nemo; private final String tokens; private final int numThreads; private final boolean debug; @@ -16,6 +17,7 @@ private OfflineModelConfig(Builder builder) { this.transducer = builder.transducer; this.paraformer = builder.paraformer; this.whisper = builder.whisper; + this.nemo = builder.nemo; this.tokens = builder.tokens; this.numThreads = builder.numThreads; this.debug = builder.debug; @@ -64,6 +66,7 @@ public static class Builder { private OfflineParaformerModelConfig paraformer = OfflineParaformerModelConfig.builder().build(); private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build(); private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build(); + private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build(); private String tokens = ""; private int numThreads = 1; private boolean debug = true; @@ -84,6 +87,11 @@ public Builder setParaformer(OfflineParaformerModelConfig paraformer) { return this; } + public Builder setNemo(OfflineNemoEncDecCtcModelConfig nemo) { + this.nemo = nemo; + return this; + } + public Builder setWhisper(OfflineWhisperModelConfig whisper) { this.whisper = whisper; return this; diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig.java new file mode 100644 index 000000000..d921c0332 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig.java @@ -0,0 +1,31 @@ +// Copyright 2024 Xiaomi Corporation +package com.k2fsa.sherpa.onnx; + +public class OfflineNemoEncDecCtcModelConfig { + private final String model; + + private OfflineNemoEncDecCtcModelConfig(Builder builder) { + this.model = builder.model; + } + + public static Builder builder() { + return new Builder(); + } + + public String getModel() { + return model; + } + + public static class Builder { + private String model = ""; + + public OfflineNemoEncDecCtcModelConfig build() { + return new OfflineNemoEncDecCtcModelConfig(this); + } + + public Builder setModel(String model) { + this.model = model; + return this; + } + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig.java new file mode 100644 index 000000000..7d8bc85ba --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig.java @@ -0,0 +1,43 @@ +// Copyright 2024 Xiaomi Corporation +package com.k2fsa.sherpa.onnx; + +public class OnlineCtcFstDecoderConfig { + private final String graph; + private final int maxActive; + + private OnlineCtcFstDecoderConfig(Builder builder) { + this.graph = builder.graph; + this.maxActive = builder.maxActive; + } + + public static Builder builder() { + return new Builder(); + } + + public String getGraph() { + return graph; + } + + public float getMaxActive() { + return maxActive; + } + + public static class Builder { + private String graph = ""; + private int maxActive = 3000; + + public OnlineCtcFstDecoderConfig build() { + return new OnlineCtcFstDecoderConfig(this); + } + + public Builder setGraph(String model) { + this.graph = graph; + return this; + } + + public Builder setMaxActive(int maxActive) { + this.maxActive = maxActive; + return this; + } + } +} \ No newline at end of file diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java index 92bfaa054..e124088e2 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -6,6 +6,8 @@ public class OnlineRecognizerConfig { private final FeatureConfig featConfig; private final OnlineModelConfig modelConfig; private final OnlineLMConfig lmConfig; + + private final OnlineCtcFstDecoderConfig ctcFstDecoderConfig; private final EndpointConfig endpointConfig; private final boolean enableEndpoint; private final String decodingMethod; @@ -17,6 +19,7 @@ private OnlineRecognizerConfig(Builder builder) { this.featConfig = builder.featConfig; this.modelConfig = builder.modelConfig; this.lmConfig = builder.lmConfig; + this.ctcFstDecoderConfig = builder.ctcFstDecoderConfig; this.endpointConfig = builder.endpointConfig; this.enableEndpoint = builder.enableEndpoint; this.decodingMethod = builder.decodingMethod; @@ -37,6 +40,7 @@ public static class Builder { private FeatureConfig featConfig = FeatureConfig.builder().build(); private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build(); private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build(); + private OnlineCtcFstDecoderConfig ctcFstDecoderConfig = OnlineCtcFstDecoderConfig.builder().build(); private EndpointConfig endpointConfig = EndpointConfig.builder().build(); private boolean enableEndpoint = true; private String decodingMethod = "greedy_search"; @@ -63,6 +67,11 @@ public Builder setOnlineLMConfig(OnlineLMConfig lmConfig) { return this; } + public Builder setCtcFstDecoderConfig(OnlineCtcFstDecoderConfig ctcFstDecoderConfig) { + this.ctcFstDecoderConfig = ctcFstDecoderConfig; + return this; + } + public Builder setEndpointConfig(EndpointConfig endpointConfig) { this.endpointConfig = endpointConfig; return this; diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 0103417a2..0bbcefd49 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -147,6 +147,19 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { ans.model_config.whisper.tail_paddings = env->GetIntField(whisper_config, fid); + fid = env->GetFieldID( + model_config_cls, "nemo", + "Lcom/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig;"); + jobject nemo_config = env->GetObjectField(model_config, fid); + jclass nemo_config_cls = env->GetObjectClass(nemo_config); + + fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;"); + + s = (jstring)env->GetObjectField(nemo_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.nemo_ctc.model = p; + env->ReleaseStringUTFChars(s, p); + return ans; } diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index 8fa069c05..c88cfd6df 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -198,6 +198,22 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); + fid = env->GetFieldID(cls, "ctcFstDecoderConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig;"); + + jobject fst_decoder_config = env->GetObjectField(config, fid); + jclass fst_decoder_config_cls = env->GetObjectClass(fst_decoder_config); + + fid = env->GetFieldID(fst_decoder_config_cls, "graph", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(fst_decoder_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.ctc_fst_decoder_config.graph = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(fst_decoder_config_cls, "maxActive", "I"); + ans.ctc_fst_decoder_config.max_active = + env->GetIntField(fst_decoder_config, fid); + return ans; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/jni/wave-reader.cc b/sherpa-onnx/jni/wave-reader.cc index 489240583..a3ca55365 100644 --- a/sherpa-onnx/jni/wave-reader.cc +++ b/sherpa-onnx/jni/wave-reader.cc @@ -6,6 +6,7 @@ #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/jni/common.h" static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index a559e662b..cca5e8757 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -18,6 +18,10 @@ data class OfflineParaformerModelConfig( var model: String = "", ) +data class OfflineNemoEncDecCtcModelConfig( + var model: String = "", +) + data class OfflineWhisperModelConfig( var encoder: String = "", var decoder: String = "", @@ -30,6 +34,7 @@ data class OfflineModelConfig( var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(), var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(), var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(), + var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(), var numThreads: Int = 1, var debug: Boolean = false, var provider: String = "cpu", @@ -216,6 +221,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { ) } + 6 -> { + val modelDir = "sherpa-onnx-nemo-ctc-en-citrinet-512" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + } return null } diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index cd2629c97..894943768 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -45,11 +45,17 @@ data class OnlineLMConfig( var scale: Float = 0.5f, ) +data class OnlineCtcFstDecoderConfig( + var graph: String = "", + var maxActive: Int = 3000, +) + data class OnlineRecognizerConfig( var featConfig: FeatureConfig = FeatureConfig(), var modelConfig: OnlineModelConfig, var lmConfig: OnlineLMConfig = OnlineLMConfig(), + var ctcFstDecoderConfig : OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(), var endpointConfig: EndpointConfig = EndpointConfig(), var enableEndpoint: Boolean = true, var decodingMethod: String = "greedy_search",