diff --git a/.github/scripts/test-nodejs-npm.sh b/.github/scripts/test-nodejs-npm.sh index 8da89ffa1..c205d3880 100755 --- a/.github/scripts/test-nodejs-npm.sh +++ b/.github/scripts/test-nodejs-npm.sh @@ -51,6 +51,13 @@ rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 node ./test-online-transducer.js rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 +curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + +node ./test-online-zipformer2-ctc.js +rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + # offline tts curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 diff --git a/.github/scripts/test-online-ctc.sh b/.github/scripts/test-online-ctc.sh index f74ee3c3e..fa331be6f 100755 --- a/.github/scripts/test-online-ctc.sh +++ b/.github/scripts/test-online-ctc.sh @@ -13,6 +13,37 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run streaming Zipformer2 CTC " +log "------------------------------------------------------------" + +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +repo=$(basename -s .tar.bz2 $url) +curl -SL -O $url +tar xvf $repo.tar.bz2 +rm $repo.tar.bz2 + +log "test fp32" + +time $EXE \ + --debug=1 \ + --zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens=$repo/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav + +log "test int8" + +time $EXE \ + --debug=1 \ + --zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens=$repo/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav + + log "------------------------------------------------------------" log "Run streaming Conformer CTC from WeNet" log "------------------------------------------------------------" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 5491ab6fb..f63c2de66 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,27 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +mkdir -p /tmp/icefall-models +dir=/tmp/icefall-models + +pushd $dir +wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +popd +repo=$dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + +python3 ./python-api-examples/online-decode-files.py \ + --tokens=$repo/tokens.txt \ + --zipformer2-ctc=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav + +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose + +rm -rf $dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + wenet_models=( sherpa-onnx-zh-wenet-aishell sherpa-onnx-zh-wenet-aishell2 @@ -17,8 +38,6 @@ sherpa-onnx-en-wenet-librispeech sherpa-onnx-en-wenet-gigaspeech ) -mkdir -p /tmp/icefall-models -dir=/tmp/icefall-models for name in ${wenet_models[@]}; do repo_url=https://huggingface.co/csukuangfj/$name diff --git a/.github/scripts/test-swift.sh b/.github/scripts/test-swift.sh index 6695b51fb..15642ebc5 100755 --- a/.github/scripts/test-swift.sh +++ b/.github/scripts/test-swift.sh @@ -21,6 +21,9 @@ cat /Users/fangjun/Desktop/Obama.srt ./run-tts.sh ls -lh +./run-decode-file.sh +rm decode-file +sed -i.bak '20d' ./decode-file.swift ./run-decode-file.sh ./run-decode-file-non-streaming.sh diff --git a/.github/workflows/export-wenet-to-onnx.yaml b/.github/workflows/export-wenet-to-onnx.yaml index 3ac14e8d1..10f229dc1 100644 --- a/.github/workflows/export-wenet-to-onnx.yaml +++ b/.github/workflows/export-wenet-to-onnx.yaml @@ -22,7 +22,7 @@ jobs: - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/export-wespeaker-to-onnx.yaml b/.github/workflows/export-wespeaker-to-onnx.yaml index b35fc71f8..18485ceae 100644 --- a/.github/workflows/export-wespeaker-to-onnx.yaml +++ b/.github/workflows/export-wespeaker-to-onnx.yaml @@ -22,7 +22,7 @@ jobs: - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml index d00018616..b755b386f 100644 --- a/.github/workflows/export-whisper-to-onnx.yaml +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -24,7 +24,7 @@ jobs: - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index e0cc9f319..3c1c16750 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -107,23 +107,23 @@ jobs: name: release-static path: build/bin/* - - name: Test offline Whisper + - name: Test online CTC shell: bash run: | export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline - - readelf -d build/bin/sherpa-onnx-offline + export EXE=sherpa-onnx - .github/scripts/test-offline-whisper.sh + .github/scripts/test-online-ctc.sh - - name: Test online CTC + - name: Test offline Whisper shell: bash run: | export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx + export EXE=sherpa-onnx-offline - .github/scripts/test-online-ctc.sh + readelf -d build/bin/sherpa-onnx-offline + + .github/scripts/test-offline-whisper.sh - name: Test offline CTC shell: bash diff --git a/.github/workflows/npm.yaml b/.github/workflows/npm.yaml index 98f633584..3f96b7ccf 100644 --- a/.github/workflows/npm.yaml +++ b/.github/workflows/npm.yaml @@ -25,7 +25,7 @@ jobs: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index 348343fe9..ddde2ff00 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -55,7 +55,7 @@ jobs: key: ${{ matrix.os }}-python-${{ matrix.python-version }} - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/style_check.yaml b/.github/workflows/style_check.yaml index 79630889c..6d59c7134 100644 --- a/.github/workflows/style_check.yaml +++ b/.github/workflows/style_check.yaml @@ -49,7 +49,7 @@ jobs: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-build-wheel.yaml b/.github/workflows/test-build-wheel.yaml index c82d361c6..d4d5117a9 100644 --- a/.github/workflows/test-build-wheel.yaml +++ b/.github/workflows/test-build-wheel.yaml @@ -29,7 +29,7 @@ jobs: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-dot-net.yaml b/.github/workflows/test-dot-net.yaml index f5fabda45..940fce5c9 100644 --- a/.github/workflows/test-dot-net.yaml +++ b/.github/workflows/test-dot-net.yaml @@ -61,7 +61,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest] #, windows-latest] python-version: ["3.8"] steps: @@ -70,7 +70,7 @@ jobs: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -143,6 +143,7 @@ jobs: cd dotnet-examples/ cd online-decode-files + ./run-zipformer2-ctc.sh ./run-transducer.sh ./run-paraformer.sh diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index 1ea5888ee..92b58bb68 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -53,7 +53,7 @@ jobs: mkdir build cd build cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF .. - make -j + make -j1 cp -v _deps/onnxruntime-src/lib/libonnxruntime*dylib ./lib/ cd ../scripts/go/_internal/ @@ -153,6 +153,14 @@ jobs: git lfs install + echo "Test zipformer2 CTC" + wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + + ./run-zipformer2-ctc.sh + rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + echo "Test transducer" git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 ./run-transducer.sh diff --git a/.github/workflows/test-nodejs-npm.yaml b/.github/workflows/test-nodejs-npm.yaml index 75826cb9f..459ac46e4 100644 --- a/.github/workflows/test-nodejs-npm.yaml +++ b/.github/workflows/test-nodejs-npm.yaml @@ -34,7 +34,7 @@ jobs: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-nodejs.yaml b/.github/workflows/test-nodejs.yaml index 8dd944066..80e10a6ca 100644 --- a/.github/workflows/test-nodejs.yaml +++ b/.github/workflows/test-nodejs.yaml @@ -52,7 +52,7 @@ jobs: ls -lh install/lib - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-pip-install.yaml b/.github/workflows/test-pip-install.yaml index 587258f81..264c50271 100644 --- a/.github/workflows/test-pip-install.yaml +++ b/.github/workflows/test-pip-install.yaml @@ -40,7 +40,7 @@ jobs: fetch-depth: 0 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-python-offline-websocket-server.yaml b/.github/workflows/test-python-offline-websocket-server.yaml index b13be5ff1..a8415ea07 100644 --- a/.github/workflows/test-python-offline-websocket-server.yaml +++ b/.github/workflows/test-python-offline-websocket-server.yaml @@ -38,7 +38,7 @@ jobs: key: ${{ matrix.os }}-python-${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test-python-online-websocket-server.yaml b/.github/workflows/test-python-online-websocket-server.yaml index 7ec641228..e32366990 100644 --- a/.github/workflows/test-python-online-websocket-server.yaml +++ b/.github/workflows/test-python-online-websocket-server.yaml @@ -25,7 +25,7 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - model_type: ["transducer", "paraformer"] + model_type: ["transducer", "paraformer", "zipformer2-ctc"] steps: - uses: actions/checkout@v4 @@ -38,7 +38,7 @@ jobs: key: ${{ matrix.os }}-python-${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -57,6 +57,26 @@ jobs: python3 -m pip install --no-deps --verbose . python3 -m pip install websockets + - name: Start server for zipformer2 CTC models + if: matrix.model_type == 'zipformer2-ctc' + shell: bash + run: | + curl -O -L https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + + python3 ./python-api-examples/streaming_server.py \ + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt & + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for zipformer2 CTC models + if: matrix.model_type == 'zipformer2-ctc' + shell: bash + run: | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav - name: Start server for transducer models if: matrix.model_type == 'transducer' diff --git a/CMakeLists.txt b/CMakeLists.txt index 19c27189d..6992781fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.9.4") +set(SHERPA_ONNX_VERSION "1.9.6") # Disable warning about # diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index d34266957..e3d60a207 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -26,9 +26,14 @@ data class OnlineParaformerModelConfig( var decoder: String = "", ) +data class OnlineZipformer2CtcModelConfig( + var model: String = "", +) + data class OnlineModelConfig( var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), + var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(), var tokens: String, var numThreads: Int = 1, var debug: Boolean = false, diff --git a/dotnet-examples/.gitignore b/dotnet-examples/.gitignore index 1746e3269..8b3f0df85 100644 --- a/dotnet-examples/.gitignore +++ b/dotnet-examples/.gitignore @@ -1,2 +1,3 @@ bin obj +!*.sh diff --git a/dotnet-examples/online-decode-files/Program.cs b/dotnet-examples/online-decode-files/Program.cs index 72c996ca8..2e0264a89 100644 --- a/dotnet-examples/online-decode-files/Program.cs +++ b/dotnet-examples/online-decode-files/Program.cs @@ -38,6 +38,9 @@ class Options [Option("paraformer-decoder", Required = false, HelpText = "Path to paraformer decoder.onnx")] public string ParaformerDecoder { get; set; } + [Option("zipformer2-ctc", Required = false, HelpText = "Path to zipformer2 CTC onnx model")] + public string Zipformer2Ctc { get; set; } + [Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")] public int NumThreads { get; set; } @@ -107,7 +110,19 @@ dotnet run \ --files ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \ ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav -(2) Streaming Paraformer models +(2) Streaming Zipformer2 Ctc models + +dotnet run -c Release \ + --tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav + +(3) Streaming Paraformer models dotnet run \ --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ @@ -121,6 +136,7 @@ dotnet run \ Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html to download pre-trained streaming models. "; @@ -150,6 +166,8 @@ private static void Run(Options options) config.ModelConfig.Paraformer.Encoder = options.ParaformerEncoder; config.ModelConfig.Paraformer.Decoder = options.ParaformerDecoder; + config.ModelConfig.Zipformer2Ctc.Model = options.Zipformer2Ctc; + config.ModelConfig.Tokens = options.Tokens; config.ModelConfig.Provider = options.Provider; config.ModelConfig.NumThreads = options.NumThreads; diff --git a/dotnet-examples/online-decode-files/run-zipformer2-ctc.sh b/dotnet-examples/online-decode-files/run-zipformer2-ctc.sh new file mode 100755 index 000000000..910b27f36 --- /dev/null +++ b/dotnet-examples/online-decode-files/run-zipformer2-ctc.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese +# to download the model files + +if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +fi + +dotnet run -c Release \ + --tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav diff --git a/go-api-examples/.gitignore b/go-api-examples/.gitignore new file mode 100644 index 000000000..6af50295b --- /dev/null +++ b/go-api-examples/.gitignore @@ -0,0 +1 @@ +!*.sh diff --git a/go-api-examples/streaming-decode-files/main.go b/go-api-examples/streaming-decode-files/main.go index fc2922236..5ec2c7cbb 100644 --- a/go-api-examples/streaming-decode-files/main.go +++ b/go-api-examples/streaming-decode-files/main.go @@ -22,6 +22,7 @@ func main() { flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model") flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model") flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model") + flag.StringVar(&config.ModelConfig.Zipformer2Ctc.Model, "zipformer2-ctc", "", "Path to the zipformer2 CTC model") flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file") flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing") flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message") diff --git a/go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh b/go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh new file mode 100755 index 000000000..4b145439e --- /dev/null +++ b/go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese +# to download the model +# before you run this script. +# +# You can switch to a different online model if you need + +./streaming-decode-files \ + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index 2d6f9e6fd..220997202 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -8,7 +8,8 @@ fun callback(samples: FloatArray): Unit { fun main() { testTts() - testAsr() + testAsr("transducer") + testAsr("zipformer2-ctc") } fun testTts() { @@ -30,25 +31,43 @@ fun testTts() { audio.save(filename="test-en.wav") } -fun testAsr() { +fun testAsr(type: String) { var featConfig = FeatureConfig( sampleRate = 16000, featureDim = 80, ) - // please refer to - // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html - // to dowload pre-trained models - var modelConfig = OnlineModelConfig( - transducer = OnlineTransducerModelConfig( - encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", - decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", - joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", - ), - tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", - numThreads = 1, - debug = false, - ) + var waveFilename: String + var modelConfig: OnlineModelConfig = when (type) { + "transducer" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav" + // please refer to + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + // to dowload pre-trained models + OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", + numThreads = 1, + debug = false, + ) + } + "zipformer2-ctc" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" + OnlineModelConfig( + zipformer2Ctc = OnlineZipformer2CtcModelConfig( + model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt", + numThreads = 1, + debug = false, + ) + } + else -> throw IllegalArgumentException(type) + } var endpointConfig = EndpointConfig() @@ -69,7 +88,7 @@ fun testAsr() { ) var objArray = WaveReader.readWaveFromFile( - filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", + filename = waveFilename, ) var samples: FloatArray = objArray[0] as FloatArray var sampleRate: Int = objArray[1] as Int diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 499221e7b..ddf412b20 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -34,6 +34,12 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 fi +if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +fi + if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 tar xf vits-piper-en_US-amy-low.tar.bz2 diff --git a/nodejs-examples/README.md b/nodejs-examples/README.md index 647609a23..ead8b5291 100644 --- a/nodejs-examples/README.md +++ b/nodejs-examples/README.md @@ -85,7 +85,7 @@ npm install wav naudiodon2 how to decode a file with a NeMo CTC model. In the code we use [stt_en_conformer_ctc_small](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/english.html#stt-en-conformer-ctc-small). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-small.tar.bz2 @@ -99,7 +99,7 @@ node ./test-offline-nemo-ctc.js how to decode a file with a non-streaming Paraformer model. In the code we use [sherpa-onnx-paraformer-zh-2023-03-28](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 @@ -113,7 +113,7 @@ node ./test-offline-paraformer.js how to decode a file with a non-streaming transducer model. In the code we use [sherpa-onnx-zipformer-en-2023-06-26](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-zipformer-en-2023-06-26-english). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 @@ -126,7 +126,7 @@ node ./test-offline-transducer.js how to decode a file with a Whisper model. In the code we use [sherpa-onnx-whisper-tiny.en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 @@ -140,7 +140,7 @@ demonstrates how to do real-time speech recognition from microphone with a streaming Paraformer model. In the code we use [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 @@ -153,7 +153,7 @@ node ./test-online-paraformer-microphone.js how to decode a file using a streaming Paraformer model. In the code we use [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 @@ -167,7 +167,7 @@ demonstrates how to do real-time speech recognition with microphone using a stre we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 @@ -180,7 +180,7 @@ node ./test-online-transducer-microphone.js how to decode a file using a streaming transducer model. In the code we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 @@ -188,13 +188,26 @@ tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 node ./test-online-transducer.js ``` +## ./test-online-zipformer2-ctc.js +[./test-online-zipformer2-ctc.js](./test-online-zipformer2-ctc.js) demonstrates +how to decode a file using a streaming zipformer2 CTC model. In the code +we use [sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese). + +You can use the following command to run it: + +```bash +wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +node ./test-online-zipformer2-ctc.js +``` + ## ./test-vad-microphone-offline-paraformer.js [./test-vad-microphone-offline-paraformer.js](./test-vad-microphone-offline-paraformer.js) demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) with non-streaming Paraformer for speech recognition from microphone. -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx @@ -209,7 +222,7 @@ node ./test-vad-microphone-offline-paraformer.js demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) with a non-streaming transducer model for speech recognition from microphone. -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx @@ -224,7 +237,7 @@ node ./test-vad-microphone-offline-transducer.js demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) with whisper for speech recognition from microphone. -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx @@ -238,7 +251,7 @@ node ./test-vad-microphone-offline-whisper.js [./test-vad-microphone.js](./test-vad-microphone.js) demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad). -You can use the following command run it: +You can use the following command to run it: ```bash wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx diff --git a/nodejs-examples/test-online-zipformer2-ctc.js b/nodejs-examples/test-online-zipformer2-ctc.js new file mode 100644 index 000000000..015d79449 --- /dev/null +++ b/nodejs-examples/test-online-zipformer2-ctc.js @@ -0,0 +1,97 @@ +// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) +// +const fs = require('fs'); +const {Readable} = require('stream'); +const wav = require('wav'); + +const sherpa_onnx = require('sherpa-onnx'); + +function createRecognizer() { + const featConfig = new sherpa_onnx.FeatureConfig(); + featConfig.sampleRate = 16000; + featConfig.featureDim = 80; + + // test online recognizer + const zipformer2Ctc = new sherpa_onnx.OnlineZipformer2CtcModelConfig(); + zipformer2Ctc.model = + './sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx'; + const tokens = + './sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt'; + + const modelConfig = new sherpa_onnx.OnlineModelConfig(); + modelConfig.zipformer2Ctc = zipformer2Ctc; + modelConfig.tokens = tokens; + + const recognizerConfig = new sherpa_onnx.OnlineRecognizerConfig(); + recognizerConfig.featConfig = featConfig; + recognizerConfig.modelConfig = modelConfig; + recognizerConfig.decodingMethod = 'greedy_search'; + + recognizer = new sherpa_onnx.OnlineRecognizer(recognizerConfig); + return recognizer; +} +recognizer = createRecognizer(); +stream = recognizer.createStream(); + +const waveFilename = + './sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav'; + +const reader = new wav.Reader(); +const readable = new Readable().wrap(reader); + +function decode(samples) { + stream.acceptWaveform(recognizer.config.featConfig.sampleRate, samples); + + while (recognizer.isReady(stream)) { + recognizer.decode(stream); + } + const r = recognizer.getResult(stream); + console.log(r.text); +} + +reader.on('format', ({audioFormat, bitDepth, channels, sampleRate}) => { + if (sampleRate != recognizer.config.featConfig.sampleRate) { + throw new Error(`Only support sampleRate ${ + recognizer.config.featConfig.sampleRate}. Given ${sampleRate}`); + } + + if (audioFormat != 1) { + throw new Error(`Only support PCM format. Given ${audioFormat}`); + } + + if (channels != 1) { + throw new Error(`Only a single channel. Given ${channel}`); + } + + if (bitDepth != 16) { + throw new Error(`Only support 16-bit samples. Given ${bitDepth}`); + } +}); + +fs.createReadStream(waveFilename, {'highWaterMark': 4096}) + .pipe(reader) + .on('finish', function(err) { + // tail padding + const floatSamples = + new Float32Array(recognizer.config.featConfig.sampleRate * 0.5); + decode(floatSamples); + stream.free(); + recognizer.free(); + }); + +readable.on('readable', function() { + let chunk; + while ((chunk = readable.read()) != null) { + const int16Samples = new Int16Array( + chunk.buffer, chunk.byteOffset, + chunk.length / Int16Array.BYTES_PER_ELEMENT); + + const floatSamples = new Float32Array(int16Samples.length); + + for (let i = 0; i < floatSamples.length; i++) { + floatSamples[i] = int16Samples[i] / 32768.0; + } + + decode(floatSamples); + } +}); diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index 56f9dc525..d4e69046d 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -37,7 +37,20 @@ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav -(3) Streaming Conformer CTC from WeNet +(3) Streaming Zipformer2 CTC + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + +./python-api-examples/online-decode-files.py \ + --zipformer2-ctc=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav + +(4) Streaming Conformer CTC from WeNet GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech cd sherpa-onnx-zh-wenet-wenetspeech @@ -51,12 +64,9 @@ ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav - Please refer to -https://k2-fsa.github.io/sherpa/onnx/index.html -and -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html -to install sherpa-onnx and to download streaming pre-trained models. +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +to download streaming pre-trained models. """ import argparse import time @@ -97,6 +107,12 @@ def get_args(): help="Path to the transducer joiner model", ) + parser.add_argument( + "--zipformer2-ctc", + type=str, + help="Path to the zipformer2 ctc model", + ) + parser.add_argument( "--paraformer-encoder", type=str, @@ -112,7 +128,7 @@ def get_args(): parser.add_argument( "--wenet-ctc", type=str, - help="Path to the wenet ctc model model", + help="Path to the wenet ctc model", ) parser.add_argument( @@ -275,6 +291,16 @@ def main(): hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, ) + elif args.zipformer2_ctc: + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( + tokens=args.tokens, + model=args.zipformer2_ctc, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) elif args.paraformer_encoder: recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( tokens=args.tokens, diff --git a/python-api-examples/online-websocket-client-decode-file.py b/python-api-examples/online-websocket-client-decode-file.py index e0ad8d256..cbe55c860 100755 --- a/python-api-examples/online-websocket-client-decode-file.py +++ b/python-api-examples/online-websocket-client-decode-file.py @@ -25,6 +25,7 @@ import argparse import asyncio +import json import logging import wave @@ -112,7 +113,7 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): async for message in socket: if message != "Done!": last_message = message - logging.info(message) + logging.info(json.loads(message)) else: break return last_message @@ -151,7 +152,7 @@ async def run( await websocket.send("Done") decoding_results = await receive_task - logging.info(f"\nFinal result is:\n{decoding_results}") + logging.info(f"\nFinal result is:\n{json.loads(decoding_results)}") async def main(): diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index c47ea28a9..a06a7deda 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -137,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser): help="Path to the transducer joiner model.", ) + parser.add_argument( + "--zipformer2-ctc", + type=str, + help="Path to the model file from zipformer2 ctc", + ) + parser.add_argument( "--wenet-ctc", type=str, @@ -405,6 +411,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: rule3_min_utterance_length=args.rule3_min_utterance_length, provider=args.provider, ) + elif args.zipformer2_ctc: + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( + tokens=args.tokens, + model=args.zipformer2_ctc, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) elif args.wenet_ctc: recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( tokens=args.tokens, @@ -748,6 +768,8 @@ def check_args(args): assert args.paraformer_encoder is None, args.paraformer_encoder assert args.paraformer_decoder is None, args.paraformer_decoder + assert args.zipformer2_ctc is None, args.zipformer2_ctc + assert args.wenet_ctc is None, args.wenet_ctc elif args.paraformer_encoder: assert Path( args.paraformer_encoder @@ -756,6 +778,10 @@ def check_args(args): assert Path( args.paraformer_decoder ).is_file(), f"{args.paraformer_decoder} does not exist" + elif args.zipformer2_ctc: + assert Path( + args.zipformer2_ctc + ).is_file(), f"{args.zipformer2_ctc} does not exist" elif args.wenet_ctc: assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist" else: diff --git a/scripts/dotnet/online.cs b/scripts/dotnet/online.cs index f0ca414b0..63470077e 100644 --- a/scripts/dotnet/online.cs +++ b/scripts/dotnet/online.cs @@ -50,6 +50,18 @@ public OnlineParaformerModelConfig() public string Decoder; } + [StructLayout(LayoutKind.Sequential)] + public struct OnlineZipformer2CtcModelConfig + { + public OnlineZipformer2CtcModelConfig() + { + Model = ""; + } + + [MarshalAs(UnmanagedType.LPStr)] + public string Model; + } + [StructLayout(LayoutKind.Sequential)] public struct OnlineModelConfig { @@ -57,6 +69,7 @@ public OnlineModelConfig() { Transducer = new OnlineTransducerModelConfig(); Paraformer = new OnlineParaformerModelConfig(); + Zipformer2Ctc = new OnlineZipformer2CtcModelConfig(); Tokens = ""; NumThreads = 1; Provider = "cpu"; @@ -66,6 +79,7 @@ public OnlineModelConfig() public OnlineTransducerModelConfig Transducer; public OnlineParaformerModelConfig Paraformer; + public OnlineZipformer2CtcModelConfig Zipformer2Ctc; [MarshalAs(UnmanagedType.LPStr)] public string Tokens; diff --git a/scripts/go/_internal/streaming-decode-files/run-zipformer2-ctc.sh b/scripts/go/_internal/streaming-decode-files/run-zipformer2-ctc.sh new file mode 120000 index 000000000..b2e29cca2 --- /dev/null +++ b/scripts/go/_internal/streaming-decode-files/run-zipformer2-ctc.sh @@ -0,0 +1 @@ +../../../../go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh \ No newline at end of file diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index d07b18185..9a869e25c 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -65,6 +65,13 @@ type OnlineParaformerModelConfig struct { Decoder string // Path to the decoder model. } +// Please refer to +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html +// to download pre-trained models +type OnlineZipformer2CtcModelConfig struct { + Model string // Path to the onnx model +} + // Configuration for online/streaming models // // Please refer to @@ -72,13 +79,14 @@ type OnlineParaformerModelConfig struct { // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html // to download pre-trained models type OnlineModelConfig struct { - Transducer OnlineTransducerModelConfig - Paraformer OnlineParaformerModelConfig - Tokens string // Path to tokens.txt - NumThreads int // Number of threads to use for neural network computation - Provider string // Optional. Valid values are: cpu, cuda, coreml - Debug int // 1 to show model meta information while loading it. - ModelType string // Optional. You can specify it for faster model initialization + Transducer OnlineTransducerModelConfig + Paraformer OnlineParaformerModelConfig + Zipformer2Ctc OnlineZipformer2CtcModelConfig + Tokens string // Path to tokens.txt + NumThreads int // Number of threads to use for neural network computation + Provider string // Optional. Valid values are: cpu, cuda, coreml + Debug int // 1 to show model meta information while loading it. + ModelType string // Optional. You can specify it for faster model initialization } // Configuration for the feature extractor @@ -157,6 +165,9 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder) defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder)) + c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model) + defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model)) + c.model_config.tokens = C.CString(config.ModelConfig.Tokens) defer C.free(unsafe.Pointer(c.model_config.tokens)) diff --git a/scripts/nodejs/index.js b/scripts/nodejs/index.js index b61f29550..da6178b34 100644 --- a/scripts/nodejs/index.js +++ b/scripts/nodejs/index.js @@ -41,9 +41,14 @@ const SherpaOnnxOnlineParaformerModelConfig = StructType({ "decoder" : cstring, }); +const SherpaOnnxOnlineZipformer2CtcModelConfig = StructType({ + "model" : cstring, +}); + const SherpaOnnxOnlineModelConfig = StructType({ "transducer" : SherpaOnnxOnlineTransducerModelConfig, "paraformer" : SherpaOnnxOnlineParaformerModelConfig, + "zipformer2Ctc" : SherpaOnnxOnlineZipformer2CtcModelConfig, "tokens" : cstring, "numThreads" : int32_t, "provider" : cstring, @@ -663,6 +668,7 @@ const OnlineModelConfig = SherpaOnnxOnlineModelConfig; const FeatureConfig = SherpaOnnxFeatureConfig; const OnlineRecognizerConfig = SherpaOnnxOnlineRecognizerConfig; const OnlineParaformerModelConfig = SherpaOnnxOnlineParaformerModelConfig; +const OnlineZipformer2CtcModelConfig = SherpaOnnxOnlineZipformer2CtcModelConfig; // offline asr const OfflineTransducerModelConfig = SherpaOnnxOfflineTransducerModelConfig; @@ -692,6 +698,7 @@ module.exports = { OnlineRecognizer, OnlineStream, OnlineParaformerModelConfig, + OnlineZipformer2CtcModelConfig, // offline asr OfflineRecognizer, diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index b8bffab9c..5d96d6899 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -54,6 +54,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.model_config.paraformer.decoder = SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + recognizer_config.model_config.zipformer2_ctc.model = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.num_threads = diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 11971808c..b6c091298 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -66,9 +66,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig { const char *decoder; } SherpaOnnxOnlineParaformerModelConfig; -SHERPA_ONNX_API typedef struct SherpaOnnxModelConfig { +// Please visit +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html# +// to download pre-trained streaming zipformer2 ctc models +SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig { + const char *model; +} SherpaOnnxOnlineZipformer2CtcModelConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { SherpaOnnxOnlineTransducerModelConfig transducer; SherpaOnnxOnlineParaformerModelConfig paraformer; + SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc; const char *tokens; int32_t num_threads; const char *provider; diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index c114e08fb..2b325b40f 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -70,6 +70,8 @@ set(sources online-wenet-ctc-model-config.cc online-wenet-ctc-model.cc online-zipformer-transducer-model.cc + online-zipformer2-ctc-model-config.cc + online-zipformer2-ctc-model.cc online-zipformer2-transducer-model.cc onnx-utils.cc packed-sequence.cc diff --git a/sherpa-onnx/csrc/online-ctc-decoder.h b/sherpa-onnx/csrc/online-ctc-decoder.h index 3e701bb37..6690e1bb2 100644 --- a/sherpa-onnx/csrc/online-ctc-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-decoder.h @@ -12,6 +12,9 @@ namespace sherpa_onnx { struct OnlineCtcDecoderResult { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + /// The decoded token IDs std::vector tokens; diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc index 8a5a606aa..909373e71 100644 --- a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc @@ -49,12 +49,17 @@ void OnlineCtcGreedySearchDecoder::Decode( if (y != blank_id_ && y != prev_id) { r.tokens.push_back(y); - r.timestamps.push_back(t); + r.timestamps.push_back(t + r.frame_offset); } prev_id = y; } // for (int32_t t = 0; t != num_frames; ++t) { } // for (int32_t b = 0; b != batch_size; ++b) + + // Update frame_offset + for (auto &r : *results) { + r.frame_offset += num_frames; + } } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-model.cc b/sherpa-onnx/csrc/online-ctc-model.cc index 4ec094e04..5fa76c192 100644 --- a/sherpa-onnx/csrc/online-ctc-model.cc +++ b/sherpa-onnx/csrc/online-ctc-model.cc @@ -11,127 +11,35 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" -namespace { - -enum class ModelType { - kZipformerCtc, - kWenetCtc, - kUnkown, -}; - -} // namespace - namespace sherpa_onnx { -static ModelType GetModelType(char *model_data, size_t model_data_length, - bool debug) { - Ort::Env env(ORT_LOGGING_LEVEL_WARNING); - Ort::SessionOptions sess_opts; - - auto sess = std::make_unique(env, model_data, model_data_length, - sess_opts); - - Ort::ModelMetadata meta_data = sess->GetModelMetadata(); - if (debug) { - std::ostringstream os; - PrintModelMetadata(os, meta_data); - SHERPA_ONNX_LOGE("%s", os.str().c_str()); - } - - Ort::AllocatorWithDefaultOptions allocator; - auto model_type = - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); - if (!model_type) { - SHERPA_ONNX_LOGE( - "No model_type in the metadata!\n" - "If you are using models from WeNet, please refer to\n" - "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/" - "run.sh\n" - "\n" - "for how to add metadta to model.onnx\n"); - return ModelType::kUnkown; - } - - if (model_type.get() == std::string("zipformer2")) { - return ModelType::kZipformerCtc; - } else if (model_type.get() == std::string("wenet_ctc")) { - return ModelType::kWenetCtc; - } else { - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); - return ModelType::kUnkown; - } -} - std::unique_ptr OnlineCtcModel::Create( const OnlineModelConfig &config) { - ModelType model_type = ModelType::kUnkown; - - std::string filename; if (!config.wenet_ctc.model.empty()) { - filename = config.wenet_ctc.model; + return std::make_unique(config); + } else if (!config.zipformer2_ctc.model.empty()) { + return std::make_unique(config); } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); } - - { - auto buffer = ReadFile(filename); - - model_type = GetModelType(buffer.data(), buffer.size(), config.debug); - } - - switch (model_type) { - case ModelType::kZipformerCtc: - return nullptr; - // return std::make_unique(config); - break; - case ModelType::kWenetCtc: - return std::make_unique(config); - break; - case ModelType::kUnkown: - SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); - return nullptr; - } - - return nullptr; } #if __ANDROID_API__ >= 9 std::unique_ptr OnlineCtcModel::Create( AAssetManager *mgr, const OnlineModelConfig &config) { - ModelType model_type = ModelType::kUnkown; - - std::string filename; if (!config.wenet_ctc.model.empty()) { - filename = config.wenet_ctc.model; + return std::make_unique(mgr, config); + } else if (!config.zipformer2_ctc.model.empty()) { + return std::make_unique(mgr, config); } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); } - - { - auto buffer = ReadFile(mgr, filename); - - model_type = GetModelType(buffer.data(), buffer.size(), config.debug); - } - - switch (model_type) { - case ModelType::kZipformerCtc: - return nullptr; - // return std::make_unique(mgr, config); - break; - case ModelType::kWenetCtc: - return std::make_unique(mgr, config); - break; - case ModelType::kUnkown: - SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); - return nullptr; - } - - return nullptr; } #endif diff --git a/sherpa-onnx/csrc/online-ctc-model.h b/sherpa-onnx/csrc/online-ctc-model.h index c89cf6054..17721752d 100644 --- a/sherpa-onnx/csrc/online-ctc-model.h +++ b/sherpa-onnx/csrc/online-ctc-model.h @@ -33,6 +33,26 @@ class OnlineCtcModel { // Return a list of tensors containing the initial states virtual std::vector GetInitStates() const = 0; + /** Stack a list of individual states into a batch. + * + * It is the inverse operation of `UnStackStates`. + * + * @param states states[i] contains the state for the i-th utterance. + * @return Return a single value representing the batched state. + */ + virtual std::vector StackStates( + std::vector> states) const = 0; + + /** Unstack a batch state into a list of individual states. + * + * It is the inverse operation of `StackStates`. + * + * @param states A batched state. + * @return ans[i] contains the state for the i-th utterance. + */ + virtual std::vector> UnStackStates( + std::vector states) const = 0; + /** * * @param x A 3-D tensor of shape (N, T, C). N has to be 1. @@ -60,6 +80,9 @@ class OnlineCtcModel { // ChunkLength() frames, we advance by ChunkShift() frames // before we process the next chunk. virtual int32_t ChunkShift() const = 0; + + // Return true if the model supports batch size > 1 + virtual bool SupportBatchProcessing() const { return true; } }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index a81ce375d..6e0ab6d77 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -14,6 +14,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { transducer.Register(po); paraformer.Register(po); wenet_ctc.Register(po); + zipformer2_ctc.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -26,10 +27,11 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); - po->Register("model-type", &model_type, - "Specify it to reduce model initialization time. " - "Valid values are: conformer, lstm, zipformer, zipformer2." - "All other values lead to loading the model twice."); + po->Register( + "model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc" + "All other values lead to loading the model twice."); } bool OnlineModelConfig::Validate() const { @@ -51,6 +53,10 @@ bool OnlineModelConfig::Validate() const { return wenet_ctc.Validate(); } + if (!zipformer2_ctc.model.empty()) { + return zipformer2_ctc.Validate(); + } + return transducer.Validate(); } @@ -61,6 +67,7 @@ std::string OnlineModelConfig::ToString() const { os << "transducer=" << transducer.ToString() << ", "; os << "paraformer=" << paraformer.ToString() << ", "; os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; + os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 34369b959..bedabf119 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -9,6 +9,7 @@ #include "sherpa-onnx/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" namespace sherpa_onnx { @@ -16,6 +17,7 @@ struct OnlineModelConfig { OnlineTransducerModelConfig transducer; OnlineParaformerModelConfig paraformer; OnlineWenetCtcModelConfig wenet_ctc; + OnlineZipformer2CtcModelConfig zipformer2_ctc; std::string tokens; int32_t num_threads = 1; bool debug = false; @@ -25,7 +27,8 @@ struct OnlineModelConfig { // - conformer, conformer transducer from icefall // - lstm, lstm transducer from icefall // - zipformer, zipformer transducer from icefall - // - zipformer2, zipformer2 transducer from icefall + // - zipformer2, zipformer2 transducer or CTC from icefall + // - wenet_ctc, wenet CTC model // // All other values are invalid and lead to loading the model twice. std::string model_type; @@ -34,11 +37,13 @@ struct OnlineModelConfig { OnlineModelConfig(const OnlineTransducerModelConfig &transducer, const OnlineParaformerModelConfig ¶former, const OnlineWenetCtcModelConfig &wenet_ctc, + const OnlineZipformer2CtcModelConfig &zipformer2_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), paraformer(paraformer), wenet_ctc(wenet_ctc), + zipformer2_ctc(zipformer2_ctc), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 16b8ca48a..f59dbd84b 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -96,8 +96,67 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { } void DecodeStreams(OnlineStream **ss, int32_t n) const override { + if (n == 1 || !model_->SupportBatchProcessing()) { + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + return; + } + + // batch processing + int32_t chunk_length = model_->ChunkLength(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_length * feat_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + for (int32_t i = 0; i != n; ++i) { - DecodeStream(ss[i]); + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_length); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_length * feat_dim); + + results[i] = std::move(ss[i]->GetCtcResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{n, chunk_length, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + auto states = model_->StackStates(std::move(states_vec)); + int32_t num_states = states.size(); + auto out = model_->Forward(std::move(x), std::move(states)); + std::vector out_states; + out_states.reserve(num_states); + + for (int32_t k = 1; k != num_states + 1; ++k) { + out_states.push_back(std::move(out[k])); + } + + std::vector> next_states = + model_->UnStackStates(std::move(out_states)); + + decoder_->Decode(std::move(out[0]), &results); + + for (int32_t k = 0; k != n; ++k) { + ss[k]->SetCtcResult(results[k]); + ss[k]->SetStates(std::move(next_states[k])); } } diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 59d8658b8..c5923c608 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -20,7 +20,8 @@ std::unique_ptr OnlineRecognizerImpl::Create( return std::make_unique(config); } - if (!config.model_config.wenet_ctc.model.empty()) { + if (!config.model_config.wenet_ctc.model.empty() || + !config.model_config.zipformer2_ctc.model.empty()) { return std::make_unique(config); } @@ -39,7 +40,8 @@ std::unique_ptr OnlineRecognizerImpl::Create( return std::make_unique(mgr, config); } - if (!config.model_config.wenet_ctc.model.empty()) { + if (!config.model_config.wenet_ctc.model.empty() || + !config.model_config.zipformer2_ctc.model.empty()) { return std::make_unique(mgr, config); } diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index eac1a21cb..34557bf10 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -1,4 +1,4 @@ -// sherpa-onnx/csrc/online-paraformer-model.cc +// sherpa-onnx/csrc/online-wenet-ctc-model.cc // // Copyright (c) 2023 Xiaomi Corporation @@ -239,4 +239,21 @@ std::vector OnlineWenetCtcModel::GetInitStates() const { return impl_->GetInitStates(); } +std::vector OnlineWenetCtcModel::StackStates( + std::vector> states) const { + if (states.size() != 1) { + SHERPA_ONNX_LOGE("wenet CTC model supports only batch_size==1. Given: %d", + static_cast(states.size())); + } + + return std::move(states[0]); +} + +std::vector> OnlineWenetCtcModel::UnStackStates( + std::vector states) const { + std::vector> ans(1); + ans[0] = std::move(states); + return ans; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.h b/sherpa-onnx/csrc/online-wenet-ctc-model.h index ccfd378b6..1be1034cc 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.h +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.h @@ -35,6 +35,12 @@ class OnlineWenetCtcModel : public OnlineCtcModel { // - offset std::vector GetInitStates() const override; + std::vector StackStates( + std::vector> states) const override; + + std::vector> UnStackStates( + std::vector states) const override; + /** * * @param x A 3-D tensor of shape (N, T, C). N has to be 1. @@ -63,6 +69,8 @@ class OnlineWenetCtcModel : public OnlineCtcModel { // before we process the next chunk. int32_t ChunkShift() const override; + bool SupportBatchProcessing() const override { return false; } + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc new file mode 100644 index 000000000..836808d6f --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineZipformer2CtcModelConfig::Register(ParseOptions *po) { + po->Register("zipformer2-ctc-model", &model, + "Path to CTC model.onnx. See also " + "https://github.com/k2-fsa/icefall/pull/1413"); +} + +bool OnlineZipformer2CtcModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("--zipformer2-ctc-model is empty!"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--zipformer2-ctc-model %s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OnlineZipformer2CtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineZipformer2CtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h new file mode 100644 index 000000000..18115c8fe --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlineZipformer2CtcModelConfig { + std::string model; + + OnlineZipformer2CtcModelConfig() = default; + + explicit OnlineZipformer2CtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc new file mode 100644 index 000000000..1146f00b2 --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc @@ -0,0 +1,464 @@ +// sherpa-onnx/csrc/online-zipformer2-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" + +#include +#include + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +class OnlineZipformer2CtcModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.zipformer2_ctc.model); + Init(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.zipformer2_ctc.model); + Init(buf.data(), buf.size()); + } + } +#endif + + std::vector Forward(Ort::Value features, + std::vector states) { + std::vector inputs; + inputs.reserve(1 + states.size()); + + inputs.push_back(std::move(features)); + for (auto &v : states) { + inputs.push_back(std::move(v)); + } + + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t ChunkLength() const { return T_; } + + int32_t ChunkShift() const { return decode_chunk_len_; } + + OrtAllocator *Allocator() const { return allocator_; } + + // Return a vector containing 3 tensors + // - attn_cache + // - conv_cache + // - offset + std::vector GetInitStates() { + std::vector ans; + ans.reserve(initial_states_.size()); + for (auto &s : initial_states_) { + ans.push_back(View(&s)); + } + return ans; + } + + std::vector StackStates( + std::vector> states) const { + int32_t batch_size = static_cast(states.size()); + int32_t num_encoders = static_cast(num_encoder_layers_.size()); + + std::vector buf(batch_size); + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 1]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 2]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 3]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 4]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 5]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 2]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 1]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + return ans; + } + + std::vector> UnStackStates( + std::vector states) const { + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + assert(states.size() == m * 6 + 2); + + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + int32_t num_encoders = num_encoder_layers_.size(); + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != m; ++i) { + { + auto v = Unbind(allocator_, &states[i * 6], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 1], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 2], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 3], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 4], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 5], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { + auto v = Unbind(allocator_, &states[m * 6], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[m * 6 + 1], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---zipformer2_ctc---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads"); + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); + + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + + { + auto shape = + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); + vocab_size_ = shape[2]; + } + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + fprintf(stderr, "%s: ", name); + for (auto i : v) { + fprintf(stderr, "%d ", i); + } + fprintf(stderr, "\n"); + }; + print(encoder_dims_, "encoder_dims"); + print(query_head_dims_, "query_head_dims"); + print(value_head_dims_, "value_head_dims"); + print(num_heads_, "num_heads"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + SHERPA_ONNX_LOGE("vocab_size_: %d", vocab_size_); + } + + InitStates(); + } + + void InitStates() { + int32_t n = static_cast(encoder_dims_.size()); + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + initial_states_.reserve(m * 6 + 2); + + for (int32_t i = 0; i != n; ++i) { + int32_t num_layers = num_encoder_layers_[i]; + int32_t key_dim = query_head_dims_[i] * num_heads_[i]; + int32_t value_dim = value_head_dims_[i] * num_heads_[i]; + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4; + + for (int32_t j = 0; j != num_layers; ++j) { + { + std::array s{left_context_len_[i], 1, key_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1, 1, left_context_len_[i], + nonlin_attn_head_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], + cnn_module_kernels_[i] / 2}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], + cnn_module_kernels_[i] / 2}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + } + } + + { + std::array s{1, 128, 3, 19}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + initial_states_.push_back(std::move(v)); + } + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector initial_states_; + + std::vector encoder_dims_; + std::vector query_head_dims_; + std::vector value_head_dims_; + std::vector num_heads_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t vocab_size_ = 0; +}; + +OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + AAssetManager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineZipformer2CtcModel::~OnlineZipformer2CtcModel() = default; + +std::vector OnlineZipformer2CtcModel::Forward( + Ort::Value x, std::vector states) const { + return impl_->Forward(std::move(x), std::move(states)); +} + +int32_t OnlineZipformer2CtcModel::VocabSize() const { + return impl_->VocabSize(); +} + +int32_t OnlineZipformer2CtcModel::ChunkLength() const { + return impl_->ChunkLength(); +} + +int32_t OnlineZipformer2CtcModel::ChunkShift() const { + return impl_->ChunkShift(); +} + +OrtAllocator *OnlineZipformer2CtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::vector OnlineZipformer2CtcModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::vector OnlineZipformer2CtcModel::StackStates( + std::vector> states) const { + return impl_->StackStates(std::move(states)); +} + +std::vector> OnlineZipformer2CtcModel::UnStackStates( + std::vector states) const { + return impl_->UnStackStates(std::move(states)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.h b/sherpa-onnx/csrc/online-zipformer2-ctc-model.h new file mode 100644 index 000000000..11b59e2bb --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.h @@ -0,0 +1,80 @@ +// sherpa-onnx/csrc/online-zipformer2-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-ctc-model.h" +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +class OnlineZipformer2CtcModel : public OnlineCtcModel { + public: + explicit OnlineZipformer2CtcModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineZipformer2CtcModel(AAssetManager *mgr, const OnlineModelConfig &config); +#endif + + ~OnlineZipformer2CtcModel() override; + + // A list of tensors. + // See also + // https://github.com/k2-fsa/icefall/pull/1413 + // and + // https://github.com/k2-fsa/icefall/pull/1415 + std::vector GetInitStates() const override; + + std::vector StackStates( + std::vector> states) const override; + + std::vector> UnStackStates( + std::vector states) const override; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + std::vector Forward( + Ort::Value x, std::vector states) const override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + // The model accepts this number of frames before subsampling as input + int32_t ChunkLength() const override; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + int32_t ChunkShift() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 9e771fc5c..89a21e239 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -26,6 +26,8 @@ int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( Usage: +(1) Streaming transducer + ./bin/sherpa-onnx \ --tokens=/path/to/tokens.txt \ --encoder=/path/to/encoder.onnx \ @@ -36,6 +38,30 @@ int main(int32_t argc, char *argv[]) { --decoding-method=greedy_search \ /path/to/foo.wav [bar.wav foobar.wav ...] +(2) Streaming zipformer2 CTC + + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + + ./bin/sherpa-onnx \ + --debug=1 \ + --zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav + +(3) Streaming paraformer + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 + tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 + + ./bin/sherpa-onnx \ + --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ + --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \ + --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav + Note: It supports decoding multiple files in batches Default value for num_threads is 2. diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index e18f0bab3..1d230bca3 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -8,9 +8,6 @@ #include #include -#include "sherpa-onnx/csrc/base64-decode.h" -#include "sherpa-onnx/csrc/onnx-utils.h" - #if __ANDROID_API__ >= 9 #include @@ -18,6 +15,9 @@ #include "android/asset_manager_jni.h" #endif +#include "sherpa-onnx/csrc/base64-decode.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + namespace sherpa_onnx { SymbolTable::SymbolTable(const std::string &filename) { diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 0e2c3794f..9dada0054 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -262,22 +262,34 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "paraformer", "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); jobject paraformer_config = env->GetObjectField(model_config, fid); - jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config); + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); - fid = env->GetFieldID(paraformer_config_config_cls, "encoder", - "Ljava/lang/String;"); + fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(paraformer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.paraformer.encoder = p; env->ReleaseStringUTFChars(s, p); - fid = env->GetFieldID(paraformer_config_config_cls, "decoder", - "Ljava/lang/String;"); + fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(paraformer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.paraformer.decoder = p; env->ReleaseStringUTFChars(s, p); + // streaming zipformer2 CTC + fid = + env->GetFieldID(model_config_cls, "zipformer2Ctc", + "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;"); + jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid); + jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config); + + fid = + env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.zipformer2_ctc.model = p; + env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index e346922e5..85120a806 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -27,6 +27,7 @@ pybind11_add_module(_sherpa_onnx online-stream.cc online-transducer-model-config.cc online-wenet-ctc-model-config.cc + online-zipformer2-ctc-model-config.cc sherpa-onnx.cc silero-vad-model-config.cc vad-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index fa742490f..f8a46a3c0 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -58,6 +58,7 @@ void PybindOfflineModelConfig(py::module *m) { .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) .def_readwrite("model_type", &PyClass::model_type) + .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index bd4c6798f..9a8473510 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -12,6 +12,7 @@ #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" +#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" namespace sherpa_onnx { @@ -19,26 +20,31 @@ void PybindOnlineModelConfig(py::module *m) { PybindOnlineTransducerModelConfig(m); PybindOnlineParaformerModelConfig(m); PybindOnlineWenetCtcModelConfig(m); + PybindOnlineZipformer2CtcModelConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") .def(py::init(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), + py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) + .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) .def_readwrite("model_type", &PyClass::model_type) + .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc b/sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc new file mode 100644 index 000000000..bc3ab1f3f --- /dev/null +++ b/sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" + +namespace sherpa_onnx { + +void PybindOnlineZipformer2CtcModelConfig(py::module *m) { + using PyClass = OnlineZipformer2CtcModelConfig; + py::class_(*m, "OnlineZipformer2CtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h b/sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h new file mode 100644 index 000000000..a4c1afa43 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineZipformer2CtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 0198ffb29..6af47e11f 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -8,11 +8,14 @@ OnlineLMConfig, OnlineModelConfig, OnlineParaformerModelConfig, - OnlineRecognizer as _Recognizer, +) +from _sherpa_onnx import OnlineRecognizer as _Recognizer +from _sherpa_onnx import ( OnlineRecognizerConfig, OnlineStream, OnlineTransducerModelConfig, OnlineWenetCtcModelConfig, + OnlineZipformer2CtcModelConfig, ) @@ -272,6 +275,101 @@ def from_paraformer( self.config = recognizer_config return self + @classmethod + def from_zipformer2_ctc( + cls, + tokens: str, + model: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) + + model_config = OnlineModelConfig( + zipformer2_ctc=zipformer2_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + @classmethod def from_wenet_ctc( cls, @@ -352,7 +450,6 @@ def from_wenet_ctc( tokens=tokens, num_threads=num_threads, provider=provider, - model_type="wenet_ctc", ) feat_config = FeatureExtractorConfig( diff --git a/sherpa-onnx/python/tests/test_online_recognizer.py b/sherpa-onnx/python/tests/test_online_recognizer.py index 7df00fe09..9193fb0f2 100755 --- a/sherpa-onnx/python/tests/test_online_recognizer.py +++ b/sherpa-onnx/python/tests/test_online_recognizer.py @@ -143,6 +143,57 @@ def test_transducer_multiple_files(self): print(f"{wave_filename}\n{result}") print("-" * 10) + def test_zipformer2_ctc(self): + m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13" + for use_int8 in [True, False]: + name = ( + "ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx" + if use_int8 + else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx" + ) + model = f"{d}/{m}/{name}" + tokens = f"{d}/{m}/tokens.txt" + wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav" + wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav" + wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav" + if not Path(model).is_file(): + print("skipping test_zipformer2_ctc()") + return + print(f"testing {model}") + + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + + results = [recognizer.get_result(s) for s in streams] + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + def test_wenet_ctc(self): models = [ "sherpa-onnx-zh-wenet-aishell", diff --git a/swift-api-examples/.gitignore b/swift-api-examples/.gitignore index 95c397c3d..b9d448211 100644 --- a/swift-api-examples/.gitignore +++ b/swift-api-examples/.gitignore @@ -5,3 +5,4 @@ tts vits-vctk sherpa-onnx-paraformer-zh-2023-09-14 !*.sh +*.bak diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index cf7e69aad..397d92e67 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -60,6 +60,14 @@ func sherpaOnnxOnlineParaformerModelConfig( ) } +func sherpaOnnxOnlineZipformer2CtcModelConfig( + model: String = "" +) -> SherpaOnnxOnlineZipformer2CtcModelConfig { + return SherpaOnnxOnlineZipformer2CtcModelConfig( + model: toCPointer(model) + ) +} + /// Return an instance of SherpaOnnxOnlineModelConfig. /// /// Please refer to @@ -75,6 +83,8 @@ func sherpaOnnxOnlineModelConfig( tokens: String, transducer: SherpaOnnxOnlineTransducerModelConfig = sherpaOnnxOnlineTransducerModelConfig(), paraformer: SherpaOnnxOnlineParaformerModelConfig = sherpaOnnxOnlineParaformerModelConfig(), + zipformer2Ctc: SherpaOnnxOnlineZipformer2CtcModelConfig = + sherpaOnnxOnlineZipformer2CtcModelConfig(), numThreads: Int = 1, provider: String = "cpu", debug: Int = 0, @@ -83,6 +93,7 @@ func sherpaOnnxOnlineModelConfig( return SherpaOnnxOnlineModelConfig( transducer: transducer, paraformer: paraformer, + zipformer2_ctc: zipformer2Ctc, tokens: toCPointer(tokens), num_threads: Int32(numThreads), provider: toCPointer(provider), diff --git a/swift-api-examples/decode-file.swift b/swift-api-examples/decode-file.swift index 5406c665c..ab9bc44d6 100644 --- a/swift-api-examples/decode-file.swift +++ b/swift-api-examples/decode-file.swift @@ -13,24 +13,47 @@ extension AVAudioPCMBuffer { } func run() { - let encoder = - "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" - let decoder = - "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" - let joiner = - "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" - let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" - - let transducerConfig = sherpaOnnxOnlineTransducerModelConfig( - encoder: encoder, - decoder: decoder, - joiner: joiner - ) + var modelConfig: SherpaOnnxOnlineModelConfig + var modelType = "zipformer2-ctc" + var filePath: String - let modelConfig = sherpaOnnxOnlineModelConfig( - tokens: tokens, - transducer: transducerConfig - ) + modelType = "transducer" + + if modelType == "transducer" { + filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" + let encoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" + let decoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + let joiner = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" + let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + + let transducerConfig = sherpaOnnxOnlineTransducerModelConfig( + encoder: encoder, + decoder: decoder, + joiner: joiner + ) + + modelConfig = sherpaOnnxOnlineModelConfig( + tokens: tokens, + transducer: transducerConfig + ) + } else { + filePath = + "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" + let model = + "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx" + let tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt" + let zipfomer2CtcModelConfig = sherpaOnnxOnlineZipformer2CtcModelConfig( + model: model + ) + + modelConfig = sherpaOnnxOnlineModelConfig( + tokens: tokens, + zipformer2Ctc: zipfomer2CtcModelConfig + ) + } let featConfig = sherpaOnnxFeatureConfig( sampleRate: 16000, @@ -43,7 +66,6 @@ func run() { let recognizer = SherpaOnnxRecognizer(config: &config) - let filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" let fileURL: NSURL = NSURL(fileURLWithPath: filePath) let audioFile = try! AVAudioFile(forReading: fileURL as URL) diff --git a/swift-api-examples/run-decode-file.sh b/swift-api-examples/run-decode-file.sh index 8f19ccc17..005b388ae 100755 --- a/swift-api-examples/run-decode-file.sh +++ b/swift-api-examples/run-decode-file.sh @@ -20,6 +20,12 @@ if [ ! -d ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 ]; then rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 fi +if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +fi + if [ ! -e ./decode-file ]; then # Note: We use -lc++ to link against libc++ instead of libstdc++ swiftc \ diff --git a/swift-api-examples/run-generate-subtitles.sh b/swift-api-examples/run-generate-subtitles.sh index 6b7b5c49c..43ece12e5 100755 --- a/swift-api-examples/run-generate-subtitles.sh +++ b/swift-api-examples/run-generate-subtitles.sh @@ -22,7 +22,7 @@ if [ ! -d ./sherpa-onnx-whisper-tiny.en ]; then fi if [ ! -f ./silero_vad.onnx ]; then echo "downloading silero_vad" - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx fi if [ ! -e ./generate-subtitles ]; then