diff --git a/.github/scripts/test-cxx-api.sh b/.github/scripts/test-cxx-api.sh index 89b2f1dd6..aedf16133 100755 --- a/.github/scripts/test-cxx-api.sh +++ b/.github/scripts/test-cxx-api.sh @@ -9,6 +9,8 @@ log() { } echo "CXX_STREAMING_ZIPFORMER_EXE is $CXX_STREAMING_ZIPFORMER_EXE" +echo "CXX_WHISPER_EXE is $CXX_WHISPER_EXE" +echo "CXX_SENSE_VOICE_EXE is $CXX_SENSE_VOICE_EXE" echo "PATH: $PATH" log "------------------------------------------------------------" @@ -19,3 +21,22 @@ tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 $CXX_STREAMING_ZIPFORMER_EXE rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 + +log "------------------------------------------------------------" +log "Test Whisper CXX API" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +rm sherpa-onnx-whisper-tiny.en.tar.bz2 +$CXX_WHISPER_EXE +rm -rf sherpa-onnx-whisper-tiny.en + +log "------------------------------------------------------------" +log "Test SenseVoice CXX API" +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + +$CXX_SENSE_VOICE_EXE +rm -rf sherpa-onnx-sense-voice-* diff --git a/.github/workflows/cxx-api.yaml b/.github/workflows/cxx-api.yaml index 1c882f2ff..357aaa227 100644 --- a/.github/workflows/cxx-api.yaml +++ b/.github/workflows/cxx-api.yaml @@ -4,6 +4,7 @@ on: push: branches: - master + - cxx-api-asr-non-streaming paths: - '.github/workflows/cxx-api.yaml' - 'CMakeLists.txt' @@ -82,6 +83,74 @@ jobs: otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib fi + - name: Test whisper + shell: bash + run: | + g++ -std=c++17 -o whisper-cxx-api ./cxx-api-examples/whisper-cxx-api.cc \ + -I ./build/install/include \ + -L ./build/install/lib/ \ + -l sherpa-onnx-cxx-api \ + -l sherpa-onnx-c-api \ + -l onnxruntime + + ls -lh whisper-cxx-api + + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then + ldd ./whisper-cxx-api + echo "----" + readelf -d ./whisper-cxx-api + fi + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 + tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 + rm sherpa-onnx-whisper-tiny.en.tar.bz2 + + ls -lh sherpa-onnx-whisper-tiny.en + echo "---" + ls -lh sherpa-onnx-whisper-tiny.en/test_wavs + + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH + + ./whisper-cxx-api + + rm -rf sherpa-onnx-whisper-* + rm ./whisper-cxx-api + + - name: Test SenseVoice + shell: bash + run: | + g++ -std=c++17 -o sense-voice-cxx-api ./cxx-api-examples/sense-voice-cxx-api.cc \ + -I ./build/install/include \ + -L ./build/install/lib/ \ + -l sherpa-onnx-cxx-api \ + -l sherpa-onnx-c-api \ + -l onnxruntime + + ls -lh sense-voice-cxx-api + + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then + ldd ./sense-voice-cxx-api + echo "----" + readelf -d ./sense-voice-cxx-api + fi + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + + ls -lh sherpa-onnx-sense-voice-* + echo "---" + ls -lh sherpa-onnx-sense-voice-*/test_wavs + + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH + + ./sense-voice-cxx-api + + rm -rf sherpa-onnx-sense-voice-* + rm ./sense-voice-cxx-api + - name: Test streaming zipformer shell: bash run: | diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index e21c452c0..a37f48618 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -155,6 +155,8 @@ jobs: du -h -d1 . export PATH=$PWD/build/bin:$PATH export CXX_STREAMING_ZIPFORMER_EXE=streaming-zipformer-cxx-api + export CXX_WHISPER_EXE=whisper-cxx-api + export CXX_SENSE_VOICE_EXE=sense-voice-cxx-api .github/scripts/test-cxx-api.sh du -h -d1 . diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 7b01846a7..849631015 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -127,6 +127,8 @@ jobs: du -h -d1 . export PATH=$PWD/build/bin:$PATH export CXX_STREAMING_ZIPFORMER_EXE=streaming-zipformer-cxx-api + export CXX_WHISPER_EXE=whisper-cxx-api + export CXX_SENSE_VOICE_EXE=sense-voice-cxx-api .github/scripts/test-cxx-api.sh du -h -d1 . diff --git a/.github/workflows/sanitizer.yaml b/.github/workflows/sanitizer.yaml index f12bafa9a..7cda96899 100644 --- a/.github/workflows/sanitizer.yaml +++ b/.github/workflows/sanitizer.yaml @@ -81,6 +81,7 @@ jobs: run: | export PATH=$PWD/build/bin:$PATH export CXX_STREAMING_ZIPFORMER_EXE=streaming-zipformer-cxx-api + export CXX_WHISPER_EXE=whisper-cxx-api .github/scripts/test-cxx-api.sh diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index f07d6c78c..9435dcefd 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -98,6 +98,8 @@ jobs: run: | export PATH=$PWD/build/bin/Release:$PATH export CXX_STREAMING_ZIPFORMER_EXE=streaming-zipformer-cxx-api.exe + export CXX_WHISPER_EXE=whisper-cxx-api.exe + export CXX_SENSE_VOICE_EXE=sense-voice-cxx-api.exe .github/scripts/test-cxx-api.sh diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index bacbdff01..36089b2dd 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -98,6 +98,8 @@ jobs: run: | export PATH=$PWD/build/bin/Release:$PATH export CXX_STREAMING_ZIPFORMER_EXE=streaming-zipformer-cxx-api.exe + export CXX_WHISPER_EXE=whisper-cxx-api.exe + export CXX_SENSE_VOICE_EXE=sense-voice-cxx-api.exe .github/scripts/test-cxx-api.sh diff --git a/c-api-examples/paraformer-c-api.c b/c-api-examples/paraformer-c-api.c index 345aed555..98d38c789 100644 --- a/c-api-examples/paraformer-c-api.c +++ b/c-api-examples/paraformer-c-api.c @@ -54,7 +54,7 @@ int32_t main() { recognizer_config.decoding_method = "greedy_search"; recognizer_config.model_config = offline_model_config; - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config); if (recognizer == NULL) { @@ -63,7 +63,8 @@ int32_t main() { return -1; } - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, wave->num_samples); diff --git a/c-api-examples/sense-voice-c-api.c b/c-api-examples/sense-voice-c-api.c index 06e890636..25d58219e 100644 --- a/c-api-examples/sense-voice-c-api.c +++ b/c-api-examples/sense-voice-c-api.c @@ -56,7 +56,7 @@ int32_t main() { recognizer_config.decoding_method = "greedy_search"; recognizer_config.model_config = offline_model_config; - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config); if (recognizer == NULL) { @@ -65,7 +65,8 @@ int32_t main() { return -1; } - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, wave->num_samples); diff --git a/c-api-examples/streaming-ctc-buffered-tokens-c-api.c b/c-api-examples/streaming-ctc-buffered-tokens-c-api.c index eb834fb70..98f5b4a60 100644 --- a/c-api-examples/streaming-ctc-buffered-tokens-c-api.c +++ b/c-api-examples/streaming-ctc-buffered-tokens-c-api.c @@ -107,7 +107,8 @@ int32_t main() { return -1; } - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + const SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateOnlineStream(recognizer); const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); int32_t segment_id = 0; diff --git a/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c b/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c index be08f4149..0c382cc94 100644 --- a/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c +++ b/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c @@ -108,7 +108,8 @@ int32_t main() { return -1; } - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + const SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateOnlineStream(recognizer); const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); int32_t segment_id = 0; diff --git a/c-api-examples/streaming-paraformer-c-api.c b/c-api-examples/streaming-paraformer-c-api.c index 11748e084..384ea411b 100644 --- a/c-api-examples/streaming-paraformer-c-api.c +++ b/c-api-examples/streaming-paraformer-c-api.c @@ -66,7 +66,8 @@ int32_t main() { return -1; } - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + const SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateOnlineStream(recognizer); const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); int32_t segment_id = 0; diff --git a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c index c8afca15a..bd76ea8ab 100644 --- a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c +++ b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c @@ -130,7 +130,8 @@ int32_t main() { return -1; } - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + const SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateOnlineStream(recognizer); const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); int32_t segment_id = 0; diff --git a/c-api-examples/streaming-zipformer-c-api.c b/c-api-examples/streaming-zipformer-c-api.c index a38d22f02..6011186ea 100644 --- a/c-api-examples/streaming-zipformer-c-api.c +++ b/c-api-examples/streaming-zipformer-c-api.c @@ -72,7 +72,8 @@ int32_t main() { return -1; } - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + const SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateOnlineStream(recognizer); const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); int32_t segment_id = 0; diff --git a/c-api-examples/telespeech-c-api.c b/c-api-examples/telespeech-c-api.c index fa7824c3b..9bf34b1a8 100644 --- a/c-api-examples/telespeech-c-api.c +++ b/c-api-examples/telespeech-c-api.c @@ -49,7 +49,7 @@ int32_t main() { recognizer_config.decoding_method = "greedy_search"; recognizer_config.model_config = offline_model_config; - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config); if (recognizer == NULL) { @@ -58,7 +58,8 @@ int32_t main() { return -1; } - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, wave->num_samples); diff --git a/c-api-examples/vad-sense-voice-c-api.c b/c-api-examples/vad-sense-voice-c-api.c index 172ec0a79..3049c9572 100644 --- a/c-api-examples/vad-sense-voice-c-api.c +++ b/c-api-examples/vad-sense-voice-c-api.c @@ -66,7 +66,7 @@ int32_t main() { recognizer_config.decoding_method = "greedy_search"; recognizer_config.model_config = offline_model_config; - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config); if (recognizer == NULL) { @@ -108,8 +108,9 @@ int32_t main() { const SherpaOnnxSpeechSegment *segment = SherpaOnnxVoiceActivityDetectorFront(vad); - SherpaOnnxOfflineStream *stream = + const SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, segment->samples, segment->n); @@ -138,7 +139,9 @@ int32_t main() { const SherpaOnnxSpeechSegment *segment = SherpaOnnxVoiceActivityDetectorFront(vad); - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); + SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, segment->samples, segment->n); diff --git a/c-api-examples/whisper-c-api.c b/c-api-examples/whisper-c-api.c index 3a71bcb03..2e795b025 100644 --- a/c-api-examples/whisper-c-api.c +++ b/c-api-examples/whisper-c-api.c @@ -58,7 +58,7 @@ int32_t main() { recognizer_config.decoding_method = "greedy_search"; recognizer_config.model_config = offline_model_config; - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config); if (recognizer == NULL) { @@ -69,7 +69,8 @@ int32_t main() { return -1; } - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, wave->num_samples); diff --git a/c-api-examples/zipformer-c-api.c b/c-api-examples/zipformer-c-api.c index 35393b19c..4db22fc38 100644 --- a/c-api-examples/zipformer-c-api.c +++ b/c-api-examples/zipformer-c-api.c @@ -60,7 +60,7 @@ int32_t main() { recognizer_config.decoding_method = "greedy_search"; recognizer_config.model_config = offline_model_config; - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config); if (recognizer == NULL) { @@ -69,7 +69,8 @@ int32_t main() { return -1; } - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, wave->num_samples); diff --git a/cxx-api-examples/CMakeLists.txt b/cxx-api-examples/CMakeLists.txt index b51d2e50d..7c9853080 100644 --- a/cxx-api-examples/CMakeLists.txt +++ b/cxx-api-examples/CMakeLists.txt @@ -2,3 +2,9 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) + +add_executable(whisper-cxx-api ./whisper-cxx-api.cc) +target_link_libraries(whisper-cxx-api sherpa-onnx-cxx-api) + +add_executable(sense-voice-cxx-api ./sense-voice-cxx-api.cc) +target_link_libraries(sense-voice-cxx-api sherpa-onnx-cxx-api) diff --git a/cxx-api-examples/sense-voice-cxx-api.cc b/cxx-api-examples/sense-voice-cxx-api.cc new file mode 100644 index 000000000..15d752058 --- /dev/null +++ b/cxx-api-examples/sense-voice-cxx-api.cc @@ -0,0 +1,78 @@ +// cxx-api-examples/sense-voice-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use sense voice with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-onnx/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_onnx::cxx; + OfflineRecognizerConfig config; + + config.model_config.sense_voice.model = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"; + config.model_config.sense_voice.use_itn = true; + config.model_config.sense_voice.language = "auto"; + config.model_config.tokens = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OfflineRecognizer recongizer = OfflineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav"; + + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OfflineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + recongizer.Decode(&stream); + + OfflineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/cxx-api-examples/streaming-zipformer-cxx-api.cc b/cxx-api-examples/streaming-zipformer-cxx-api.cc index 4f38647f4..5a49dcfc9 100644 --- a/cxx-api-examples/streaming-zipformer-cxx-api.cc +++ b/cxx-api-examples/streaming-zipformer-cxx-api.cc @@ -66,6 +66,8 @@ int32_t main() { OnlineStream stream = recongizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); + stream.InputFinished(); + while (recongizer.IsReady(&stream)) { recongizer.Decode(&stream); } diff --git a/cxx-api-examples/whisper-cxx-api.cc b/cxx-api-examples/whisper-cxx-api.cc new file mode 100644 index 000000000..82f0ddb53 --- /dev/null +++ b/cxx-api-examples/whisper-cxx-api.cc @@ -0,0 +1,76 @@ +// cxx-api-examples/whisper-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use whisper with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +// tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +// rm sherpa-onnx-whisper-tiny.en.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-onnx/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_onnx::cxx; + OfflineRecognizerConfig config; + + config.model_config.whisper.encoder = + "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"; + config.model_config.whisper.decoder = + "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"; + config.model_config.tokens = + "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OfflineRecognizer recongizer = OfflineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"; + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OfflineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + recongizer.Decode(&stream); + + OfflineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/ffmpeg-examples/sherpa-onnx-ffmpeg.c b/ffmpeg-examples/sherpa-onnx-ffmpeg.c index 31e7491f0..cf3614518 100644 --- a/ffmpeg-examples/sherpa-onnx-ffmpeg.c +++ b/ffmpeg-examples/sherpa-onnx-ffmpeg.c @@ -320,7 +320,8 @@ int main(int argc, char **argv) { const SherpaOnnxOnlineRecognizer *recognizer = SherpaOnnxCreateOnlineRecognizer(&config); - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + const SherpaOnnxOnlineStream *stream = + SherpaOnnxCreateOnlineStream(recognizer); const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); int32_t segment_id = 0; diff --git a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp index c559c9321..462959247 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp +++ b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp @@ -256,7 +256,7 @@ void CNonStreamingSpeechRecognitionDlg::OnBnClickedOk() { } pa_stream_ = nullptr; - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer_); + const SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer_); SherpaOnnxAcceptWaveformOffline(stream, config_.feat_config.sample_rate, samples_.data(), static_cast(samples_.size())); diff --git a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h index 77a8992e9..19ab83880 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h +++ b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h @@ -48,7 +48,7 @@ class CNonStreamingSpeechRecognitionDlg : public CDialogEx { private: Microphone mic_; - SherpaOnnxOfflineRecognizer *recognizer_ = nullptr; + const SherpaOnnxOfflineRecognizer *recognizer_ = nullptr; SherpaOnnxOfflineRecognizerConfig config_; PaStream *pa_stream_ = nullptr; diff --git a/scripts/node-addon-api/src/non-streaming-asr.cc b/scripts/node-addon-api/src/non-streaming-asr.cc index b24a37beb..86badc0ff 100644 --- a/scripts/node-addon-api/src/non-streaming-asr.cc +++ b/scripts/node-addon-api/src/non-streaming-asr.cc @@ -203,7 +203,7 @@ CreateOfflineRecognizerWrapper(const Napi::CallbackInfo &info) { SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fars, ruleFars); SHERPA_ONNX_ASSIGN_ATTR_FLOAT(blank_penalty, blankPenalty); - SherpaOnnxOfflineRecognizer *recognizer = + const SherpaOnnxOfflineRecognizer *recognizer = SherpaOnnxCreateOfflineRecognizer(&c); if (c.model_config.transducer.encoder) { @@ -306,7 +306,7 @@ CreateOfflineRecognizerWrapper(const Napi::CallbackInfo &info) { } return Napi::External::New( - env, recognizer, + env, const_cast(recognizer), [](Napi::Env env, SherpaOnnxOfflineRecognizer *recognizer) { SherpaOnnxDestroyOfflineRecognizer(recognizer); }); @@ -336,10 +336,12 @@ static Napi::External CreateOfflineStreamWrapper( SherpaOnnxOfflineRecognizer *recognizer = info[0].As>().Data(); - SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer); + const SherpaOnnxOfflineStream *stream = + SherpaOnnxCreateOfflineStream(recognizer); return Napi::External::New( - env, stream, [](Napi::Env env, SherpaOnnxOfflineStream *stream) { + env, const_cast(stream), + [](Napi::Env env, SherpaOnnxOfflineStream *stream) { SherpaOnnxDestroyOfflineStream(stream); }); } diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 653355a66..d7fa383be 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -168,14 +168,14 @@ void SherpaOnnxDestroyOnlineRecognizer( delete recognizer; } -SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStream( +const SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStream( const SherpaOnnxOnlineRecognizer *recognizer) { SherpaOnnxOnlineStream *stream = new SherpaOnnxOnlineStream(recognizer->impl->CreateStream()); return stream; } -SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStreamWithHotwords( +const SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStreamWithHotwords( const SherpaOnnxOnlineRecognizer *recognizer, const char *hotwords) { SherpaOnnxOnlineStream *stream = new SherpaOnnxOnlineStream(recognizer->impl->CreateStream(hotwords)); @@ -351,7 +351,7 @@ struct SherpaOnnxOfflineStream { static sherpa_onnx::OfflineRecognizerConfig convertConfig( const SherpaOnnxOfflineRecognizerConfig *config); -SherpaOnnxOfflineRecognizer *SherpaOnnxCreateOfflineRecognizer( +const SherpaOnnxOfflineRecognizer *SherpaOnnxCreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config) { sherpa_onnx::OfflineRecognizerConfig recognizer_config = convertConfig(config); @@ -490,11 +490,11 @@ void SherpaOnnxOfflineRecognizerSetConfig( } void SherpaOnnxDestroyOfflineRecognizer( - SherpaOnnxOfflineRecognizer *recognizer) { + const SherpaOnnxOfflineRecognizer *recognizer) { delete recognizer; } -SherpaOnnxOfflineStream *SherpaOnnxCreateOfflineStream( +const SherpaOnnxOfflineStream *SherpaOnnxCreateOfflineStream( const SherpaOnnxOfflineRecognizer *recognizer) { SherpaOnnxOfflineStream *stream = new SherpaOnnxOfflineStream(recognizer->impl->CreateStream()); @@ -518,8 +518,8 @@ void SherpaOnnxDecodeOfflineStream( } void SherpaOnnxDecodeMultipleOfflineStreams( - SherpaOnnxOfflineRecognizer *recognizer, SherpaOnnxOfflineStream **streams, - int32_t n) { + const SherpaOnnxOfflineRecognizer *recognizer, + const SherpaOnnxOfflineStream **streams, int32_t n) { std::vector ss(n); for (int32_t i = 0; i != n; ++i) { ss[i] = streams[i]->impl.get(); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 0a01379d4..e5fc92eb1 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -220,7 +220,7 @@ SHERPA_ONNX_API void SherpaOnnxDestroyOnlineRecognizer( /// @param recognizer A pointer returned by SherpaOnnxCreateOnlineRecognizer() /// @return Return a pointer to an OnlineStream. The user has to invoke /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStream( +SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStream( const SherpaOnnxOnlineRecognizer *recognizer); /// Create an online stream for accepting wave samples with the specified hot @@ -229,7 +229,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateOnlineStream( /// @param recognizer A pointer returned by SherpaOnnxCreateOnlineRecognizer() /// @return Return a pointer to an OnlineStream. The user has to invoke /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxOnlineStream * +SHERPA_ONNX_API const SherpaOnnxOnlineStream * SherpaOnnxCreateOnlineStreamWithHotwords( const SherpaOnnxOnlineRecognizer *recognizer, const char *hotwords); @@ -453,7 +453,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream; /// @return Return a pointer to the recognizer. The user has to invoke // SherpaOnnxDestroyOfflineRecognizer() to free it to avoid memory // leak. -SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *SherpaOnnxCreateOfflineRecognizer( +SHERPA_ONNX_API const SherpaOnnxOfflineRecognizer * +SherpaOnnxCreateOfflineRecognizer( const SherpaOnnxOfflineRecognizerConfig *config); /// @param config Config for the recognizer. @@ -465,14 +466,14 @@ SHERPA_ONNX_API void SherpaOnnxOfflineRecognizerSetConfig( /// /// @param p A pointer returned by SherpaOnnxCreateOfflineRecognizer() SHERPA_ONNX_API void SherpaOnnxDestroyOfflineRecognizer( - SherpaOnnxOfflineRecognizer *recognizer); + const SherpaOnnxOfflineRecognizer *recognizer); /// Create an offline stream for accepting wave samples. /// /// @param recognizer A pointer returned by SherpaOnnxCreateOfflineRecognizer() /// @return Return a pointer to an OfflineStream. The user has to invoke /// SherpaOnnxDestroyOfflineStream() to free it to avoid memory leak. -SHERPA_ONNX_API SherpaOnnxOfflineStream *SherpaOnnxCreateOfflineStream( +SHERPA_ONNX_API const SherpaOnnxOfflineStream *SherpaOnnxCreateOfflineStream( const SherpaOnnxOfflineRecognizer *recognizer); /// Destroy an offline stream. @@ -518,8 +519,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeOfflineStream( /// by SherpaOnnxCreateOfflineStream(). /// @param n Number of entries in the given streams. SHERPA_ONNX_API void SherpaOnnxDecodeMultipleOfflineStreams( - SherpaOnnxOfflineRecognizer *recognizer, SherpaOnnxOfflineStream **streams, - int32_t n); + const SherpaOnnxOfflineRecognizer *recognizer, + const SherpaOnnxOfflineStream **streams, int32_t n); SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { const char *text; diff --git a/sherpa-onnx/c-api/cxx-api.cc b/sherpa-onnx/c-api/cxx-api.cc index 243e99738..262ad3cc9 100644 --- a/sherpa-onnx/c-api/cxx-api.cc +++ b/sherpa-onnx/c-api/cxx-api.cc @@ -36,6 +36,10 @@ void OnlineStream::AcceptWaveform(int32_t sample_rate, const float *samples, SherpaOnnxOnlineStreamAcceptWaveform(p_, sample_rate, samples, n); } +void OnlineStream::InputFinished() const { + SherpaOnnxOnlineStreamInputFinished(p_); +} + OnlineRecognizer OnlineRecognizer::Create( const OnlineRecognizerConfig &config) { struct SherpaOnnxOnlineRecognizerConfig c; @@ -119,6 +123,14 @@ void OnlineRecognizer::Decode(const OnlineStream *s) const { SherpaOnnxDecodeOnlineStream(p_, s->Get()); } +void OnlineRecognizer::Reset(const OnlineStream *s) const { + SherpaOnnxOnlineStreamReset(p_, s->Get()); +} + +bool OnlineRecognizer::IsEndpoint(const OnlineStream *s) const { + return SherpaOnnxOnlineStreamIsEndpoint(p_, s->Get()); +} + void OnlineRecognizer::Decode(const OnlineStream *ss, int32_t n) const { if (n <= 0) { return; @@ -156,4 +168,138 @@ OnlineRecognizerResult OnlineRecognizer::GetResult( return ans; } +// ============================================================================ +// Non-streaming ASR +// ============================================================================ +OfflineStream::OfflineStream(const SherpaOnnxOfflineStream *p) + : MoveOnly(p) {} + +void OfflineStream::Destroy(const SherpaOnnxOfflineStream *p) const { + SherpaOnnxDestroyOfflineStream(p); +} + +void OfflineStream::AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const { + SherpaOnnxAcceptWaveformOffline(p_, sample_rate, samples, n); +} + +OfflineRecognizer OfflineRecognizer::Create( + const OfflineRecognizerConfig &config) { + struct SherpaOnnxOfflineRecognizerConfig c; + memset(&c, 0, sizeof(c)); + + c.feat_config.sample_rate = config.feat_config.sample_rate; + c.feat_config.feature_dim = config.feat_config.feature_dim; + c.model_config.transducer.encoder = + config.model_config.transducer.encoder.c_str(); + c.model_config.transducer.decoder = + config.model_config.transducer.decoder.c_str(); + c.model_config.transducer.joiner = + config.model_config.transducer.joiner.c_str(); + + c.model_config.paraformer.model = + config.model_config.paraformer.model.c_str(); + + c.model_config.nemo_ctc.model = config.model_config.nemo_ctc.model.c_str(); + + c.model_config.whisper.encoder = config.model_config.whisper.encoder.c_str(); + c.model_config.whisper.decoder = config.model_config.whisper.decoder.c_str(); + c.model_config.whisper.language = + config.model_config.whisper.language.c_str(); + c.model_config.whisper.task = config.model_config.whisper.task.c_str(); + c.model_config.whisper.tail_paddings = + config.model_config.whisper.tail_paddings; + + c.model_config.tdnn.model = config.model_config.tdnn.model.c_str(); + + c.model_config.tokens = config.model_config.tokens.c_str(); + c.model_config.num_threads = config.model_config.num_threads; + c.model_config.debug = config.model_config.debug; + c.model_config.provider = config.model_config.provider.c_str(); + c.model_config.model_type = config.model_config.model_type.c_str(); + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str(); + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str(); + c.model_config.telespeech_ctc = config.model_config.telespeech_ctc.c_str(); + + c.model_config.sense_voice.model = + config.model_config.sense_voice.model.c_str(); + c.model_config.sense_voice.language = + config.model_config.sense_voice.language.c_str(); + c.model_config.sense_voice.use_itn = config.model_config.sense_voice.use_itn; + + c.lm_config.model = config.lm_config.model.c_str(); + c.lm_config.scale = config.lm_config.scale; + + c.decoding_method = config.decoding_method.c_str(); + c.max_active_paths = config.max_active_paths; + c.hotwords_file = config.hotwords_file.c_str(); + c.hotwords_score = config.hotwords_score; + + c.rule_fsts = config.rule_fsts.c_str(); + c.rule_fars = config.rule_fars.c_str(); + + c.blank_penalty = config.blank_penalty; + + auto p = SherpaOnnxCreateOfflineRecognizer(&c); + return OfflineRecognizer(p); +} + +OfflineRecognizer::OfflineRecognizer(const SherpaOnnxOfflineRecognizer *p) + : MoveOnly(p) {} + +void OfflineRecognizer::Destroy(const SherpaOnnxOfflineRecognizer *p) const { + SherpaOnnxDestroyOfflineRecognizer(p_); +} + +OfflineStream OfflineRecognizer::CreateStream() const { + auto p = SherpaOnnxCreateOfflineStream(p_); + return OfflineStream{p}; +} + +void OfflineRecognizer::Decode(const OfflineStream *s) const { + SherpaOnnxDecodeOfflineStream(p_, s->Get()); +} + +void OfflineRecognizer::Decode(const OfflineStream *ss, int32_t n) const { + if (n <= 0) { + return; + } + + std::vector streams(n); + for (int32_t i = 0; i != n; ++i) { + streams[i] = ss[i].Get(); + } + + SherpaOnnxDecodeMultipleOfflineStreams(p_, streams.data(), n); +} + +OfflineRecognizerResult OfflineRecognizer::GetResult( + const OfflineStream *s) const { + auto r = SherpaOnnxGetOfflineStreamResult(s->Get()); + + OfflineRecognizerResult ans; + if (r) { + ans.text = r->text; + + if (r->timestamps) { + ans.timestamps.resize(r->count); + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data()); + } + + ans.tokens.resize(r->count); + for (int32_t i = 0; i != r->count; ++i) { + ans.tokens[i] = r->tokens_arr[i]; + } + + ans.json = r->json; + ans.lang = r->lang ? r->lang : ""; + ans.emotion = r->emotion ? r->emotion : ""; + ans.event = r->event ? r->event : ""; + } + + SherpaOnnxDestroyOfflineRecognizerResult(r); + + return ans; +} + } // namespace sherpa_onnx::cxx diff --git a/sherpa-onnx/c-api/cxx-api.h b/sherpa-onnx/c-api/cxx-api.h index 416333b7f..b078f5c73 100644 --- a/sherpa-onnx/c-api/cxx-api.h +++ b/sherpa-onnx/c-api/cxx-api.h @@ -13,6 +13,9 @@ namespace sherpa_onnx::cxx { +// ============================================================================ +// Streaming ASR +// ============================================================================ struct SHERPA_ONNX_API OnlineTransducerModelConfig { std::string encoder; std::string decoder; @@ -148,6 +151,8 @@ class SHERPA_ONNX_API OnlineStream void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) const; + void InputFinished() const; + void Destroy(const SherpaOnnxOnlineStream *p) const; }; @@ -170,10 +175,134 @@ class SHERPA_ONNX_API OnlineRecognizer OnlineRecognizerResult GetResult(const OnlineStream *s) const; + void Reset(const OnlineStream *s) const; + + bool IsEndpoint(const OnlineStream *s) const; + private: explicit OnlineRecognizer(const SherpaOnnxOnlineRecognizer *p); }; +// ============================================================================ +// Non-streaming ASR +// ============================================================================ +struct SHERPA_ONNX_API OfflineTransducerModelConfig { + std::string encoder; + std::string decoder; + std::string joiner; +}; + +struct SHERPA_ONNX_API OfflineParaformerModelConfig { + std::string model; +}; + +struct SHERPA_ONNX_API OfflineNemoEncDecCtcModelConfig { + std::string model; +}; + +struct SHERPA_ONNX_API OfflineWhisperModelConfig { + std::string encoder; + std::string decoder; + std::string language; + std::string task = "transcribe"; + int32_t tail_paddings = -1; +}; + +struct SHERPA_ONNX_API OfflineTdnnModelConfig { + std::string model; +}; + +struct SHERPA_ONNX_API SherpaOnnxOfflineLMConfig { + std::string model; + float scale = 1.0; +}; + +struct SHERPA_ONNX_API OfflineSenseVoiceModelConfig { + std::string model; + std::string language; + bool use_itn = false; +}; + +struct SHERPA_ONNX_API OfflineModelConfig { + OfflineTransducerModelConfig transducer; + OfflineParaformerModelConfig paraformer; + OfflineNemoEncDecCtcModelConfig nemo_ctc; + OfflineWhisperModelConfig whisper; + OfflineTdnnModelConfig tdnn; + + std::string tokens; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + std::string model_type; + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + std::string telespeech_ctc; + OfflineSenseVoiceModelConfig sense_voice; +}; + +struct SHERPA_ONNX_API OfflineLMConfig { + std::string model; + float scale = 1.0; +}; + +struct SHERPA_ONNX_API OfflineRecognizerConfig { + FeatureConfig feat_config; + OfflineModelConfig model_config; + OfflineLMConfig lm_config; + + std::string decoding_method = "greedy_search"; + int32_t max_active_paths = 4; + + std::string hotwords_file; + + float hotwords_score = 1.5; + std::string rule_fsts; + std::string rule_fars; + float blank_penalty = 0; +}; + +struct SHERPA_ONNX_API OfflineRecognizerResult { + std::string text; + std::vector timestamps; + std::vector tokens; + std::string json; + std::string lang; + std::string emotion; + std::string event; +}; + +class SHERPA_ONNX_API OfflineStream + : public MoveOnly { + public: + explicit OfflineStream(const SherpaOnnxOfflineStream *p); + + void AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const; + + void Destroy(const SherpaOnnxOfflineStream *p) const; +}; + +class SHERPA_ONNX_API OfflineRecognizer + : public MoveOnly { + public: + static OfflineRecognizer Create(const OfflineRecognizerConfig &config); + + void Destroy(const SherpaOnnxOfflineRecognizer *p) const; + + OfflineStream CreateStream() const; + + void Decode(const OfflineStream *s) const; + + void Decode(const OfflineStream *ss, int32_t n) const; + + OfflineRecognizerResult GetResult(const OfflineStream *s) const; + + private: + explicit OfflineRecognizer(const SherpaOnnxOfflineRecognizer *p); +}; + } // namespace sherpa_onnx::cxx #endif // SHERPA_ONNX_C_API_CXX_API_H_ + // diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 399dab49e..cf59cb539 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -30,9 +30,13 @@ std::unique_ptr OnlineRecognizerImpl::Create( if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_ERROR); + Ort::SessionOptions sess_opts; + sess_opts.SetIntraOpNumThreads(1); + sess_opts.SetInterOpNumThreads(1); + auto decoder_model = ReadFile(config.model_config.transducer.decoder); - auto sess = std::make_unique( - env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + auto sess = std::make_unique(env, decoder_model.data(), + decoder_model.size(), sess_opts); size_t node_count = sess->GetOutputCount(); @@ -63,9 +67,13 @@ std::unique_ptr OnlineRecognizerImpl::Create( if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_ERROR); + Ort::SessionOptions sess_opts; + sess_opts.SetIntraOpNumThreads(1); + sess_opts.SetInterOpNumThreads(1); + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); - auto sess = std::make_unique( - env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + auto sess = std::make_unique(env, decoder_model.data(), + decoder_model.size(), sess_opts); size_t node_count = sess->GetOutputCount();