diff --git a/.github/workflows/run-java-test.yaml b/.github/workflows/run-java-test.yaml index e974f29a8..69d0c0c4e 100644 --- a/.github/workflows/run-java-test.yaml +++ b/.github/workflows/run-java-test.yaml @@ -11,6 +11,7 @@ on: - 'java-api-examples/**' - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/jni/*' + - 'sherpa-onnx/java-api/**' pull_request: branches: - master @@ -21,6 +22,7 @@ on: - 'java-api-examples/**' - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/jni/*' + - 'sherpa-onnx/java-api/**' workflow_dispatch: concurrency: @@ -46,7 +48,7 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2 with: - key: ${{ matrix.os }} + key: ${{ matrix.os }}-java - name: Display java version shell: bash @@ -54,6 +56,42 @@ jobs: java -version echo "JAVA_HOME is: ${JAVA_HOME}" + cmake --version + + - name: Build sherpa-onnx (jar) + shell: bash + run: | + cd sherpa-onnx/java-api/ + make + ls -lh + + - uses: actions/upload-artifact@v4 + with: + name: sherpa-onnx-jar-${{ matrix.os }} + path: sherpa-onnx/java-api/build + + - name: Build sherpa-onnx (C++) + shell: bash + run: | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + + mkdir build + cd build + + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_BINARY=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + - name: Run java test shell: bash run: | @@ -62,4 +100,12 @@ jobs: cmake --version cd ./java-api-examples - ./runtest.sh + ./run-streaming-decode-file-ctc.sh + # Delete model files to save space + rm -rf sherpa-onnx-streaming-* + + ./run-streaming-decode-file-paraformer.sh + rm -rf sherpa-onnx-streaming-* + + ./run-streaming-decode-file-transducer.sh + rm -rf sherpa-onnx-streaming-* diff --git a/.github/workflows/test-build-wheel.yaml b/.github/workflows/test-build-wheel.yaml index 2c6070693..91eaea01a 100644 --- a/.github/workflows/test-build-wheel.yaml +++ b/.github/workflows/test-build-wheel.yaml @@ -66,11 +66,11 @@ jobs: - os: macos-14 python-version: "3.12" - - os: windows-2019 + - os: windows-2022 python-version: "3.7" - - os: windows-2019 + - os: windows-2022 python-version: "3.8" - - os: windows-2019 + - os: windows-2022 python-version: "3.9" - os: windows-2022 diff --git a/java-api-examples/.gitignore b/java-api-examples/.gitignore index 0c17ace06..c9db6d42e 100644 --- a/java-api-examples/.gitignore +++ b/java-api-examples/.gitignore @@ -1,2 +1,3 @@ lib hs_err* +!run-streaming*.sh diff --git a/java-api-examples/Makefile b/java-api-examples/Makefile deleted file mode 100755 index 619d6b43e..000000000 --- a/java-api-examples/Makefile +++ /dev/null @@ -1,101 +0,0 @@ -ENTRY_POINT = ./ - -LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx - -LIB_FILES = \ - $(LIB_SRC_DIR)/EndpointRule.java \ - $(LIB_SRC_DIR)/EndpointConfig.java \ - $(LIB_SRC_DIR)/FeatureConfig.java \ - $(LIB_SRC_DIR)/OnlineLMConfig.java \ - $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ - $(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \ - $(LIB_SRC_DIR)/OnlineZipformer2CtcModelConfig.java \ - $(LIB_SRC_DIR)/OnlineModelConfig.java \ - $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ - $(LIB_SRC_DIR)/OnlineStream.java \ - $(LIB_SRC_DIR)/OnlineRecognizer.java - -WEBSOCKET_DIR:= ./src/websocketsrv -WEBSOCKET_FILES = \ - $(WEBSOCKET_DIR)/ConnectionData.java \ - $(WEBSOCKET_DIR)/DecoderThreadHandler.java \ - $(WEBSOCKET_DIR)/StreamThreadHandler.java \ - $(WEBSOCKET_DIR)/AsrWebsocketServer.java \ - $(WEBSOCKET_DIR)/AsrWebsocketClient.java \ - - -LIB_BUILD_DIR = ./lib - - -EXAMPLE_FILE = DecodeFile.java - -EXAMPLE_Mic = DecodeMic.java - -JAVAC = javac - -BUILD_DIR = build - - -RUNJFLAGS = -Dfile.encoding=utf-8 - -vpath %.class $(BUILD_DIR) -vpath %.java src - - -buildfile: - $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE) - -buildmic: - $(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic) - -rebuild: clean all - -.PHONY: clean run downjar - -downjar: - wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/ - wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/ - wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/ - - -clean: - rm -frv $(BUILD_DIR)/* - rm -frv $(LIB_BUILD_DIR)/* - mkdir -p $(BUILD_DIR) - mkdir -p ./lib - -runfile: packjar buildfile - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav - -runhotwords: - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav - -runmic: - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic - -runsrv: - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg - -runclient: - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 - -runclienthotwords: - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32 - -buildlib: $(LIB_FILES:.java=.class) - - -%.class: %.java - $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< - -buildwebsocket: $(WEBSOCKET_FILES:.java=.class) - - -%.class: %.java - - $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $< - -packjar: buildlib - jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . - -all: clean buildlib packjar buildfile buildmic downjar buildwebsocket diff --git a/java-api-examples/README.md b/java-api-examples/README.md index 859768469..4427b03ec 100755 --- a/java-api-examples/README.md +++ b/java-api-examples/README.md @@ -1,193 +1,11 @@ -0.Introduction --------------- +# Introduction -Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform. -now support multiple threads for websocket server +This directory contains examples for the JAVA API of sherpa-onnx. -```xml -Depend on: - Openjdk 1.8 -``` - ---- - -1.Compile libsherpa-onnx-jni.so -------------------------------- - -Compile sherpa-onnx/jni/jni.cc according to your system. -Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: - -```xml - git clone https://github.com/k2-fsa/sherpa-onnx - cd sherpa-onnx - mkdir build - cd build - cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_JNI=ON .. - make -j6 -``` - ---- - -2.Download asr model files --------------------------- - -[click here for more detail](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html) --------------------------- - -3.Config model config.cfg -------------------------- -/**change model path in config.cfg according to your env**/ -```xml - #model config - sample_rate=16000 - feature_dim=80 - rule1_min_trailing_silence=2.4 - rule2_min_trailing_silence=1.2 - rule3_min_utterance_length=20 - encoder=/sherpa-onnx/build/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx - decoder=/sherpa-onnx/build/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx - joiner=/sherpa-onnx/build/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx - tokens=/sherpa-onnx/build/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt - num_threads=4 - enable_endpoint_detection=false - decoding_method=greedy_search - max_active_paths=4 - - #websocket server config - port=8890 - #number of threads pool for network io - connection_thread_num=16 - #number of threads pool for stream - stream_thread_num=16 - #number of threads pool for decoder - decoder_thread_num=16 - #size of streams for parallel decoding - parallel_decoder_num=16 - #time(ms) idle for decoder thread when no job - decoder_time_idle=10 - #time(ms) out for connection data - deocder_time_out=3000 -``` +# Usage ---- - -4.A simple java example ------------------------ - -refer to [java_api_example](https://github.com/k2-fsa/sherpa-onnx/blob/master/java-api-examples/src/DecodeFile.java) for more detail. - -```java - import com.k2fsa.sherpa.onnx.OnlineRecognizer; - import com.k2fsa.sherpa.onnx.OnlineStream; - String cfgpath=appdir+"/modelconfig.cfg"; - OnlineRecognizer.setSoPath(soPath); //set so lib path - - OnlineRecognizer rcgOjb = new OnlineRecognizer(); //create a recognizer - rcgOjb = new OnlineRecognizer(cfgFile); //set model config file - CreateStream streamObj=rcgOjb.CreateStream(); //create a stream for read wav data - float[] buffer = rcgOjb.readWavFile(wavfilename); // read data from file - streamObj.acceptWaveform(buffer); // feed stream with data - streamObj.inputFinished(); // tell engine you done with all data - OnlineStream ssObj[] = new OnlineStream[1]; - while (rcgOjb.isReady(streamObj)) { // engine is ready for unprocessed data - ssObj[0] = streamObj; - rcgOjb.decodeStreams(ssObj); // decode for multiple stream - // rcgOjb.DecodeStream(streamObj); // decode for single stream - } - - String recText = "simple:" + rcgOjb.getResult(streamObj) + "\n"; - byte[] utf8Data = recText.getBytes(StandardCharsets.UTF_8); - System.out.println(new String(utf8Data)); - rcgOjb.reSet(streamObj); - rcgOjb.releaseStream(streamObj); // release stream - rcgOjb.release(); // release recognizer ``` - ---- - -5.Makefile ----------- - -OS Ubuntu 18.04 LTS -Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar - -5.1 Build - -```bash - cd sherpa-onnx/java-api-examples - make all +./run-streaming-decode-file-ctc.sh +./run-streaming-decode-file-paraformer.sh +./run-streaming-decode-file-transducer.sh ``` - -5.2 Run DecodeFile example - -```bash - make runfile -``` - -5.3 Run DecodeMic example - -```bash - make runmic -``` - ---- - -6.WebSocket Server ----------- - -support multiple threads for websocket server -6.0 Protocol for communication -1) client connect to server -```shell - ws client -> srv ws address - ws address example: ws://127.0.0.1:8889/ -``` -2) client send 16k pcm_s16le binary stream data to server -```shell - PCM sampleRate 16000 - single channel - sampleSize 16bit - little endian - type short -``` -3) client send "Done" text to server when all data is sent -```shell - ws_socket.send("Done") -``` -4) client will receive json message from server whenever asr engine decoded new text -```shell - json example: {"text":"甚至出现交易几乎停滞的情况","eof":false"} -``` - - -6.1 Build - -```bash - cd sherpa-onnx/java-api-examples - make all -``` - -6.2 Run srv example - -usage: AsrWebsocketServer soPath modelCfgPath - -```bash - make runsrv /**change path in Makefile according to your env**/ -``` - -6.3 Run multiple threads client example - -usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads - -json result example: {"text":"甚至出现交易几乎停滞的情况","eof":"true"} - -```bash - make runclient /**change path in Makefile according to your env**/ -``` - -7 runtest -this script will download model, compile codes and run test -```bash - cd sherpa-onnx/java-api-examples - runtest.sh -``` \ No newline at end of file diff --git a/java-api-examples/StreamingDecodeFileCtc.java b/java-api-examples/StreamingDecodeFileCtc.java new file mode 100644 index 000000000..bb9f9121b --- /dev/null +++ b/java-api-examples/StreamingDecodeFileCtc.java @@ -0,0 +1,57 @@ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use an online CTC model, i.e., streaming CTC model, +// to decode files. +import com.k2fsa.sherpa.onnx.*; + +public class StreamingDecodeFileCtc { + public static void main(String[] args) { + // please refer to + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + // to download model files + String model = + "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx"; + String tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt"; + String waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/8k.wav"; + + WaveReader reader = new WaveReader(waveFilename); + System.out.println(reader.getSampleRate()); + System.out.println(reader.getSamples().length); + + OnlineZipformer2CtcModelConfig ctc = + OnlineZipformer2CtcModelConfig.builder().setModel(model).build(); + + OnlineModelConfig modelConfig = + OnlineModelConfig.builder() + .setZipformer2Ctc(ctc) + .setTokens(tokens) + .setNumThreads(1) + .setDebug(true) + .build(); + + OnlineRecognizerConfig config = + OnlineRecognizerConfig.builder() + .setOnlineModelConfig(modelConfig) + .setDecodingMethod("greedy_search") + .build(); + + OnlineRecognizer recognizer = new OnlineRecognizer(config); + OnlineStream stream = recognizer.createStream(); + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + + float[] tailPaddings = new float[(int) (0.3 * reader.getSampleRate())]; + stream.acceptWaveform(tailPaddings, reader.getSampleRate()); + + while (recognizer.isReady(stream)) { + recognizer.decode(stream); + } + + String text = recognizer.getResult(stream).getText(); + + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text); + + stream.release(); + recognizer.release(); + } +} diff --git a/java-api-examples/StreamingDecodeFileParaformer.java b/java-api-examples/StreamingDecodeFileParaformer.java new file mode 100644 index 000000000..af44910d8 --- /dev/null +++ b/java-api-examples/StreamingDecodeFileParaformer.java @@ -0,0 +1,57 @@ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use an online paraformer, i.e., streaming paraformer, +// to decode files. +import com.k2fsa.sherpa.onnx.*; + +public class StreamingDecodeFileParaformer { + public static void main(String[] args) { + // please refer to + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english + // to download model files + String encoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"; + String decoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"; + String tokens = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt"; + String waveFilename = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav"; + + WaveReader reader = new WaveReader(waveFilename); + System.out.println(reader.getSampleRate()); + System.out.println(reader.getSamples().length); + + OnlineParaformerModelConfig paraformer = + OnlineParaformerModelConfig.builder().setEncoder(encoder).setDecoder(decoder).build(); + + OnlineModelConfig modelConfig = + OnlineModelConfig.builder() + .setParaformer(paraformer) + .setTokens(tokens) + .setNumThreads(1) + .setDebug(true) + .build(); + + OnlineRecognizerConfig config = + OnlineRecognizerConfig.builder() + .setOnlineModelConfig(modelConfig) + .setDecodingMethod("greedy_search") + .build(); + + OnlineRecognizer recognizer = new OnlineRecognizer(config); + OnlineStream stream = recognizer.createStream(); + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + + float[] tailPaddings = new float[(int) (0.8 * reader.getSampleRate())]; + stream.acceptWaveform(tailPaddings, reader.getSampleRate()); + + while (recognizer.isReady(stream)) { + recognizer.decode(stream); + } + + String text = recognizer.getResult(stream).getText(); + + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text); + + stream.release(); + recognizer.release(); + } +} diff --git a/java-api-examples/StreamingDecodeFileTransducer.java b/java-api-examples/StreamingDecodeFileTransducer.java new file mode 100644 index 000000000..b1c0288ba --- /dev/null +++ b/java-api-examples/StreamingDecodeFileTransducer.java @@ -0,0 +1,67 @@ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation + +// This file shows how to use an online transducer, i.e., streaming transducer, +// to decode files. +import com.k2fsa.sherpa.onnx.*; + +public class StreamingDecodeFileTransducer { + public static void main(String[] args) { + // please refer to + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + // to download model files + String encoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"; + String decoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"; + String joiner = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"; + String tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"; + + String waveFilename = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav"; + + WaveReader reader = new WaveReader(waveFilename); + System.out.println(reader.getSampleRate()); + System.out.println(reader.getSamples().length); + + OnlineTransducerModelConfig transducer = + OnlineTransducerModelConfig.builder() + .setEncoder(encoder) + .setDecoder(decoder) + .setJoiner(joiner) + .build(); + + OnlineModelConfig modelConfig = + OnlineModelConfig.builder() + .setTransducer(transducer) + .setTokens(tokens) + .setNumThreads(1) + .setDebug(true) + .build(); + + OnlineRecognizerConfig config = + OnlineRecognizerConfig.builder() + .setOnlineModelConfig(modelConfig) + .setDecodingMethod("greedy_search") + .build(); + + OnlineRecognizer recognizer = new OnlineRecognizer(config); + OnlineStream stream = recognizer.createStream(); + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate()); + + float[] tailPaddings = new float[(int) (0.8 * reader.getSampleRate())]; + stream.acceptWaveform(tailPaddings, reader.getSampleRate()); + + while (recognizer.isReady(stream)) { + recognizer.decode(stream); + } + + String text = recognizer.getResult(stream).getText(); + + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text); + + stream.release(); + recognizer.release(); + } +} diff --git a/java-api-examples/modelconfig.cfg b/java-api-examples/modelconfig.cfg deleted file mode 100755 index d1ed3b2d0..000000000 --- a/java-api-examples/modelconfig.cfg +++ /dev/null @@ -1,28 +0,0 @@ -#model config -sample_rate=16000 -feature_dim=80 -rule1_min_trailing_silence=2.4 -rule2_min_trailing_silence=1.2 -rule3_min_utterance_length=20 -encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx -decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx -joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx -tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt -num_threads=4 -enable_endpoint_detection=true -decoding_method=modified_beam_search -max_active_paths=4 -hotwords_file= -hotwords_score=1.5 -lm_model= -lm_scale=0.5 -model_type=zipformer - -#websocket server config -port=8890 -connection_thread_num=16 -stream_thread_num=16 -decoder_thread_num=16 -parallel_decoder_num=16 -decoder_time_idle=200 -deocder_time_out=30000 diff --git a/java-api-examples/run-streaming-decode-file-ctc.sh b/java-api-examples/run-streaming-decode-file-ctc.sh new file mode 100755 index 000000000..d029e1662 --- /dev/null +++ b/java-api-examples/run-streaming-decode-file-ctc.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [ ! -f ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + StreamingDecodeFileCtc.java diff --git a/java-api-examples/run-streaming-decode-file-paraformer.sh b/java-api-examples/run-streaming-decode-file-paraformer.sh new file mode 100755 index 000000000..435f80503 --- /dev/null +++ b/java-api-examples/run-streaming-decode-file-paraformer.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [ ! -f ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 + tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 + rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + StreamingDecodeFileParaformer.java diff --git a/java-api-examples/run-streaming-decode-file-transducer.sh b/java-api-examples/run-streaming-decode-file-transducer.sh new file mode 100755 index 000000000..79a20ea1c --- /dev/null +++ b/java-api-examples/run-streaming-decode-file-transducer.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then + pushd ../sherpa-onnx/java-api + make + popd +fi + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + cmake \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib +fi + +if [ ! -f ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +fi + +java \ + -Djava.library.path=$PWD/../build/lib \ + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \ + StreamingDecodeFileTransducer.java diff --git a/java-api-examples/runtest.sh b/java-api-examples/runtest.sh deleted file mode 100755 index 82f763c6c..000000000 --- a/java-api-examples/runtest.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env bash -# -# This scripts shows how to test java for sherpa-onnx -# Note: This scripts runs only on Linux and macOS - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - - - - -echo "PATH: $PATH" - - - - - -log "------------------------------------------------------------" -log "Run download model" -log "------------------------------------------------------------" - -repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 -log "Start testing ${repo_url}" -repo=$(basename $repo_url) -log "download dir is $(basename $repo_url)" -if [ ! -d $repo ];then - log "Download pretrained model and test-data from $repo_url" - - GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url - pushd $repo - git lfs pull --include "*.onnx" - ls -lh *.onnx - popd - ln -s $repo/test_wavs/0.wav hotwords.wav - -fi - -log $(pwd) - -sed -e 's?/sherpa/?'$(pwd)'/?g' modelconfig.cfg > modeltest.cfg - -log "display model cfg" -cat modeltest.cfg - -cd .. - -export JAVA_HOME=$(readlink -f /usr/bin/javac | sed "s:/bin/javac::") - -mkdir -p build -cd build - -cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_JNI=ON .. - -make -j4 -ls -lh lib - -export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH - -cd ../java-api-examples - -make all - -make runfile - -echo "礼 拜 二" > hotwords.txt - -sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg - -make runhotwords diff --git a/java-api-examples/test.wav b/java-api-examples/test.wav deleted file mode 100644 index 256e4afd3..000000000 Binary files a/java-api-examples/test.wav and /dev/null differ diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index cb9c04b55..bc9184021 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -6,11 +6,9 @@ set -ex -cd .. -mkdir -p build -cd build - if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build cmake \ -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ -DSHERPA_ONNX_ENABLE_TESTS=OFF \ @@ -22,12 +20,11 @@ if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa make -j4 ls -lh lib + popd fi export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH -cd ../kotlin-api-examples - function testSpeakerEmbeddingExtractor() { if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx diff --git a/sherpa-onnx/csrc/online-websocket-client.cc b/sherpa-onnx/csrc/online-websocket-client.cc index 62a6832b8..351bd6acc 100644 --- a/sherpa-onnx/csrc/online-websocket-client.cc +++ b/sherpa-onnx/csrc/online-websocket-client.cc @@ -253,7 +253,7 @@ int32_t main(int32_t argc, char *argv[]) { sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok); if (!is_ok) { - SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); + SHERPA_ONNX_LOGE("Failed to read '%s'", wave_filename.c_str()); return -1; } diff --git a/sherpa-onnx/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc b/sherpa-onnx/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc index f4702a836..da00ea626 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc @@ -96,7 +96,7 @@ static std::vector> ComputeEmbeddings( sherpa_onnx::ReadWave(f, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", f.c_str()); + fprintf(stderr, "Failed to read '%s'\n", f.c_str()); exit(-1); } diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc index 72053744d..77d5cd4ef 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc @@ -78,7 +78,7 @@ for a list of pre-trained models to download. sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); return -1; } diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc index 769e8b5ff..d3518c090 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc @@ -93,7 +93,7 @@ static std::vector> ComputeEmbeddings( sherpa_onnx::ReadWave(f, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", f.c_str()); + fprintf(stderr, "Failed to read '%s'\n", f.c_str()); exit(-1); } diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc index 862818f5c..9367b017e 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-audio-tagging.cc @@ -58,7 +58,7 @@ for more models. sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); return -1; } diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc index 83756621d..ab5dcc7e7 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc @@ -73,7 +73,7 @@ for a list of pre-trained models to download. const std::vector samples = sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); return -1; } float duration = samples.size() / static_cast(sampling_rate); diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-parallel.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-parallel.cc index 8e4c4ffa4..63409395d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-parallel.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-parallel.cc @@ -69,7 +69,7 @@ void AsrInference(const std::vector> &chunk_wav_paths, const std::vector samples = sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); continue; } duration += samples.size() / static_cast(sampling_rate); @@ -96,7 +96,7 @@ void AsrInference(const std::vector> &chunk_wav_paths, const std::vector samples = sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); continue; } duration += samples.size() / static_cast(sampling_rate); diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index a84266c72..73e77299a 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -124,7 +124,7 @@ for a list of pre-trained models to download. const std::vector samples = sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); return -1; } duration += samples.size() / static_cast(sampling_rate); diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 89a21e239..cc2a78392 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -109,7 +109,7 @@ for a list of pre-trained models to download. sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); return -1; } diff --git a/sherpa-onnx/java-api/.gitignore b/sherpa-onnx/java-api/.gitignore index 4934677e0..4c8b70e05 100644 --- a/sherpa-onnx/java-api/.gitignore +++ b/sherpa-onnx/java-api/.gitignore @@ -1,2 +1,6 @@ .idea java-api.iml +out +META-INF +build +*.jar diff --git a/sherpa-onnx/java-api/Makefile b/sherpa-onnx/java-api/Makefile new file mode 100644 index 000000000..8b4bf3377 --- /dev/null +++ b/sherpa-onnx/java-api/Makefile @@ -0,0 +1,42 @@ + +# all .class and .jar files are put inside out_dir +out_dir := build +out_jar := $(out_dir)/sherpa-onnx.jar + +package_dir := com/k2fsa/sherpa/onnx + +java_files := WaveReader.java +java_files += EndpointRule.java +java_files += EndpointConfig.java +java_files += FeatureConfig.java +java_files += OnlineLMConfig.java +java_files += OnlineParaformerModelConfig.java +java_files += OnlineZipformer2CtcModelConfig.java +java_files += OnlineTransducerModelConfig.java +java_files += OnlineModelConfig.java +java_files += OnlineStream.java +java_files += OnlineRecognizerConfig.java +java_files += OnlineRecognizerResult.java +java_files += OnlineRecognizer.java + +class_files := $(java_files:%.java=%.class) + +java_files := $(addprefix src/$(package_dir)/,$(java_files)) +class_files := $(addprefix $(out_dir)/$(package_dir)/,$(class_files)) + +$(info -- java files $(java_files)) +$(info --) +$(info -- class files $(class_files)) + +.phony: all clean + +all: $(out_jar) + +$(out_jar): $(class_files) + jar --create --verbose --file $(out_jar) -C $(out_dir) . + +clean: + $(RM) -rfv $(out_dir) + +$(class_files): $(out_dir)/$(package_dir)/%.class: src/$(package_dir)/%.java + javac -d $(out_dir) --class-path $(out_dir) $< diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java index 41c1c9193..e63b142ec 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java @@ -1,18 +1,22 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class EndpointConfig { + private final EndpointRule rule1; private final EndpointRule rule2; private final EndpointRule rule3; - public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) { - this.rule1 = rule1; - this.rule2 = rule2; - this.rule3 = rule3; + private EndpointConfig(Builder builder) { + this.rule1 = builder.rule1; + this.rule2 = builder.rule2; + this.rule3 = builder.rule3; + } + + public static Builder builder() { + return new Builder(); } public EndpointRule getRule1() { @@ -26,4 +30,42 @@ public EndpointRule getRule2() { public EndpointRule getRule3() { return rule3; } + + public static class Builder { + + private EndpointRule rule1 = EndpointRule.builder(). + setMustContainNonSilence(false). + setMinTrailingSilence(2.4f). + setMinUtteranceLength(0). + build(); + private EndpointRule rule2 = EndpointRule.builder(). + setMustContainNonSilence(true). + setMinTrailingSilence(1.4f). + setMinUtteranceLength(0). + build(); + private EndpointRule rule3 = EndpointRule.builder(). + setMustContainNonSilence(false). + setMinTrailingSilence(0.0f). + setMinUtteranceLength(20.0f). + build(); + + public EndpointConfig build() { + return new EndpointConfig(this); + } + + public Builder setRule1(EndpointRule rule) { + this.rule1 = rule; + return this; + } + + public Builder setRule2(EndpointRule rule) { + this.rule2 = rule; + return this; + } + + public Builder setRul3(EndpointRule rule) { + this.rule3 = rule; + return this; + } + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java index 7abcc7c57..97a5dbb33 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java @@ -1,19 +1,21 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ - +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class EndpointRule { + private final boolean mustContainNonSilence; private final float minTrailingSilence; private final float minUtteranceLength; - public EndpointRule( - boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) { - this.mustContainNonSilence = mustContainNonSilence; - this.minTrailingSilence = minTrailingSilence; - this.minUtteranceLength = minUtteranceLength; + private EndpointRule(Builder builder) { + this.mustContainNonSilence = builder.mustContainNonSilence; + this.minTrailingSilence = builder.minTrailingSilence; + this.minUtteranceLength = builder.minUtteranceLength; + } + + public static Builder builder() { + return new Builder(); } public float getMinTrailingSilence() { @@ -27,4 +29,29 @@ public float getMinUtteranceLength() { public boolean getMustContainNonSilence() { return mustContainNonSilence; } -} + + public static class Builder { + private boolean mustContainNonSilence = false; + private float minTrailingSilence = 0; + private float minUtteranceLength = 0; + + public EndpointRule build() { + return new EndpointRule(this); + } + + public Builder setMustContainNonSilence(boolean mustContainNonSilence) { + this.mustContainNonSilence = mustContainNonSilence; + return this; + } + + public Builder setMinTrailingSilence(float minTrailingSilence) { + this.minTrailingSilence = minTrailingSilence; + return this; + } + + public Builder setMinUtteranceLength(float minUtteranceLength) { + this.minUtteranceLength = minUtteranceLength; + return this; + } + } +} \ No newline at end of file diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java index 381c28ac6..b4ecf7624 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java @@ -1,6 +1,5 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; @@ -8,9 +7,13 @@ public class FeatureConfig { private final int sampleRate; private final int featureDim; - public FeatureConfig(int sampleRate, int featureDim) { - this.sampleRate = sampleRate; - this.featureDim = featureDim; + private FeatureConfig(Builder builder) { + this.sampleRate = builder.sampleRate; + this.featureDim = builder.featureDim; + } + + public static Builder builder() { + return new Builder(); } public int getSampleRate() { @@ -20,4 +23,23 @@ public int getSampleRate() { public int getFeatureDim() { return featureDim; } + + public static class Builder { + private int sampleRate = 16000; + private int featureDim = 80; + + public FeatureConfig build() { + return new FeatureConfig(this); + } + + public Builder setSampleRate(int sampleRate) { + this.sampleRate = sampleRate; + return this; + } + + public Builder setFeatureDim(int featureDim) { + this.featureDim = featureDim; + return this; + } + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java index e94ca9653..a41144d73 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineLMConfig.java @@ -1,16 +1,20 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class OnlineLMConfig { + private final String model; private final float scale; - public OnlineLMConfig(String model, float scale) { - this.model = model; - this.scale = scale; + private OnlineLMConfig(Builder builder) { + this.model = builder.model; + this.scale = builder.scale; + } + + public static Builder builder() { + return new Builder(); } public String getModel() { @@ -20,4 +24,23 @@ public String getModel() { public float getScale() { return scale; } -} + + public static class Builder { + private String model = ""; + private float scale = 1.0f; + + public OnlineLMConfig build() { + return new OnlineLMConfig(this); + } + + public Builder setModel(String model) { + this.model = model; + return this; + } + + public Builder setScale(float scale) { + this.scale = scale; + return this; + } + } +} \ No newline at end of file diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java index eddf73617..bf2e73dc3 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineModelConfig.java @@ -1,36 +1,30 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class OnlineModelConfig { - private final OnlineParaformerModelConfig paraformer; private final OnlineTransducerModelConfig transducer; + private final OnlineParaformerModelConfig paraformer; private final OnlineZipformer2CtcModelConfig zipformer2Ctc; private final String tokens; private final int numThreads; private final boolean debug; - private final String provider = "cpu"; - private String modelType = ""; - - public OnlineModelConfig( - String tokens, - int numThreads, - boolean debug, - String modelType, - OnlineParaformerModelConfig paraformer, - OnlineTransducerModelConfig transducer, - OnlineZipformer2CtcModelConfig zipformer2Ctc - ) { - - this.tokens = tokens; - this.numThreads = numThreads; - this.debug = debug; - this.modelType = modelType; - this.paraformer = paraformer; - this.transducer = transducer; - this.zipformer2Ctc = zipformer2Ctc; + private final String provider; + private final String modelType; + private OnlineModelConfig(Builder builder) { + this.transducer = builder.transducer; + this.paraformer = builder.paraformer; + this.zipformer2Ctc = builder.zipformer2Ctc; + this.tokens = builder.tokens; + this.numThreads = builder.numThreads; + this.debug = builder.debug; + this.provider = builder.provider; + this.modelType = builder.modelType; + } + + public static Builder builder() { + return new Builder(); } public OnlineParaformerModelConfig getParaformer() { @@ -41,6 +35,10 @@ public OnlineTransducerModelConfig getTransducer() { return transducer; } + public OnlineZipformer2CtcModelConfig getZipformer2Ctc() { + return zipformer2Ctc; + } + public String getTokens() { return tokens; } @@ -52,4 +50,67 @@ public int getNumThreads() { public boolean getDebug() { return debug; } + + public String getProvider() { + return provider; + } + + public String getModelType() { + return modelType; + } + + public static class Builder { + private OnlineParaformerModelConfig paraformer = OnlineParaformerModelConfig.builder().build(); + private OnlineTransducerModelConfig transducer = OnlineTransducerModelConfig.builder().build(); + private OnlineZipformer2CtcModelConfig zipformer2Ctc = OnlineZipformer2CtcModelConfig.builder().build(); + private String tokens = ""; + private int numThreads = 1; + private boolean debug = true; + private String provider = "cpu"; + private String modelType = ""; + + public OnlineModelConfig build() { + return new OnlineModelConfig(this); + } + + public Builder setTransducer(OnlineTransducerModelConfig transducer) { + this.transducer = transducer; + return this; + } + + public Builder setParaformer(OnlineParaformerModelConfig paraformer) { + this.paraformer = paraformer; + return this; + } + + public Builder setZipformer2Ctc(OnlineZipformer2CtcModelConfig zipformer2Ctc) { + this.zipformer2Ctc = zipformer2Ctc; + return this; + } + + public Builder setTokens(String tokens) { + this.tokens = tokens; + return this; + } + + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + public Builder setDebug(boolean debug) { + this.debug = debug; + return this; + } + + public Builder setProvider(String provider) { + this.provider = provider; + return this; + } + + public Builder setModelType(String modelType) { + this.modelType = modelType; + return this; + } + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java index 2f7017a04..2b02d7007 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineParaformerModelConfig.java @@ -1,6 +1,5 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; @@ -8,9 +7,13 @@ public class OnlineParaformerModelConfig { private final String encoder; private final String decoder; - public OnlineParaformerModelConfig(String encoder, String decoder) { - this.encoder = encoder; - this.decoder = decoder; + private OnlineParaformerModelConfig(Builder builder) { + this.encoder = builder.encoder; + this.decoder = builder.decoder; + } + + public static Builder builder() { + return new Builder(); } public String getEncoder() { @@ -20,4 +23,23 @@ public String getEncoder() { public String getDecoder() { return decoder; } + + public static class Builder { + private String encoder = ""; + private String decoder = ""; + + public OnlineParaformerModelConfig build() { + return new OnlineParaformerModelConfig(this); + } + + public Builder setEncoder(String encoder) { + this.encoder = encoder; + return this; + } + + public Builder setDecoder(String decoder) { + this.decoder = decoder; + return this; + } + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java index 15f07b07a..87076a830 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -1,234 +1,21 @@ -/* - * // Copyright 2022-2023 by zhaoming - * // the online recognizer for sherpa-onnx, it can load config from a file - * // or by argument - */ -/* -usage example: - - String cfgpath=appdir+"/modelconfig.cfg"; - OnlineRecognizer.setSoPath(soPath); //set so lib path - - OnlineRecognizer rcgOjb = new OnlineRecognizer(); //create a recognizer - rcgOjb = new OnlineRecognizer(cfgFile); //set model config file - CreateStream streamObj=rcgOjb.CreateStream(); //create a stream for read wav data - float[] buffer = rcgOjb.readWavFile(wavfilename); // read data from file - streamObj.acceptWaveform(buffer); // feed stream with data - streamObj.inputFinished(); // tell engine you done with all data - OnlineStream ssObj[] = new OnlineStream[1]; - while (rcgOjb.isReady(streamObj)) { // engine is ready for unprocessed data - ssObj[0] = streamObj; - rcgOjb.decodeStreams(ssObj); // decode for multiple stream - // rcgOjb.DecodeStream(streamObj); // decode for single stream - } - - String recText = "simple:" + rcgOjb.getResult(streamObj) + "\n"; - byte[] utf8Data = recText.getBytes(StandardCharsets.UTF_8); - System.out.println(new String(utf8Data)); - rcgOjb.reSet(streamObj); - rcgOjb.releaseStream(streamObj); // release stream - rcgOjb.release(); // release recognizer - -*/ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; -import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.Map; -import java.util.Properties; public class OnlineRecognizer { - private long ptr = 0; // this is the asr engine ptrss - - private int sampleRate = 16000; - - // load config file for OnlineRecognizer - public OnlineRecognizer(String modelCfgPath) { - Map proMap = this.readProperties(modelCfgPath); - try { - int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); - this.sampleRate = sampleRate; - EndpointRule rule1 = - new EndpointRule( - false, - Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), - 0.0F); - EndpointRule rule2 = - new EndpointRule( - true, - Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), - 0.0F); - EndpointRule rule3 = - new EndpointRule( - false, - 0.0F, - Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); - EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - - OnlineParaformerModelConfig modelParaCfg = - new OnlineParaformerModelConfig( - proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); - OnlineTransducerModelConfig modelTranCfg = - new OnlineTransducerModelConfig( - proMap.getOrDefault("encoder", "").trim(), - proMap.getOrDefault("decoder", "").trim(), - proMap.getOrDefault("joiner", "").trim()); - OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig(""); - OnlineModelConfig modelCfg = - new OnlineModelConfig( - proMap.getOrDefault("tokens", "").trim(), - Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), - false, - proMap.getOrDefault("model_type", "zipformer").trim(), - modelParaCfg, - modelTranCfg, zipformer2CtcConfig); - FeatureConfig featConfig = - new FeatureConfig( - sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); - OnlineLMConfig onlineLmConfig = - new OnlineLMConfig( - proMap.getOrDefault("lm_model", "").trim(), - Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); - - OnlineRecognizerConfig rcgCfg = - new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - onlineLmConfig, - Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), - proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), - Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), - proMap.getOrDefault("hotwords_file", "").trim(), - Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); - // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 - this.ptr = createOnlineRecognizer(new Object(), rcgCfg); - - } catch (Exception e) { - System.err.println(e); - } + static { + System.loadLibrary("sherpa-onnx-jni"); } - // use for android asset_manager ANDROID_API__ >= 9 - public OnlineRecognizer(Object assetManager, String modelCfgPath) { - Map proMap = this.readProperties(modelCfgPath); - try { - int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); - this.sampleRate = sampleRate; - EndpointRule rule1 = - new EndpointRule( - false, - Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), - 0.0F); - EndpointRule rule2 = - new EndpointRule( - true, - Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), - 0.0F); - EndpointRule rule3 = - new EndpointRule( - false, - 0.0F, - Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); - EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineParaformerModelConfig modelParaCfg = - new OnlineParaformerModelConfig( - proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); - OnlineTransducerModelConfig modelTranCfg = - new OnlineTransducerModelConfig( - proMap.getOrDefault("encoder", "").trim(), - proMap.getOrDefault("decoder", "").trim(), - proMap.getOrDefault("joiner", "").trim()); - OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig(""); - - OnlineModelConfig modelCfg = - new OnlineModelConfig( - proMap.getOrDefault("tokens", "").trim(), - Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), - false, - proMap.getOrDefault("model_type", "zipformer").trim(), - modelParaCfg, - modelTranCfg, zipformer2CtcConfig); - FeatureConfig featConfig = - new FeatureConfig( - sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); - - OnlineLMConfig onlineLmConfig = - new OnlineLMConfig( - proMap.getOrDefault("lm_model", "").trim(), - Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); - - OnlineRecognizerConfig rcgCfg = - new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - onlineLmConfig, - Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), - proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), - Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), - proMap.getOrDefault("hotwords_file", "").trim(), - Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); - // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 - this.ptr = createOnlineRecognizer(assetManager, rcgCfg); + private long ptr = 0; // this is the asr engine ptrss - } catch (Exception e) { - System.err.println(e); - } - } - // set onlineRecognizer by parameter - public OnlineRecognizer( - String tokens, - String encoder, - String decoder, - String joiner, - int numThreads, - int sampleRate, - int featureDim, - boolean enableEndpointDetection, - float rule1MinTrailingSilence, - float rule2MinTrailingSilence, - float rule3MinUtteranceLength, - String decodingMethod, - String lm_model, - float lm_scale, - int maxActivePaths, - String hotwordsFile, - float hotwordsScore, - String modelType) { - this.sampleRate = sampleRate; - EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); - EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); - EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); - EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); - OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder); - OnlineTransducerModelConfig modelTranCfg = - new OnlineTransducerModelConfig(encoder, decoder, joiner); - OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig(""); - OnlineModelConfig modelCfg = - new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg, zipformer2CtcConfig); - FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); - OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); - OnlineRecognizerConfig rcgCfg = - new OnlineRecognizerConfig( - featConfig, - modelCfg, - endCfg, - onlineLmConfig, - enableEndpointDetection, - decodingMethod, - maxActivePaths, - hotwordsFile, - hotwordsScore); - // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 - this.ptr = createOnlineRecognizer(new Object(), rcgCfg); + public OnlineRecognizer(OnlineRecognizerConfig config) { + ptr = newFromFile(config); } + /* public static float[] readWavFile(String fileName) { // read data from the filename Object[] wavdata = readWave(fileName); @@ -238,139 +25,67 @@ public static float[] readWavFile(String fileName) { return floatData; } + */ - // load the libsherpa-onnx-jni.so lib - public static void loadSoLib(String soPath) { - // load libsherpa-onnx-jni.so lib from the path - - System.out.println("so lib path=" + soPath + "\n"); - System.load(soPath.trim()); - System.out.println("load so lib succeed\n"); - } - - public static void setSoPath(String soPath) { - OnlineRecognizer.loadSoLib(soPath); - OnlineStream.loadSoLib(soPath); - } - - private static native Object[] readWave(String fileName); // static - - private Map readProperties(String modelCfgPath) { - // read and parse config file - Properties props = new Properties(); - Map proMap = new HashMap<>(); - try { - File file = new File(modelCfgPath); - if (!file.exists()) { - System.out.println("model cfg file not exists!"); - System.exit(0); - } - InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath)); - props.load(in); - Enumeration en = props.propertyNames(); - while (en.hasMoreElements()) { - String key = (String) en.nextElement(); - String Property = props.getProperty(key); - proMap.put(key, Property); - } - - } catch (Exception e) { - e.printStackTrace(); - } - return proMap; - } - - public void decodeStream(OnlineStream s) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - // when feeded samples to engine, call DecodeStream to let it process - decodeStream(this.ptr, streamPtr); - } - public void decodeStreams(OnlineStream[] ssOjb) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - // decode for multiple streams - long[] ss = new long[ssOjb.length]; - for (int i = 0; i < ssOjb.length; i++) { - ss[i] = ssOjb[i].getPtr(); - if (ss[i] == 0) throw new Exception("null exception for stream ptr"); - } - decodeStreams(this.ptr, ss); + public void decode(OnlineStream s) { + decode(ptr, s.getPtr()); } - public boolean isReady(OnlineStream s) throws Exception { - // whether the engine is ready for decode - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - return isReady(this.ptr, streamPtr); - } - public String getResult(OnlineStream s) throws Exception { - // get text from the engine - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - return getResult(this.ptr, streamPtr); + public boolean isReady(OnlineStream s) { + return isReady(ptr, s.getPtr()); } - public boolean isEndpoint(OnlineStream s) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - return isEndpoint(this.ptr, streamPtr); + public boolean isEndpoint(OnlineStream s) { + return isEndpoint(ptr, s.getPtr()); } - public void reSet(OnlineStream s) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = s.getPtr(); - if (streamPtr == 0) throw new Exception("null exception for stream ptr"); - reSet(this.ptr, streamPtr); + public void reset(OnlineStream s) { + reset(ptr, s.getPtr()); } - public OnlineStream createStream() throws Exception { - // create one stream for data to feed in - if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); - long streamPtr = createStream(this.ptr); - OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate); - return stream; + public OnlineStream createStream() { + long p = createStream(ptr, ""); + return new OnlineStream(p); } + @Override protected void finalize() throws Throwable { release(); } // recognizer release, you'd better call it manually if not use anymore public void release() { - if (this.ptr == 0) return; - deleteOnlineRecognizer(this.ptr); + if (this.ptr == 0) { + return; + } + delete(this.ptr); this.ptr = 0; } - // JNI interface libsherpa-onnx-jni.so - - // stream release, you'd better call it manually if not use anymore - public void releaseStream(OnlineStream s) { - s.release(); + public OnlineRecognizerResult getResult(OnlineStream s) { + Object[] arr = getResult(ptr, s.getPtr()); + String text = (String) arr[0]; + String[] tokens = (String[]) arr[1]; + float[] timestamps = (float[]) arr[2]; + return new OnlineRecognizerResult(text, tokens, timestamps); } - private native String getResult(long ptr, long streamPtr); - private native void decodeStream(long ptr, long streamPtr); + private native void delete(long ptr); - private native void decodeStreams(long ptr, long[] ssPtr); + private native long newFromFile(OnlineRecognizerConfig config); - private native boolean isReady(long ptr, long streamPtr); - - // first parameter keep for android asset_manager ANDROID_API__ >= 9 - private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config); + private native long createStream(long ptr, String hotwords); - private native long createStream(long ptr); + private native void reset(long ptr, long streamPtr); - private native void deleteOnlineRecognizer(long ptr); + private native void decode(long ptr, long streamPtr); private native boolean isEndpoint(long ptr, long streamPtr); - private native void reSet(long ptr, long streamPtr); -} + private native boolean isReady(long ptr, long streamPtr); + + private native Object[] getResult(long ptr, long streamPtr); +} \ No newline at end of file diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java index 74f035cb0..8a6eeadd5 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -1,66 +1,95 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ - +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class OnlineRecognizerConfig { private final FeatureConfig featConfig; private final OnlineModelConfig modelConfig; - private final EndpointConfig endpointConfig; private final OnlineLMConfig lmConfig; + private final EndpointConfig endpointConfig; private final boolean enableEndpoint; private final String decodingMethod; private final int maxActivePaths; private final String hotwordsFile; private final float hotwordsScore; - - public OnlineRecognizerConfig( - FeatureConfig featConfig, - OnlineModelConfig modelConfig, - EndpointConfig endpointConfig, - OnlineLMConfig lmConfig, - boolean enableEndpoint, - String decodingMethod, - int maxActivePaths, - String hotwordsFile, - float hotwordsScore) { - this.featConfig = featConfig; - this.modelConfig = modelConfig; - this.endpointConfig = endpointConfig; - this.lmConfig = lmConfig; - this.enableEndpoint = enableEndpoint; - this.decodingMethod = decodingMethod; - this.maxActivePaths = maxActivePaths; - this.hotwordsFile = hotwordsFile; - this.hotwordsScore = hotwordsScore; + private OnlineRecognizerConfig(Builder builder) { + this.featConfig = builder.featConfig; + this.modelConfig = builder.modelConfig; + this.lmConfig = builder.lmConfig; + this.endpointConfig = builder.endpointConfig; + this.enableEndpoint = builder.enableEndpoint; + this.decodingMethod = builder.decodingMethod; + this.maxActivePaths = builder.maxActivePaths; + this.hotwordsFile = builder.hotwordsFile; + this.hotwordsScore = builder.hotwordsScore; } - public OnlineLMConfig getLmConfig() { - return lmConfig; - } - - public FeatureConfig getFeatConfig() { - return featConfig; + public static Builder builder() { + return new Builder(); } public OnlineModelConfig getModelConfig() { return modelConfig; } - public EndpointConfig getEndpointConfig() { - return endpointConfig; - } + public static class Builder { + private FeatureConfig featConfig = FeatureConfig.builder().build(); + private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build(); + private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build(); + private EndpointConfig endpointConfig = EndpointConfig.builder().build(); + private boolean enableEndpoint = true; + private String decodingMethod = "greedy_search"; + private int maxActivePaths = 4; + private String hotwordsFile = ""; + private float hotwordsScore = 1.5f; - public boolean isEnableEndpoint() { - return enableEndpoint; - } + public OnlineRecognizerConfig build() { + return new OnlineRecognizerConfig(this); + } - public String getDecodingMethod() { - return decodingMethod; - } + public Builder setFeatureConfig(FeatureConfig featConfig) { + this.featConfig = featConfig; + return this; + } + + public Builder setOnlineModelConfig(OnlineModelConfig modelConfig) { + this.modelConfig = modelConfig; + return this; + } + + public Builder setOnlineLMConfig(OnlineLMConfig lmConfig) { + this.lmConfig = lmConfig; + return this; + } + + public Builder setEndpointConfig(EndpointConfig endpointConfig) { + this.endpointConfig = endpointConfig; + return this; + } + + public Builder setEnableEndpoint(boolean enableEndpoint) { + this.enableEndpoint = enableEndpoint; + return this; + } + + public Builder setDecodingMethod(String decodingMethod) { + this.decodingMethod = decodingMethod; + return this; + } + + public Builder setMaxActivePaths(int maxActivePaths) { + this.maxActivePaths = maxActivePaths; + return this; + } + + public Builder setHotwordsFile(String hotwordsFile) { + this.hotwordsFile = hotwordsFile; + return this; + } - public int getMaxActivePaths() { - return maxActivePaths; + public Builder setHotwordsScore(float hotwordsScore) { + this.hotwordsScore = hotwordsScore; + return this; + } } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java new file mode 100644 index 000000000..468e325e9 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerResult.java @@ -0,0 +1,26 @@ +// Copyright 2024 Xiaomi Corporation +package com.k2fsa.sherpa.onnx; + +public class OnlineRecognizerResult { + private final String text; + private final String[] tokens; + private final float[] timestamps; + + public OnlineRecognizerResult(String text, String[] tokens, float[] timestamps) { + this.text = text; + this.tokens = tokens; + this.timestamps = timestamps; + } + + public String getText() { + return text; + } + + public String[] getTokens() { + return tokens; + } + + public float[] getTimestamps() { + return timestamps; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java index 42df01018..960144b56 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java @@ -1,84 +1,56 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ -// Stream is used for feeding data to the asr engine +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class OnlineStream { - private long ptr = 0; // this is the stream ptr + static { + System.loadLibrary("sherpa-onnx-jni"); + } - private int sampleRate = 16000; + private long ptr = 0; - // assign ptr to this stream in construction - public OnlineStream(long ptr, int sampleRate) { - this.ptr = ptr; - this.sampleRate = sampleRate; + public OnlineStream() { + this.ptr = 0; } - public static void loadSoLib(String soPath) { - // load .so lib from the path - System.load(soPath.trim()); // ("sherpa-onnx-jni-java"); + public OnlineStream(long ptr) { + this.ptr = ptr; } public long getPtr() { return ptr; } - public void acceptWaveform(float[] samples) throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + public void setPtr(long ptr) { + this.ptr = ptr; + } - // feed wave data to asr engine - acceptWaveform(this.ptr, this.sampleRate, samples); + public void acceptWaveform(float[] samples, int sampleRate) { + acceptWaveform(this.ptr, samples, sampleRate); } public void inputFinished() { - // add some tail padding - int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate - float[] tailPaddings = new float[padLen]; // default value is 0 - acceptWaveform(this.ptr, this.sampleRate, tailPaddings); - - // tell the engine all data are feeded inputFinished(this.ptr); } public void release() { // stream object must be release after used - if (this.ptr == 0) return; - deleteStream(this.ptr); + if (this.ptr == 0) { + return; + } + delete(this.ptr); this.ptr = 0; } + @Override protected void finalize() throws Throwable { release(); + super.finalize(); } - public boolean isLastFrame() throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - return isLastFrame(this.ptr); - } - - public void reSet() throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - reSet(this.ptr); - } - - public int featureDim() throws Exception { - if (this.ptr == 0) throw new Exception("null exception for stream ptr"); - return featureDim(this.ptr); - } - - // JNI interface libsherpa-onnx-jni.so - private native void acceptWaveform(long ptr, int sampleRate, float[] samples); + private native void acceptWaveform(long ptr, float[] samples, int sampleRate); private native void inputFinished(long ptr); - private native void deleteStream(long ptr); - - private native int numFramesReady(long ptr); - - private native boolean isLastFrame(long ptr); - - private native void reSet(long ptr); - - private native int featureDim(long ptr); -} + private native void delete(long ptr); +} \ No newline at end of file diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java index 6faf5f961..80a786cb5 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java @@ -1,6 +1,5 @@ -/* - * // Copyright 2022-2023 by zhaoming - */ +// Copyright 2022-2023 by zhaoming +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; @@ -9,10 +8,14 @@ public class OnlineTransducerModelConfig { private final String decoder; private final String joiner; - public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) { - this.encoder = encoder; - this.decoder = decoder; - this.joiner = joiner; + private OnlineTransducerModelConfig(Builder builder) { + this.encoder = builder.encoder; + this.decoder = builder.decoder; + this.joiner = builder.joiner; + } + + public static Builder builder() { + return new Builder(); } public String getEncoder() { @@ -26,4 +29,29 @@ public String getDecoder() { public String getJoiner() { return joiner; } + + public static class Builder { + private String encoder = ""; + private String decoder = ""; + private String joiner = ""; + + public OnlineTransducerModelConfig build() { + return new OnlineTransducerModelConfig(this); + } + + public Builder setEncoder(String encoder) { + this.encoder = encoder; + return this; + } + + public Builder setDecoder(String decoder) { + this.decoder = decoder; + return this; + } + + public Builder setJoiner(String joiner) { + this.joiner = joiner; + return this; + } + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java index 07309b501..eac706302 100644 --- a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig.java @@ -1,14 +1,31 @@ +// Copyright 2024 Xiaomi Corporation package com.k2fsa.sherpa.onnx; public class OnlineZipformer2CtcModelConfig { private final String model; - public OnlineZipformer2CtcModelConfig(String model) { - this.model = model; + private OnlineZipformer2CtcModelConfig(Builder builder) { + this.model = builder.model; + } + + public static Builder builder() { + return new Builder(); } public String getModel() { return model; } + public static class Builder { + private String model = ""; + + public OnlineZipformer2CtcModelConfig build() { + return new OnlineZipformer2CtcModelConfig(this); + } + + public Builder setModel(String model) { + this.model = model; + return this; + } + } } diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/WaveReader.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/WaveReader.java new file mode 100644 index 000000000..e835c9db0 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/WaveReader.java @@ -0,0 +1,29 @@ +// Copyright 2024 Xiaomi Corporation +package com.k2fsa.sherpa.onnx; + +public class WaveReader { + static { + System.loadLibrary("sherpa-onnx-jni"); + } + + private final int sampleRate; + private final float[] samples; + + // It supports only single channel, 16-bit wave file. + // It will exit the program if the given file has a wrong format + public WaveReader(String filename) { + Object[] arr = readWaveFromFile(filename); + samples = (float[]) arr[0]; + sampleRate = (int) arr[1]; + } + + public int getSampleRate() { + return sampleRate; + } + + public float[] getSamples() { + return samples; + } + + private native Object[] readWaveFromFile(String filename); +} diff --git a/sherpa-onnx/jni/CMakeLists.txt b/sherpa-onnx/jni/CMakeLists.txt index 339e945a5..ae3180c59 100644 --- a/sherpa-onnx/jni/CMakeLists.txt +++ b/sherpa-onnx/jni/CMakeLists.txt @@ -21,6 +21,7 @@ set(sources speaker-embedding-manager.cc spoken-language-identification.cc voice-activity-detector.cc + wave-reader.cc ) if(SHERPA_ONNX_ENABLE_TTS) diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index e70f5e608..955b4e30a 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -8,7 +8,6 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" -#include "sherpa-onnx/csrc/wave-reader.h" #include "sherpa-onnx/csrc/wave-writer.h" #include "sherpa-onnx/jni/common.h" @@ -43,69 +42,6 @@ JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( return ok; } -static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, - const char *p_filename) { - bool is_ok = false; - int32_t sampling_rate = -1; - std::vector samples = - sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok); - - if (!is_ok) { - SHERPA_ONNX_LOGE("Failed to read %s", p_filename); - exit(-1); - } - - jfloatArray samples_arr = env->NewFloatArray(samples.size()); - env->SetFloatArrayRegion(samples_arr, 0, samples.size(), samples.data()); - - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( - 2, env->FindClass("java/lang/Object"), nullptr); - - env->SetObjectArrayElement(obj_arr, 0, samples_arr); - env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate)); - - return obj_arr; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromFile( - JNIEnv *env, jclass /*cls*/, jstring filename) { - const char *p_filename = env->GetStringUTFChars(filename, nullptr); - std::ifstream is(p_filename, std::ios::binary); - - auto obj_arr = ReadWaveImpl(env, is, p_filename); - - env->ReleaseStringUTFChars(filename, p_filename); - - return obj_arr; -} - -SHERPA_ONNX_EXTERN_C -JNIEXPORT jobjectArray JNICALL -Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset( - JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) { - const char *p_filename = env->GetStringUTFChars(filename, nullptr); -#if __ANDROID_API__ >= 9 - AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); - if (!mgr) { - SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); - exit(-1); - } - std::vector buffer = sherpa_onnx::ReadFile(mgr, p_filename); - - std::istrstream is(buffer.data(), buffer.size()); -#else - std::ifstream is(p_filename, std::ios::binary); -#endif - - auto obj_arr = ReadWaveImpl(env, is, p_filename); - - env->ReleaseStringUTFChars(filename, p_filename); - - return obj_arr; -} - #if 0 SHERPA_ONNX_EXTERN_C JNIEXPORT void JNICALL diff --git a/sherpa-onnx/jni/wave-reader.cc b/sherpa-onnx/jni/wave-reader.cc new file mode 100644 index 000000000..489240583 --- /dev/null +++ b/sherpa-onnx/jni/wave-reader.cc @@ -0,0 +1,81 @@ +// sherpa-onnx/jni/wave-reader.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-onnx/csrc/wave-reader.h" + +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/jni/common.h" + +static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, + const char *p_filename) { + bool is_ok = false; + int32_t sampling_rate = -1; + std::vector samples = + sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok); + + if (!is_ok) { + SHERPA_ONNX_LOGE("Failed to read '%s'", p_filename); + exit(-1); + } + + jfloatArray samples_arr = env->NewFloatArray(samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, samples.size(), samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate)); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromFile( + JNIEnv *env, jclass /*cls*/, jstring filename) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); + std::ifstream is(p_filename, std::ios::binary); + + auto obj_arr = ReadWaveImpl(env, is, p_filename); + + env->ReleaseStringUTFChars(filename, p_filename); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_WaveReader_readWaveFromFile(JNIEnv *env, + jclass /*obj*/, + jstring filename) { + return Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromFile( + env, nullptr, filename); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset( + JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + exit(-1); + } + std::vector buffer = sherpa_onnx::ReadFile(mgr, p_filename); + + std::istrstream is(buffer.data(), buffer.size()); +#else + std::ifstream is(p_filename, std::ios::binary); +#endif + + auto obj_arr = ReadWaveImpl(env, is, p_filename); + + env->ReleaseStringUTFChars(filename, p_filename); + + return obj_arr; +}