Skip to content

Commit

Permalink
Add streaming CTC HLG decoding for JNI.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 25, 2024
1 parent af8344c commit 7b2a3a5
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/run-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ jobs:
# Delete model files to save space
rm -rf sherpa-onnx-streaming-*
./run-streaming-decode-file-ctc-hlg.sh
rm -rf sherpa-onnx-streaming-*
./run-streaming-decode-file-paraformer.sh
rm -rf sherpa-onnx-streaming-*
Expand All @@ -118,3 +121,6 @@ jobs:
./run-non-streaming-decode-file-whisper.sh
rm -rf sherpa-onnx-whisper-*
./run-non-streaming-decode-file-nemo.sh
rm -rf sherpa-onnx-nemo-*
4 changes: 0 additions & 4 deletions dotnet-examples/offline-tts-play/offline-tts-play.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
<Nullable>enable</Nullable>
</PropertyGroup>

<PropertyGroup>
<RestoreSources>/tmp/packages;$(RestoreSources);https://api.nuget.org/v3/index.json</RestoreSources>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="CommandLineParser" Version="2.9.1" />
<PackageReference Include="org.k2fsa.sherpa.onnx" Version="*" />
Expand Down
2 changes: 2 additions & 0 deletions java-api-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This directory contains examples for the JAVA API of sherpa-onnx.

