Skip to content

Commit

Permalink
Add CTC HLG decoding for JNI (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Apr 25, 2024
1 parent 6686c7d commit f7b3735
Show file tree
Hide file tree
Showing 21 changed files with 429 additions and 13 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/jni.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest, macos-14]

steps:
- uses: actions/checkout@v4
Expand All @@ -49,6 +49,11 @@ jobs:
with:
key: ${{ matrix.os }}

- name: OS info
shell: bash
run: |
uname -a
- name: Display kotlin version
shell: bash
run: |
Expand All @@ -58,6 +63,7 @@ jobs:
shell: bash
run: |
java -version
javac -help
echo "JAVA_HOME is: ${JAVA_HOME}"
- name: Run JNI test
Expand Down
22 changes: 21 additions & 1 deletion .github/workflows/run-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
os: [ubuntu-latest, macos-latest, macos-14]

steps:
- uses: actions/checkout@v4
Expand All @@ -50,10 +50,24 @@ jobs:
with:
key: ${{ matrix.os }}-java

- name: OS info
shell: bash
run: |
uname -a
- uses: actions/setup-java@v4
with:
distribution: 'temurin' # See 'Supported distributions' for available options
java-version: '21'

- name: Display java version
shell: bash
run: |
java -version
java -help
echo "----"
javac -version
javac -help
echo "JAVA_HOME is: ${JAVA_HOME}"
cmake --version
Expand Down Expand Up @@ -100,6 +114,9 @@ jobs:
# Delete model files to save space
rm -rf sherpa-onnx-streaming-*
./run-streaming-decode-file-ctc-hlg.sh
rm -rf sherpa-onnx-streaming-*
./run-streaming-decode-file-paraformer.sh
rm -rf sherpa-onnx-streaming-*
Expand All @@ -118,3 +135,6 @@ jobs:
./run-non-streaming-decode-file-whisper.sh
rm -rf sherpa-onnx-whisper-*
./run-non-streaming-decode-file-nemo.sh
rm -rf sherpa-onnx-nemo-*
4 changes: 0 additions & 4 deletions dotnet-examples/offline-tts-play/offline-tts-play.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
<Nullable>enable</Nullable>
</PropertyGroup>

<PropertyGroup>
<RestoreSources>/tmp/packages;$(RestoreSources);https://api.nuget.org/v3/index.json</RestoreSources>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="CommandLineParser" Version="2.9.1" />
<PackageReference Include="org.k2fsa.sherpa.onnx" Version="*" />
Expand Down
49 changes: 49 additions & 0 deletions java-api-examples/NonStreamingDecodeFileNemo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2024 Xiaomi Corporation

// This file shows how to use an offline NeMo CTC model, i.e., non-streaming NeMo CTC model,,
// to decode files.
import com.k2fsa.sherpa.onnx.*;

public class NonStreamingDecodeFileNemo {
public static void main(String[] args) {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
// to download model files
String model = "./sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx";
String tokens = "./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt";

String waveFilename = "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav";

WaveReader reader = new WaveReader(waveFilename);

OfflineNemoEncDecCtcModelConfig nemo =
OfflineNemoEncDecCtcModelConfig.builder().setModel(model).build();

OfflineModelConfig modelConfig =
OfflineModelConfig.builder()
.setNemo(nemo)
.setTokens(tokens)
.setNumThreads(1)
.setDebug(true)
.build();

OfflineRecognizerConfig config =
OfflineRecognizerConfig.builder()
.setOfflineModelConfig(modelConfig)
.setDecodingMethod("greedy_search")
.build();

OfflineRecognizer recognizer = new OfflineRecognizer(config);
OfflineStream stream = recognizer.createStream();
stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());

recognizer.decode(stream);

String text = recognizer.getResult(stream).getText();

System.out.printf("filename:%s\nresult:%s\n", waveFilename, text);

stream.release();
recognizer.release();
}
}
2 changes: 2 additions & 0 deletions java-api-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This directory contains examples for the JAVA API of sherpa-onnx.