```
./run-streaming-decode-file-ctc.sh
./run-streaming-decode-file-ctc-hlg.sh
./run-streaming-decode-file-paraformer.sh
./run-streaming-decode-file-transducer.sh
```
Expand All @@ -18,4 +19,5 @@ This directory contains examples for the JAVA API of sherpa-onnx.
./run-non-streaming-decode-file-paraformer.sh
./run-non-streaming-decode-file-transducer.sh
./run-non-streaming-decode-file-whisper.sh
./run-non-streaming-decode-file-nemo.sh
```
58 changes: 58 additions & 0 deletions java-api-examples/StreamingDecodeFileCtcHLG.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 Xiaomi Corporation

// This file shows how to use an online CTC model, i.e., streaming CTC model,
// to decode files.
import com.k2fsa.sherpa.onnx.*;

public class StreamingDecodeFileCtcHLG {
public static void main(String[] args) {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
// to download model files
String model =
"./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx";
String tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt";
String hlg = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst";
String waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/8k.wav";

WaveReader reader = new WaveReader(waveFilename);

OnlineZipformer2CtcModelConfig ctc =
OnlineZipformer2CtcModelConfig.builder().setModel(model).build();

OnlineModelConfig modelConfig =
OnlineModelConfig.builder()
.setZipformer2Ctc(ctc)
.setTokens(tokens)
.setNumThreads(1)
.setDebug(true)
.build();

OnlineCtcFstDecoderConfig ctcFstDecoderConfig =
OnlineCtcFstDecoderConfig.builder().setGraph("hlg").build();

OnlineRecognizerConfig config =
OnlineRecognizerConfig.builder()
.setOnlineModelConfig(modelConfig)
.setCtcFstDecoderConfig(ctcFstDecoderConfig)
.build();

OnlineRecognizer recognizer = new OnlineRecognizer(config);
OnlineStream stream = recognizer.createStream();
stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());

float[] tailPaddings = new float[(int) (0.3 * reader.getSampleRate())];
stream.acceptWaveform(tailPaddings, reader.getSampleRate());

while (recognizer.isReady(stream)) {
recognizer.decode(stream);
}

String text = recognizer.getResult(stream).getText();

System.out.printf("filename:%s\nresult:%s\n", waveFilename, text);

stream.release();
recognizer.release();
}
}
36 changes: 36 additions & 0 deletions java-api-examples/run-streaming-decode-file-ctc-hlg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash
set -ex

if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
mkdir -p ../build
pushd ../build
cmake \
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
-DBUILD_SHARED_LIBS=ON \
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
-DSHERPA_ONNX_ENABLE_JNI=ON \
..

make -j4
ls -lh lib
popd
fi

if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
pushd ../sherpa-onnx/java-api
make
popd
fi

if [ ! -f ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
fi

java \
-Djava.library.path=$PWD/../build/lib \
-cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
StreamingDecodeFileCtcHLG.java
6 changes: 6 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ function testOnlineAsr() {
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi

if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
fi

out_filename=test_online_asr.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_online_asr.kt \
Expand Down
15 changes: 15 additions & 0 deletions kotlin-api-examples/test_online_asr.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.k2fsa.sherpa.onnx
fun main() {
testOnlineAsr("transducer")
testOnlineAsr("zipformer2-ctc")
testOnlineAsr("ctc-hlg")
}

fun testOnlineAsr(type: String) {
Expand All @@ -11,6 +12,7 @@ fun testOnlineAsr(type: String) {
featureDim = 80,
)

var ctcFstDecoderConfig = OnlineCtcFstDecoderConfig()
val waveFilename: String
val modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
Expand Down Expand Up @@ -40,6 +42,18 @@ fun testOnlineAsr(type: String) {
debug = false,
)
}
"ctc-hlg" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/1.wav"
ctcFstDecoderConfig.graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}

Expand All @@ -51,6 +65,7 @@ fun testOnlineAsr(type: String) {
modelConfig = modelConfig,
lmConfig = lmConfig,
featConfig = featConfig,
ctcFstDecoderConfig=ctcFstDecoderConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
decodingMethod = "greedy_search",
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/java-api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ java_files += OnlineParaformerModelConfig.java
java_files += OnlineZipformer2CtcModelConfig.java
java_files += OnlineTransducerModelConfig.java
java_files += OnlineModelConfig.java
java_files += OnlineCtcFstDecoderConfig.java
java_files += OnlineStream.java
java_files += OnlineRecognizerConfig.java
java_files += OnlineRecognizerResult.java
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;

public class OfflineNemoEncDecCtcModelConfig {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;

public class OnlineCtcFstDecoderConfig {
private final String graph;
private final int maxActive;

private OnlineCtcFstDecoderConfig(Builder builder) {
this.graph = builder.graph;
this.maxActive = builder.maxActive;
}

public static Builder builder() {
return new Builder();
}

public String getGraph() {
return graph;
}

public float getMaxActive() {
return maxActive;
}

public static class Builder {
private String graph = "";
private int maxActive = 3000;

public OnlineCtcFstDecoderConfig build() {
return new OnlineCtcFstDecoderConfig(this);
}

public Builder setGraph(String model) {
this.graph = graph;
return this;
}

public Builder setMaxActive(int maxActive) {
this.maxActive = maxActive;
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ public class OnlineRecognizerConfig {
private final FeatureConfig featConfig;
private final OnlineModelConfig modelConfig;
private final OnlineLMConfig lmConfig;

private final OnlineCtcFstDecoderConfig ctcFstDecoderConfig;
private final EndpointConfig endpointConfig;
private final boolean enableEndpoint;
private final String decodingMethod;
Expand All @@ -17,6 +19,7 @@ private OnlineRecognizerConfig(Builder builder) {
this.featConfig = builder.featConfig;
this.modelConfig = builder.modelConfig;
this.lmConfig = builder.lmConfig;
this.ctcFstDecoderConfig = builder.ctcFstDecoderConfig;
this.endpointConfig = builder.endpointConfig;
this.enableEndpoint = builder.enableEndpoint;
this.decodingMethod = builder.decodingMethod;
Expand All @@ -37,6 +40,7 @@ public static class Builder {
private FeatureConfig featConfig = FeatureConfig.builder().build();
private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build();
private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build();
private OnlineCtcFstDecoderConfig ctcFstDecoderConfig = OnlineCtcFstDecoderConfig.builder().build();
private EndpointConfig endpointConfig = EndpointConfig.builder().build();
private boolean enableEndpoint = true;
private String decodingMethod = "greedy_search";
Expand All @@ -63,6 +67,11 @@ public Builder setOnlineLMConfig(OnlineLMConfig lmConfig) {
return this;
}

public Builder setCtcFstDecoderConfig(OnlineCtcFstDecoderConfig ctcFstDecoderConfig) {
this.ctcFstDecoderConfig = ctcFstDecoderConfig;
return this;
}

public Builder setEndpointConfig(EndpointConfig endpointConfig) {
this.endpointConfig = endpointConfig;
return this;
Expand Down
16 changes: 16 additions & 0 deletions sherpa-onnx/jni/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);

fid = env->GetFieldID(cls, "ctcFstDecoderConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig;");

jobject fst_decoder_config = env->GetObjectField(config, fid);
jclass fst_decoder_config_cls = env->GetObjectClass(fst_decoder_config);

fid = env->GetFieldID(fst_decoder_config_cls, "graph", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(fst_decoder_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.ctc_fst_decoder_config.graph = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(fst_decoder_config_cls, "maxActive", "I");
ans.ctc_fst_decoder_config.max_active =
env->GetIntField(fst_decoder_config, fid);

return ans;
}
} // namespace sherpa_onnx
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,17 @@ data class OnlineLMConfig(
var scale: Float = 0.5f,
)

data class OnlineCtcFstDecoderConfig(
var graph: String = "",
var maxActive: Int = 3000,
)


data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
var ctcFstDecoderConfig : OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(),
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
Expand Down

0 comments on commit 7b2a3a5

Please sign in to comment.