```
./run-streaming-decode-file-ctc.sh
./run-streaming-decode-file-ctc-hlg.sh
./run-streaming-decode-file-paraformer.sh
./run-streaming-decode-file-transducer.sh
```
Expand All @@ -18,4 +19,5 @@ This directory contains examples for the JAVA API of sherpa-onnx.
./run-non-streaming-decode-file-paraformer.sh
./run-non-streaming-decode-file-transducer.sh
./run-non-streaming-decode-file-whisper.sh
./run-non-streaming-decode-file-nemo.sh
```
58 changes: 58 additions & 0 deletions java-api-examples/StreamingDecodeFileCtcHLG.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 Xiaomi Corporation

// This file shows how to use an online CTC model, i.e., streaming CTC model,
// to decode files.
import com.k2fsa.sherpa.onnx.*;

public class StreamingDecodeFileCtcHLG {
public static void main(String[] args) {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
// to download model files
String model =
"./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx";
String tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt";
String hlg = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst";
String waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/8k.wav";

WaveReader reader = new WaveReader(waveFilename);

OnlineZipformer2CtcModelConfig ctc =
OnlineZipformer2CtcModelConfig.builder().setModel(model).build();

OnlineModelConfig modelConfig =
OnlineModelConfig.builder()
.setZipformer2Ctc(ctc)
.setTokens(tokens)
.setNumThreads(1)
.setDebug(true)
.build();

OnlineCtcFstDecoderConfig ctcFstDecoderConfig =
OnlineCtcFstDecoderConfig.builder().setGraph("hlg").build();

OnlineRecognizerConfig config =
OnlineRecognizerConfig.builder()
.setOnlineModelConfig(modelConfig)
.setCtcFstDecoderConfig(ctcFstDecoderConfig)
.build();

OnlineRecognizer recognizer = new OnlineRecognizer(config);
OnlineStream stream = recognizer.createStream();
stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());

float[] tailPaddings = new float[(int) (0.3 * reader.getSampleRate())];
stream.acceptWaveform(tailPaddings, reader.getSampleRate());

while (recognizer.isReady(stream)) {
recognizer.decode(stream);
}

String text = recognizer.getResult(stream).getText();

System.out.printf("filename:%s\nresult:%s\n", waveFilename, text);

stream.release();
recognizer.release();
}
}
51 changes: 51 additions & 0 deletions java-api-examples/run-non-streaming-decode-file-nemo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env bash

set -ex

if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
popd
fi

if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi

if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
fi

if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
fi

java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
NonStreamingDecodeFileNemo.java
36 changes: 36 additions & 0 deletions java-api-examples/run-streaming-decode-file-ctc-hlg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash
set -ex

if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
popd
fi

if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi

if [ ! -f ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
fi

java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
StreamingDecodeFileCtcHLG.java
24 changes: 24 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ function testOnlineAsr() {
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi

if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
fi

out_filename=test_online_asr.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_online_asr.kt \
Expand Down Expand Up @@ -160,6 +166,24 @@ function testOfflineAsr() {
rm sherpa-onnx-whisper-tiny.en.tar.bz2
fi

if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
fi

if [ ! -f ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
fi

if [ ! -f ./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2
tar xvf sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2
rm sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2
fi

out_filename=test_offline_asr.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_offline_asr.kt \
Expand Down
23 changes: 18 additions & 5 deletions kotlin-api-examples/test_offline_asr.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
package com.k2fsa.sherpa.onnx

fun main() {
val recognizer = createOfflineRecognizer()
val types = arrayOf(0, 2, 5, 6)
for (type in types) {
test(type)
}
}

fun test(type: Int) {
val recognizer = createOfflineRecognizer(type)

val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"
val waveFilename = when (type) {
0 -> "./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav"
2 -> "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
5 -> "./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/test_wavs/1.wav"
6 -> "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav"
else -> null
}

val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
filename = waveFilename!!,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
Expand All @@ -22,10 +35,10 @@ fun main() {
recognizer.release()
}

fun createOfflineRecognizer(): OfflineRecognizer {
fun createOfflineRecognizer(type: Int): OfflineRecognizer {
val config = OfflineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getOfflineModelConfig(type = 2)!!,
modelConfig = getOfflineModelConfig(type = type)!!,
)

return OfflineRecognizer(config = config)
Expand Down
Loading

0 comments on commit f7b3735

Please sign in to comment.