From 2b0152d2a2f2b2dbb775520023f30ea435123b6f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 28 Aug 2023 19:39:22 +0800 Subject: [PATCH 1/9] Fix context graph (#292) --- sherpa-onnx/csrc/context-graph-test.cc | 5 +++-- sherpa-onnx/csrc/context-graph.cc | 17 ++++------------- sherpa-onnx/csrc/context-graph.h | 6 +++--- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/sherpa-onnx/csrc/context-graph-test.cc b/sherpa-onnx/csrc/context-graph-test.cc index 97d034431..0e7e9b5cd 100644 --- a/sherpa-onnx/csrc/context-graph-test.cc +++ b/sherpa-onnx/csrc/context-graph-test.cc @@ -22,8 +22,9 @@ TEST(ContextGraph, TestBasic) { auto context_graph = ContextGraph(contexts, 1); auto queries = std::map{ - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, - {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; for (const auto &iter : queries) { float total_scores = 0; diff --git a/sherpa-onnx/csrc/context-graph.cc b/sherpa-onnx/csrc/context-graph.cc index bc3a1e3ec..05ca04f00 100644 --- a/sherpa-onnx/csrc/context-graph.cc +++ b/sherpa-onnx/csrc/context-graph.cc @@ -19,7 +19,7 @@ void ContextGraph::Build( bool is_end = j == token_ids[i].size() - 1; node->next[token] = std::make_unique( token, context_score_, node->node_score + context_score_, - is_end ? 0 : node->local_node_score + context_score_, is_end); + is_end ? node->node_score + context_score_ : 0, is_end); } node = node->next[token].get(); } @@ -34,7 +34,6 @@ std::pair ContextGraph::ForwardOneStep( if (1 == state->next.count(token)) { node = state->next.at(token).get(); score = node->token_score; - if (state->is_end) score += state->node_score; } else { node = state->fail; while (0 == node->next.count(token)) { @@ -44,24 +43,15 @@ std::pair ContextGraph::ForwardOneStep( if (1 == node->next.count(token)) { node = node->next.at(token).get(); } - score = node->node_score - state->local_node_score; + score = node->node_score - state->node_score; } SHERPA_ONNX_CHECK(nullptr != node); - float matched_score = 0; - auto output = node->output; - while (nullptr != output) { - matched_score += output->node_score; - output = output->output; - } - return std::make_pair(score + matched_score, node); + return std::make_pair(score + node->output_score, node); } std::pair ContextGraph::Finalize( const ContextState *state) const { float score = -state->node_score; - if (state->is_end) { - score = 0; - } return std::make_pair(score, root_.get()); } @@ -98,6 +88,7 @@ void ContextGraph::FillFailOutput() const { } } kv.second->output = output; + kv.second->output_score += output == nullptr ? 0 : output->output_score; node_queue.push(kv.second.get()); } } diff --git a/sherpa-onnx/csrc/context-graph.h b/sherpa-onnx/csrc/context-graph.h index db16ce663..570106896 100644 --- a/sherpa-onnx/csrc/context-graph.h +++ b/sherpa-onnx/csrc/context-graph.h @@ -21,7 +21,7 @@ struct ContextState { int32_t token; float token_score; float node_score; - float local_node_score; + float output_score; bool is_end; std::unordered_map> next; const ContextState *fail = nullptr; @@ -29,11 +29,11 @@ struct ContextState { ContextState() = default; ContextState(int32_t token, float token_score, float node_score, - float local_node_score, bool is_end) + float output_score, bool is_end) : token(token), token_score(token_score), node_score(node_score), - local_node_score(local_node_score), + output_score(output_score), is_end(is_end) {} }; From a0a747a0c0df93cad346144d0f8f9c43bcacca83 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Aug 2023 14:41:04 +0800 Subject: [PATCH 2/9] add endpointing for online websocket server (#294) --- .../csrc/online-recognizer-transducer-impl.h | 17 +++++++++++++++-- sherpa-onnx/csrc/online-stream.cc | 7 +++++++ sherpa-onnx/csrc/online-stream.h | 2 ++ .../csrc/online-websocket-server-impl.cc | 3 +++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 625d02b1c..27f58687b 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -26,7 +26,8 @@ namespace sherpa_onnx { static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, const SymbolTable &sym_table, int32_t frame_shift_ms, - int32_t subsampling_factor) { + int32_t subsampling_factor, + int32_t segment) { OnlineRecognizerResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); @@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.timestamps.push_back(time); } + r.segment = segment; + return r; } @@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; int32_t subsampling_factor = 4; - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment()); } bool IsEndpoint(OnlineStream *s) const override { @@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } void Reset(OnlineStream *s) const override { + { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetResult(); + if (!r.tokens.empty() && r.tokens.back() != 0) { + s->GetCurrentSegment() += 1; + } + } + // we keep the decoder_out decoder_->UpdateDecoderOut(&s->GetResult()); Ort::Value decoder_out = std::move(s->GetResult().decoder_out); diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 8960ed13e..39dfd7966 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -43,6 +43,8 @@ class OnlineStream::Impl { int32_t &GetNumProcessedFrames() { return num_processed_frames_; } + int32_t &GetCurrentSegment() { return segment_; } + void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } OnlineTransducerDecoderResult &GetResult() { return result_; } @@ -83,6 +85,7 @@ class OnlineStream::Impl { ContextGraphPtr context_graph_; int32_t num_processed_frames_ = 0; // before subsampling int32_t start_frame_index_ = 0; // never reset + int32_t segment_ = 0; OnlineTransducerDecoderResult result_; std::vector states_; std::vector paraformer_feat_cache_; @@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() { return impl_->GetNumProcessedFrames(); } +int32_t &OnlineStream::GetCurrentSegment() { + return impl_->GetCurrentSegment(); +} + void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { impl_->SetResult(r); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index ae920c1d3..6b7a96c44 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -68,6 +68,8 @@ class OnlineStream { // The returned reference is valid as long as this object is alive. int32_t &GetNumProcessedFrames(); + int32_t &GetCurrentSegment(); + void SetResult(const OnlineTransducerDecoderResult &r); OnlineTransducerDecoderResult &GetResult(); diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.cc b/sherpa-onnx/csrc/online-websocket-server-impl.cc index a62bef25f..9c265de8e 100644 --- a/sherpa-onnx/csrc/online-websocket-server-impl.cc +++ b/sherpa-onnx/csrc/online-websocket-server-impl.cc @@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() { for (auto c : c_vec) { auto result = recognizer_->GetResult(c->s.get()); + if (recognizer_->IsEndpoint(c->s.get())) { + recognizer_->Reset(c->s.get()); + } asio::post(server_->GetConnectionContext(), [this, hdl = c->hdl, str = result.AsJsonString()]() { From ffeff3b8a3fdf695a318438d31f14c1d8246f171 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 Sep 2023 11:29:00 +0800 Subject: [PATCH 3/9] Fix a typo for Go (#298) --- .github/workflows/android.yaml | 10 ++++------ .github/workflows/arm-linux-gnueabihf.yaml | 17 +++-------------- .github/workflows/build-xcframework.yaml | 13 ++----------- .github/workflows/dot-net.yaml | 16 ++-------------- .github/workflows/jni.yaml | 2 ++ .github/workflows/linux-gpu.yaml | 16 ++-------------- .github/workflows/linux.yaml | 16 ++-------------- .github/workflows/macos.yaml | 18 ++++-------------- .github/workflows/mfc.yaml | 16 ++-------------- .github/workflows/windows-x64-cuda.yaml | 16 ++-------------- .github/workflows/windows-x64.yaml | 16 ++-------------- .github/workflows/windows-x86.yaml | 18 ++++-------------- CMakeLists.txt | 2 +- scripts/go/sherpa_onnx.go | 2 +- 14 files changed, 33 insertions(+), 145 deletions(-) diff --git a/.github/workflows/android.yaml b/.github/workflows/android.yaml index 3f9f63b08..c7995ff4a 100644 --- a/.github/workflows/android.yaml +++ b/.github/workflows/android.yaml @@ -11,6 +11,8 @@ on: - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/jni/*' - 'build-android*.sh' + tags: + - '*' pull_request: branches: - master @@ -22,10 +24,6 @@ on: - 'sherpa-onnx/jni/*' - 'build-android*.sh' - release: - types: - - published - workflow_dispatch: inputs: release: @@ -112,7 +110,7 @@ jobs: # https://huggingface.co/docs/hub/spaces-github-actions - name: Publish to huggingface - if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | @@ -136,7 +134,7 @@ jobs: git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs main - name: Release android libs - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/arm-linux-gnueabihf.yaml b/.github/workflows/arm-linux-gnueabihf.yaml index 0c1efde92..6d6834c25 100644 --- a/.github/workflows/arm-linux-gnueabihf.yaml +++ b/.github/workflows/arm-linux-gnueabihf.yaml @@ -11,6 +11,8 @@ on: - 'cmake/**' - 'sherpa-onnx/csrc/*' - 'toolchains/arm-linux-gnueabihf.toolchain.cmake' + tags: + - '*' pull_request: branches: - master @@ -20,20 +22,8 @@ on: - 'cmake/**' - 'sherpa-onnx/csrc/*' - 'toolchains/arm-linux-gnueabihf.toolchain.cmake' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: arm-linux-gnueabihf-${{ github.ref }} @@ -131,7 +121,6 @@ jobs: export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH export QEMU_LD_PREFIX=$GITHUB_WORKSPACE/toolchain/arm-linux-gnueabihf/libc - ls -lh ./build-arm-linux-gnueabihf/bin qemu-arm ./build-arm-linux-gnueabihf/bin/sherpa-onnx --help @@ -156,7 +145,7 @@ jobs: path: sherpa-onnx-*linux-arm-gnueabihf.tar.bz2 - name: Release pre-compiled binaries and libs for arm linux gnueabihf - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/build-xcframework.yaml b/.github/workflows/build-xcframework.yaml index 576b2d083..8c2ffaa09 100644 --- a/.github/workflows/build-xcframework.yaml +++ b/.github/workflows/build-xcframework.yaml @@ -14,15 +14,6 @@ on: - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: build-xcframework-${{ github.ref }} @@ -76,7 +67,7 @@ jobs: # https://huggingface.co/docs/hub/spaces-github-actions - name: Publish to huggingface - if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | @@ -100,7 +91,7 @@ jobs: git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs main - name: Release xcframework - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/dot-net.yaml b/.github/workflows/dot-net.yaml index bcff2f312..7c33c9090 100644 --- a/.github/workflows/dot-net.yaml +++ b/.github/workflows/dot-net.yaml @@ -6,20 +6,8 @@ on: - dot-net tags: - '*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: dot-net-${{ github.ref }} @@ -138,7 +126,7 @@ jobs: path: scripts/dotnet/packages/*.nupkg - name: publish .Net packages to nuget.org - if: github.repository == 'csukuangfj/sherpa-onnx' || github.repository == 'k2-fsa/sherpa-onnx' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash env: API_KEY: ${{ secrets.NUGET_API_KEY }} @@ -148,7 +136,7 @@ jobs: dotnet nuget push ./org.k2fsa.sherpa.onnx.*.nupkg --skip-duplicate --api-key $API_KEY --source https://api.nuget.org/v3/index.json - name: Release nuget packages - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/jni.yaml b/.github/workflows/jni.yaml index 1dbccb4f0..853d2dcb9 100644 --- a/.github/workflows/jni.yaml +++ b/.github/workflows/jni.yaml @@ -22,6 +22,8 @@ on: - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/jni/*' + workflow_dispatch: + concurrency: group: jni-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/linux-gpu.yaml b/.github/workflows/linux-gpu.yaml index 25350a319..de4e3e5af 100644 --- a/.github/workflows/linux-gpu.yaml +++ b/.github/workflows/linux-gpu.yaml @@ -30,20 +30,8 @@ on: - 'cmake/**' - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/c-api/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: linux-gpu-${{ github.ref }} @@ -136,7 +124,7 @@ jobs: .github/scripts/test-online-transducer.sh - name: Copy files - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash run: | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) @@ -153,7 +141,7 @@ jobs: tar cjvf ${dst}.tar.bz2 $dst - name: Release pre-compiled binaries and libs for linux x64 - if: env.RELEASE == 'true' && matrix.build_type == 'Release' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 2c026fc0e..bbb2c775b 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -30,20 +30,8 @@ on: - 'cmake/**' - 'sherpa-onnx/csrc/*' - 'sherpa-onnx/c-api/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: linux-${{ github.ref }} @@ -135,7 +123,7 @@ jobs: .github/scripts/test-online-transducer.sh - name: Copy files - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash run: | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) @@ -152,7 +140,7 @@ jobs: tar cjvf ${dst}.tar.bz2 $dst - name: Release pre-compiled binaries and libs for linux x64 - if: env.RELEASE == 'true' && matrix.build_type == 'Release' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index f103061f7..5cc6e83e2 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -4,6 +4,8 @@ on: push: branches: - master + tags: + - '*' paths: - '.github/workflows/macos.yaml' - '.github/scripts/test-online-transducer.sh' @@ -25,20 +27,8 @@ on: - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: macos-${{ github.ref }} @@ -133,7 +123,7 @@ jobs: .github/scripts/test-online-transducer.sh - name: Copy files - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash run: | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) @@ -151,7 +141,7 @@ jobs: tar cjvf ${dst}.tar.bz2 $dst - name: Release pre-compiled binaries and libs for macOS - if: env.RELEASE == 'true' && matrix.build_type == 'Release' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/mfc.yaml b/.github/workflows/mfc.yaml index cdf5f9c15..7d0061c62 100644 --- a/.github/workflows/mfc.yaml +++ b/.github/workflows/mfc.yaml @@ -21,20 +21,8 @@ on: - 'cmake/**' - 'mfc-examples/**' - 'sherpa-onnx/csrc/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: mfc-${{ github.ref }} @@ -114,7 +102,7 @@ jobs: path: ./mfc-examples/${{ matrix.arch }}/Release/NonStreamingSpeechRecognition.exe - name: Release pre-compiled binaries and libs for Windows ${{ matrix.arch }} - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true @@ -122,7 +110,7 @@ jobs: file: ./mfc-examples/${{ matrix.arch }}/Release/sherpa-onnx-streaming-*.exe - name: Release pre-compiled binaries and libs for Windows ${{ matrix.arch }} - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/windows-x64-cuda.yaml b/.github/workflows/windows-x64-cuda.yaml index 17e53d8b1..afbb501c4 100644 --- a/.github/workflows/windows-x64-cuda.yaml +++ b/.github/workflows/windows-x64-cuda.yaml @@ -27,20 +27,8 @@ on: - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: windows-x64-cuda-${{ github.ref }} @@ -125,7 +113,7 @@ jobs: .github/scripts/test-online-transducer.sh - name: Copy files - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash run: | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) @@ -140,7 +128,7 @@ jobs: tar cjvf ${dst}.tar.bz2 $dst - name: Release pre-compiled binaries and libs for Windows x64 CUDA - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index c63dbae30..c491b37d6 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -27,20 +27,8 @@ on: - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: windows-x64-${{ github.ref }} @@ -126,7 +114,7 @@ jobs: .github/scripts/test-online-transducer.sh - name: Copy files - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash run: | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) @@ -141,7 +129,7 @@ jobs: tar cjvf ${dst}.tar.bz2 $dst - name: Release pre-compiled binaries and libs for Windows x64 - if: env.RELEASE == 'true' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index b39a1ddcf..e74aa54b2 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -4,6 +4,8 @@ on: push: branches: - master + tags: + - '*' paths: - '.github/workflows/windows-x86.yaml' - '.github/scripts/test-online-transducer.sh' @@ -25,20 +27,8 @@ on: - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' - release: - types: - - published workflow_dispatch: - inputs: - release: - description: "Whether to release" - type: boolean - -env: - RELEASE: - |- # Release if there is a release tag name or a release flag in workflow_dispatch - ${{ github.event.release.tag_name != '' || github.event.inputs.release == 'true' }} concurrency: group: windows-x86-${{ github.ref }} @@ -124,7 +114,7 @@ jobs: .github/scripts/test-online-transducer.sh - name: Copy files - if: env.RELEASE == 'true' && matrix.vs-version == 'vs2015' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') shell: bash run: | SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) @@ -139,7 +129,7 @@ jobs: tar cjvf ${dst}.tar.bz2 $dst - name: Release pre-compiled binaries and libs for Windows x86 - if: env.RELEASE == 'true' && matrix.vs-version == 'vs2015' + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/CMakeLists.txt b/CMakeLists.txt index 70046c0d2..bf877a53b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.11") +set(SHERPA_ONNX_VERSION "1.7.12") # Disable warning about # diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index f4fe6998d..96e0db7e8 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -407,7 +407,7 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) defer C.free(unsafe.Pointer(c.model_config.whisper.decoder)) - c.model_config.tdnn.decoder = C.CString(config.ModelConfig.Tdnn.Model) + c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) c.model_config.tokens = C.CString(config.ModelConfig.Tokens) From a12ebfab2294f4ede5cd4c85e3e63d994851caba Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 Sep 2023 15:12:29 +0800 Subject: [PATCH 4/9] treat unk as blank (#299) --- .../csrc/online-recognizer-transducer-impl.h | 21 +++++++++++++------ ...online-transducer-greedy-search-decoder.cc | 4 +++- .../online-transducer-greedy-search-decoder.h | 6 ++++-- ...transducer-modified-beam-search-decoder.cc | 4 +++- ...-transducer-modified-beam-search-decoder.h | 6 ++++-- 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 27f58687b..e08993dc1 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { + if (sym_.contains("")) { + unk_id_ = sym_[""]; + } + if (config.decoding_method == "modified_beam_search") { if (!config_.lm_config.model.empty()) { lm_ = OnlineLM::Create(config.lm_config); @@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); + config_.lm_config.scale, unk_id_); } else if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); + decoder_ = std::make_unique( + model_.get(), unk_id_); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(mgr, config.model_config)), sym_(mgr, config.model_config.tokens), endpoint_(config_.endpoint_config) { + if (sym_.contains("")) { + unk_id_ = sym_[""]; + } + if (config.decoding_method == "modified_beam_search") { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); + config_.lm_config.scale, unk_id_); } else if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); + decoder_ = std::make_unique( + model_.get(), unk_id_); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { std::unique_ptr decoder_; SymbolTable sym_; Endpoint endpoint_; + int32_t unk_id_ = -1; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 965285ce7..e90426bdc 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( static_cast(p_logit), std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); - if (y != 0) { + // blank id is hardcoded to 0 + // also, it treats unk as blank + if (y != 0 && y != unk_id_) { emitted = true; r.tokens.push_back(y); r.timestamps.push_back(t + r.frame_offset); diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index f7fa7ddf4..363cefedd 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -14,8 +14,9 @@ namespace sherpa_onnx { class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { public: - explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) - : model_(model) {} + OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, + int32_t unk_id) + : model_(model), unk_id_(unk_id) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { private: OnlineTransducerModel *model_; // Not owned + int32_t unk_id_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index fef673472..a98f19dad 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( float context_score = 0; auto context_state = new_hyp.context_state; - if (new_token != 0) { + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { new_hyp.ys.push_back(new_token); new_hyp.timestamps.push_back(t + frame_offset); new_hyp.num_trailing_blanks = 0; diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index d05c5167b..bc0cfb559 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, OnlineLM *lm, int32_t max_active_paths, - float lm_scale) + float lm_scale, int32_t unk_id) : model_(model), lm_(lm), max_active_paths_(max_active_paths), - lm_scale_(lm_scale) {} + lm_scale_(lm_scale), + unk_id_(unk_id) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder int32_t max_active_paths_; float lm_scale_; // used only when lm_ is not nullptr + int32_t unk_id_; }; } // namespace sherpa_onnx From 86b18184c99e559d5f65876dc20c7d764393c5d4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 Sep 2023 15:27:41 +0800 Subject: [PATCH 5/9] Fix Go examples (#300) --- .../non-streaming-decode-files/.gitignore | 2 ++ .../non-streaming-decode-files/go.mod | 2 +- .../non-streaming-decode-files/go.sum | 16 ++++++++-------- .../go.mod | 2 +- .../go.sum | 16 ++++++++-------- go-api-examples/streaming-decode-files/go.mod | 2 +- go-api-examples/streaming-decode-files/go.sum | 16 ++++++++-------- 7 files changed, 29 insertions(+), 27 deletions(-) diff --git a/go-api-examples/non-streaming-decode-files/.gitignore b/go-api-examples/non-streaming-decode-files/.gitignore index 0ea122c78..56666688f 100644 --- a/go-api-examples/non-streaming-decode-files/.gitignore +++ b/go-api-examples/non-streaming-decode-files/.gitignore @@ -1,2 +1,4 @@ non-streaming-decode-files sherpa-onnx-zipformer-en-2023-06-26 +sherpa-onnx-whisper-tiny.en +sherpa-onnx-tdnn-yesno diff --git a/go-api-examples/non-streaming-decode-files/go.mod b/go-api-examples/non-streaming-decode-files/go.mod index 516e08b10..4f81374ff 100644 --- a/go-api-examples/non-streaming-decode-files/go.mod +++ b/go-api-examples/non-streaming-decode-files/go.mod @@ -3,7 +3,7 @@ module non-streaming-decode-files go 1.12 require ( - github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 + github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha github.com/spf13/pflag v1.0.5 github.com/youpy/go-wav v0.3.2 ) diff --git a/go-api-examples/non-streaming-decode-files/go.sum b/go-api-examples/non-streaming-decode-files/go.sum index a1565ae52..46db02fa2 100644 --- a/go-api-examples/non-streaming-decode-files/go.sum +++ b/go-api-examples/non-streaming-decode-files/go.sum @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= -github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= -github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= -github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= -github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= -github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= -github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= -github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= +github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha h1:pm9VCFe51c59LilgDmGwKGfGB/TalLJX26LSvjrELTk= +github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha/go.mod h1:JLAytuKK2r1sPf8BcyaUTFfvmGGTLpbfG9g9x/Rq7GA= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.12 h1:9g6Af3kBtcbDrTH7EqlWB9cSvBsc/xY00r7MeA/qVzo= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.12/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.12-alpha h1:G8B6PaPHTFlbe6YtUFc7/H4rJfzmOJRvEzPJMj4h/w8= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.12-alpha/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.12 h1:WudeR8tlCsS5uj0d99jJ+jaKjvyND+aCuajFDE9qEY4= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.12/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/go-api-examples/real-time-speech-recognition-from-microphone/go.mod b/go-api-examples/real-time-speech-recognition-from-microphone/go.mod index b10d2e390..13d8f41db 100644 --- a/go-api-examples/real-time-speech-recognition-from-microphone/go.mod +++ b/go-api-examples/real-time-speech-recognition-from-microphone/go.mod @@ -4,6 +4,6 @@ go 1.12 require ( github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 - github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 + github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha github.com/spf13/pflag v1.0.5 ) diff --git a/go-api-examples/real-time-speech-recognition-from-microphone/go.sum b/go-api-examples/real-time-speech-recognition-from-microphone/go.sum index a7332280b..ac7a23bb4 100644 --- a/go-api-examples/real-time-speech-recognition-from-microphone/go.sum +++ b/go-api-examples/real-time-speech-recognition-from-microphone/go.sum @@ -1,12 +1,12 @@ github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= -github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= -github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= -github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= -github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= -github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= -github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= -github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= -github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= +github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha h1:pm9VCFe51c59LilgDmGwKGfGB/TalLJX26LSvjrELTk= +github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha/go.mod h1:JLAytuKK2r1sPf8BcyaUTFfvmGGTLpbfG9g9x/Rq7GA= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.12 h1:9g6Af3kBtcbDrTH7EqlWB9cSvBsc/xY00r7MeA/qVzo= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.12/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.12-alpha h1:G8B6PaPHTFlbe6YtUFc7/H4rJfzmOJRvEzPJMj4h/w8= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.12-alpha/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.12 h1:WudeR8tlCsS5uj0d99jJ+jaKjvyND+aCuajFDE9qEY4= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.12/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= diff --git a/go-api-examples/streaming-decode-files/go.mod b/go-api-examples/streaming-decode-files/go.mod index 278520e95..d8e4837b3 100644 --- a/go-api-examples/streaming-decode-files/go.mod +++ b/go-api-examples/streaming-decode-files/go.mod @@ -3,7 +3,7 @@ module streaming-decode-files go 1.12 require ( - github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 + github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha github.com/spf13/pflag v1.0.5 github.com/youpy/go-wav v0.3.2 ) diff --git a/go-api-examples/streaming-decode-files/go.sum b/go-api-examples/streaming-decode-files/go.sum index a1565ae52..46db02fa2 100644 --- a/go-api-examples/streaming-decode-files/go.sum +++ b/go-api-examples/streaming-decode-files/go.sum @@ -2,14 +2,14 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1 h1:Em5/MJcZUkzqJuZZgTHcZhruQ828qsEyH46wHSHQLjQ= -github.com/k2-fsa/sherpa-onnx-go v1.7.6-alpha.1/go.mod h1:A8I7HnuFkTM5i3qK+mWfPTmoNAD+RYcR+PG/PO9Cf0c= -github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6 h1:gQV7yFVhssfg1ZaVHrlRl3xHJVJ+4O7rXgz15mLMynM= -github.com/k2-fsa/sherpa-onnx-go-linux v1.7.6/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= -github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6 h1:vHKEL9PMeyShFsS3Dc1iohLk1zAOp02kKoWiGKtV/xk= -github.com/k2-fsa/sherpa-onnx-go-macos v1.7.6/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= -github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6 h1:5pKmsXioj/eXfS6oE320PwR/aVtTcLWeRiqfrJHOIY4= -github.com/k2-fsa/sherpa-onnx-go-windows v1.7.6/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= +github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha h1:pm9VCFe51c59LilgDmGwKGfGB/TalLJX26LSvjrELTk= +github.com/k2-fsa/sherpa-onnx-go v1.7.12-alpha/go.mod h1:JLAytuKK2r1sPf8BcyaUTFfvmGGTLpbfG9g9x/Rq7GA= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.12 h1:9g6Af3kBtcbDrTH7EqlWB9cSvBsc/xY00r7MeA/qVzo= +github.com/k2-fsa/sherpa-onnx-go-linux v1.7.12/go.mod h1:lHZRU/WtBUJetJVPyXHg092diEWYyIEoaob+LMJKWvo= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.12-alpha h1:G8B6PaPHTFlbe6YtUFc7/H4rJfzmOJRvEzPJMj4h/w8= +github.com/k2-fsa/sherpa-onnx-go-macos v1.7.12-alpha/go.mod h1:o1Cd6Zy+Tpq3bLAWqBoVcDenxi8HSaSubURtbtIqH2s= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.12 h1:WudeR8tlCsS5uj0d99jJ+jaKjvyND+aCuajFDE9qEY4= +github.com/k2-fsa/sherpa-onnx-go-windows v1.7.12/go.mod h1:R7JSrFkZGkfM/F/gVSR+yTJ+sPaHhJgdqsB5N7dTU6E= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 8982984ea28fc452d83419b4b9e89e003357a02b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 10 Sep 2023 17:56:13 +0800 Subject: [PATCH 6/9] add a two-pass python example (#303) --- ...pass-speech-recognition-from-microphone.py | 423 ++++++++++++++++++ 1 file changed, 423 insertions(+) create mode 100755 python-api-examples/two-pass-speech-recognition-from-microphone.py diff --git a/python-api-examples/two-pass-speech-recognition-from-microphone.py b/python-api-examples/two-pass-speech-recognition-from-microphone.py new file mode 100755 index 000000000..12a57ffa8 --- /dev/null +++ b/python-api-examples/two-pass-speech-recognition-from-microphone.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 + +# Two-pass real-time speech recognition from a microphone with sherpa-onnx +# Python API. +# +# The first pass uses a streaming model, which has two purposes: +# +# (1) Display a temporary result to users +# +# (2) Endpointing +# +# The second pass uses a non-streaming model. It has a higher recognition +# accuracy than the first pass model and its result is used as the final result. +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +""" +Usage examples: + +(1) Chinese: Streaming zipformer (1st pass) + Non-streaming paraformer (2nd pass) + +python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \ + --first-encoder ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/encoder-epoch-99-avg-1.onnx \ + --first-decoder ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/decoder-epoch-99-avg-1.onnx \ + --first-joiner ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/joiner-epoch-99-avg-1.onnx \ + --first-tokens ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/tokens.txt \ + \ + --second-paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ + --second-tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt + +(2) English: Streaming zipformer (1st pass) + Non-streaming whisper (2nd pass) + +python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \ + --first-encoder ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.onnx \ + --first-decoder ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/decoder-epoch-99-avg-1.onnx \ + --first-joiner ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/joiner-epoch-99-avg-1.onnx \ + --first-tokens ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt \ + \ + --second-whisper-encoder ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx \ + --second-whisper-decoder ./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx \ + --second-tokens ./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt +""" + +import argparse +import sys +from pathlib import Path +from typing import List + +import numpy as np + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_onnx + + +def assert_file_exists(filename: str, message: str): + if not filename: + raise ValueError(f"Please specify {message}") + + if not Path(filename).is_file(): + raise ValueError(f"{message} {filename} does not exist") + + +def add_first_pass_streaming_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--first-tokens", + type=str, + required=True, + help="Path to tokens.txt for the first pass", + ) + + parser.add_argument( + "--first-encoder", + type=str, + required=True, + help="Path to the encoder model for the first pass", + ) + + parser.add_argument( + "--first-decoder", + type=str, + required=True, + help="Path to the decoder model for the first pass", + ) + + parser.add_argument( + "--first-joiner", + type=str, + help="Path to the joiner model for the first pass", + ) + + parser.add_argument( + "--first-decoding-method", + type=str, + default="greedy_search", + help="""Decoding method for the first pass. Valid values are + greedy_search and modified_beam_search""", + ) + + parser.add_argument( + "--first-max-active-paths", + type=int, + default=4, + help="""Used only when --first-decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + +def add_second_pass_transducer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-encoder", + default="", + type=str, + help="Path to the transducer encoder model for the second pass", + ) + + parser.add_argument( + "--second-decoder", + default="", + type=str, + help="Path to the transducer decoder model for the second pass", + ) + + parser.add_argument( + "--second-joiner", + default="", + type=str, + help="Path to the transducer joiner model for the second pass", + ) + + +def add_second_pass_paraformer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-paraformer", + default="", + type=str, + help="Path to the model.onnx for Paraformer for the second pass", + ) + + +def add_second_pass_nemo_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-nemo-ctc", + default="", + type=str, + help="Path to the model.onnx for NeMo CTC for the second pass", + ) + + +def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model for the second pass", + ) + + parser.add_argument( + "--second-whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model for the second pass", + ) + + parser.add_argument( + "--second-whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--second-whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + +def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser): + add_second_pass_transducer_model_args(parser) + add_second_pass_nemo_ctc_model_args(parser) + add_second_pass_paraformer_model_args(parser) + add_second_pass_whisper_model_args(parser) + + parser.add_argument( + "--second-tokens", + type=str, + help="Path to tokens.txt for the second pass", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + add_first_pass_streaming_model_args(parser) + add_second_pass_non_streaming_model_args(parser) + + return parser.parse_args() + + +def check_first_pass_args(args): + assert_file_exists(args.first_tokens, "--first-tokens") + assert_file_exists(args.first_encoder, "--first-encoder") + assert_file_exists(args.first_decoder, "--first-decoder") + assert_file_exists(args.first_joiner, "--first-joiner") + + +def check_second_pass_args(args): + assert_file_exists(args.second_tokens, "--second-tokens") + + if args.second_encoder: + assert_file_exists(args.second_encoder, "--second-encoder") + assert_file_exists(args.second_decoder, "--second-decoder") + assert_file_exists(args.second_joiner, "--second-joiner") + elif args.second_paraformer: + assert_file_exists(args.second_paraformer, "--second-paraformer") + elif args.second_nemo_ctc: + assert_file_exists(args.second_nemo_ctc, "--second-nemo-ctc") + elif args.second_whisper_encoder: + assert_file_exists(args.second_whisper_encoder, "--second-whisper-encoder") + assert_file_exists(args.second_whisper_decoder, "--second-whisper-decoder") + else: + raise ValueError("Please specify the model for the second pass") + + +def create_first_pass_recognizer(args): + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( + tokens=args.first_tokens, + encoder=args.first_encoder, + decoder=args.first_decoder, + joiner=args.first_joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method=args.first_decoding_method, + max_active_paths=args.first_max_active_paths, + provider=args.provider, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=20, + ) + return recognizer + + +def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer: + if args.second_encoder: + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( + encoder=args.second_encoder, + decoder=args.second_decoder, + joiner=args.second_joiner, + tokens=args.second_tokens, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + max_active_paths=4, + ) + elif args.second_paraformer: + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.second_paraformer, + tokens=args.second_tokens, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + elif args.second_nemo_ctc: + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( + model=args.second_nemo_ctc, + tokens=args.second_tokens, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + elif args.second_whisper_encoder: + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( + encoder=args.second_whisper_encoder, + decoder=args.second_whisper_decoder, + tokens=args.second_tokens, + num_threads=1, + decoding_method="greedy_search", + language=args.second_whisper_language, + task=args.second_whisper_task, + ) + else: + raise ValueError("Please specify at least one model for the second pass") + + return recognizer + + +def run_second_pass( + recognizer: sherpa_onnx.OfflineRecognizer, + sample_buffers: List[np.ndarray], + sample_rate: int, +): + stream = recognizer.create_stream() + samples = np.concatenate(sample_buffers) + stream.accept_waveform(sample_rate, samples) + + recognizer.decode_stream(stream) + + return stream.result.text + + +def main(): + args = get_args() + check_first_pass_args(args) + check_second_pass_args(args) + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + + # If you want to select a different input device, please use + # sd.default.device[0] = xxx + # where xxx is the device number + + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + print("Creating recognizers. Please wait...") + first_recognizer = create_first_pass_recognizer(args) + second_recognizer = create_second_pass_recognizer(args) + + print("Started! Please speak") + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + stream = first_recognizer.create_stream() + + last_result = "" + segment_id = 0 + + sample_buffers = [] + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + + sample_buffers.append(samples) + + while first_recognizer.is_ready(stream): + first_recognizer.decode_stream(stream) + + is_endpoint = first_recognizer.is_endpoint(stream) + + result = first_recognizer.get_result(stream) + result = result.lower().strip() + + if last_result != result: + print( + "\r{}:{}".format(segment_id, " " * len(last_result)), + end="", + flush=True, + ) + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + + if is_endpoint: + if result: + result = run_second_pass( + recognizer=second_recognizer, + sample_buffers=sample_buffers, + sample_rate=sample_rate, + ) + result = result.lower().strip() + + sample_buffers = [] + print( + "\r{}:{}".format(segment_id, " " * len(last_result)), + end="", + flush=True, + ) + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + else: + sample_buffers = [] + + first_recognizer.reset(stream) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") From debab7c0915b335eabcb4ddd1f053109aa9b66be Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Sep 2023 15:40:16 +0800 Subject: [PATCH 7/9] Add two-pass speech recognition Android/iOS demo (#304) --- .gitignore | 7 + .../com/k2fsa/sherpa/onnx/MainActivity.kt | 44 +- .../java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt | 4 +- android/SherpaOnnx2Pass/.gitignore | 15 + android/SherpaOnnx2Pass/.idea/.gitignore | 3 + android/SherpaOnnx2Pass/.idea/compiler.xml | 6 + android/SherpaOnnx2Pass/.idea/gradle.xml | 19 + android/SherpaOnnx2Pass/.idea/misc.xml | 10 + android/SherpaOnnx2Pass/.idea/vcs.xml | 6 + android/SherpaOnnx2Pass/app/.gitignore | 1 + android/SherpaOnnx2Pass/app/build.gradle | 44 ++ .../SherpaOnnx2Pass/app/proguard-rules.pro | 21 + .../sherpa/onnx/ExampleInstrumentedTest.kt | 24 ++ .../SherpaOnnx2Pass/app/src/main/.gitignore | 1 + .../app/src/main/AndroidManifest.xml | 32 ++ .../app/src/main/assets/.gitkeep | 0 .../com/k2fsa/sherpa/onnx/MainActivity.kt | 251 ++++++++++++ .../java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt | 375 +++++++++++++++++ .../java/com/k2fsa/sherpa/onnx/WaveReader.kt | 1 + .../app/src/main/jniLibs/.gitkeep | 0 .../app/src/main/jniLibs/arm64-v8a/.gitkeep | 0 .../app/src/main/jniLibs/armeabi-v7a/.gitkeep | 0 .../app/src/main/jniLibs/x86/.gitkeep | 0 .../app/src/main/jniLibs/x86_64/.gitkeep | 0 .../drawable-v24/ic_launcher_foreground.xml | 30 ++ .../res/drawable/ic_launcher_background.xml | 170 ++++++++ .../app/src/main/res/layout/activity_main.xml | 39 ++ .../res/mipmap-anydpi-v26/ic_launcher.xml | 5 + .../mipmap-anydpi-v26/ic_launcher_round.xml | 5 + .../src/main/res/mipmap-hdpi/ic_launcher.webp | Bin 0 -> 1404 bytes .../res/mipmap-hdpi/ic_launcher_round.webp | Bin 0 -> 2898 bytes .../src/main/res/mipmap-mdpi/ic_launcher.webp | Bin 0 -> 982 bytes .../res/mipmap-mdpi/ic_launcher_round.webp | Bin 0 -> 1772 bytes .../main/res/mipmap-xhdpi/ic_launcher.webp | Bin 0 -> 1900 bytes .../res/mipmap-xhdpi/ic_launcher_round.webp | Bin 0 -> 3918 bytes .../main/res/mipmap-xxhdpi/ic_launcher.webp | Bin 0 -> 2884 bytes .../res/mipmap-xxhdpi/ic_launcher_round.webp | Bin 0 -> 5914 bytes .../main/res/mipmap-xxxhdpi/ic_launcher.webp | Bin 0 -> 3844 bytes .../res/mipmap-xxxhdpi/ic_launcher_round.webp | Bin 0 -> 7778 bytes .../app/src/main/res/values-night/themes.xml | 16 + .../app/src/main/res/values/colors.xml | 10 + .../app/src/main/res/values/strings.xml | 13 + .../app/src/main/res/values/themes.xml | 16 + .../app/src/main/res/xml/backup_rules.xml | 13 + .../main/res/xml/data_extraction_rules.xml | 19 + .../com/k2fsa/sherpa/onnx/ExampleUnitTest.kt | 17 + android/SherpaOnnx2Pass/build.gradle | 6 + android/SherpaOnnx2Pass/gradle.properties | 23 ++ .../gradle/wrapper/gradle-wrapper.properties | 6 + android/SherpaOnnx2Pass/gradlew | 185 +++++++++ android/SherpaOnnx2Pass/gradlew.bat | 89 ++++ android/SherpaOnnx2Pass/settings.gradle | 16 + ios-swift/.gitignore | 91 +++++ ios-swiftui/.gitignore | 91 +++++ .../SherpaOnnx2Pass.xcodeproj/project.pbxproj | 380 ++++++++++++++++++ .../contents.xcworkspacedata | 7 + .../xcshareddata/IDEWorkspaceChecks.plist | 8 + .../AccentColor.colorset/Contents.json | 11 + .../AppIcon.appiconset/Contents.json | 14 + .../AppIcon.appiconset/k2-1024x1024.png | Bin 0 -> 421090 bytes .../Assets.xcassets/Contents.json | 6 + .../SherpaOnnx2Pass/ContentView.swift | 46 +++ .../SherpaOnnx2Pass/Extension.swift | 20 + .../SherpaOnnx2Pass/Model.swift | 134 ++++++ .../Preview Assets.xcassets/Contents.json | 6 + .../SherpaOnnx2Pass/SherpaOnnx2PassApp.swift | 17 + .../SherpaOnnx2Pass/SherpaOnnxViewModel.swift | 252 ++++++++++++ .../SherpaOnnx2Pass/k2-1024x1024.png | 1 + sherpa-onnx/c-api/c-api.cc | 2 + sherpa-onnx/csrc/offline-ctc-model.cc | 38 ++ sherpa-onnx/csrc/offline-ctc-model.h | 11 + sherpa-onnx/csrc/offline-lm.cc | 7 + sherpa-onnx/csrc/offline-lm.h | 10 + .../csrc/offline-nemo-enc-dec-ctc-model.cc | 26 +- .../csrc/offline-nemo-enc-dec-ctc-model.h | 11 + sherpa-onnx/csrc/offline-paraformer-model.cc | 26 +- sherpa-onnx/csrc/offline-paraformer-model.h | 10 + .../csrc/offline-recognizer-ctc-impl.h | 23 +- sherpa-onnx/csrc/offline-recognizer-impl.cc | 117 ++++++ sherpa-onnx/csrc/offline-recognizer-impl.h | 10 + .../csrc/offline-recognizer-paraformer-impl.h | 27 ++ .../csrc/offline-recognizer-transducer-impl.h | 31 ++ .../csrc/offline-recognizer-whisper-impl.h | 25 +- sherpa-onnx/csrc/offline-recognizer.cc | 6 + sherpa-onnx/csrc/offline-recognizer.h | 9 + sherpa-onnx/csrc/offline-rnn-lm.cc | 27 +- sherpa-onnx/csrc/offline-rnn-lm.h | 9 + sherpa-onnx/csrc/offline-tdnn-ctc-model.cc | 26 +- sherpa-onnx/csrc/offline-tdnn-ctc-model.h | 10 + sherpa-onnx/csrc/offline-transducer-model.cc | 29 ++ sherpa-onnx/csrc/offline-transducer-model.h | 10 + sherpa-onnx/csrc/offline-whisper-model.cc | 24 ++ sherpa-onnx/csrc/offline-whisper-model.h | 10 + sherpa-onnx/jni/jni.cc | 207 +++++++++- swift-api-examples/.gitignore | 2 + swift-api-examples/SherpaOnnx.swift | 165 +++++++- .../decode-file-non-streaming.swift | 65 +++ 97 files changed, 3544 insertions(+), 55 deletions(-) create mode 100644 android/SherpaOnnx2Pass/.gitignore create mode 100644 android/SherpaOnnx2Pass/.idea/.gitignore create mode 100644 android/SherpaOnnx2Pass/.idea/compiler.xml create mode 100644 android/SherpaOnnx2Pass/.idea/gradle.xml create mode 100644 android/SherpaOnnx2Pass/.idea/misc.xml create mode 100644 android/SherpaOnnx2Pass/.idea/vcs.xml create mode 100644 android/SherpaOnnx2Pass/app/.gitignore create mode 100644 android/SherpaOnnx2Pass/app/build.gradle create mode 100644 android/SherpaOnnx2Pass/app/proguard-rules.pro create mode 100644 android/SherpaOnnx2Pass/app/src/androidTest/java/com/k2fsa/sherpa/onnx/ExampleInstrumentedTest.kt create mode 100644 android/SherpaOnnx2Pass/app/src/main/.gitignore create mode 100644 android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/assets/.gitkeep create mode 100644 android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt create mode 100644 android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt create mode 120000 android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt create mode 100644 android/SherpaOnnx2Pass/app/src/main/jniLibs/.gitkeep create mode 100644 android/SherpaOnnx2Pass/app/src/main/jniLibs/arm64-v8a/.gitkeep create mode 100644 android/SherpaOnnx2Pass/app/src/main/jniLibs/armeabi-v7a/.gitkeep create mode 100644 android/SherpaOnnx2Pass/app/src/main/jniLibs/x86/.gitkeep create mode 100644 android/SherpaOnnx2Pass/app/src/main/jniLibs/x86_64/.gitkeep create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/drawable-v24/ic_launcher_foreground.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/drawable/ic_launcher_background.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/layout/activity_main.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-hdpi/ic_launcher.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-mdpi/ic_launcher.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-xhdpi/ic_launcher.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/values-night/themes.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/values/colors.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/values/strings.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/values/themes.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/xml/backup_rules.xml create mode 100644 android/SherpaOnnx2Pass/app/src/main/res/xml/data_extraction_rules.xml create mode 100644 android/SherpaOnnx2Pass/app/src/test/java/com/k2fsa/sherpa/onnx/ExampleUnitTest.kt create mode 100644 android/SherpaOnnx2Pass/build.gradle create mode 100644 android/SherpaOnnx2Pass/gradle.properties create mode 100644 android/SherpaOnnx2Pass/gradle/wrapper/gradle-wrapper.properties create mode 100755 android/SherpaOnnx2Pass/gradlew create mode 100644 android/SherpaOnnx2Pass/gradlew.bat create mode 100644 android/SherpaOnnx2Pass/settings.gradle create mode 100644 ios-swift/.gitignore create mode 100644 ios-swiftui/.gitignore create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass.xcodeproj/project.pbxproj create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass.xcodeproj/project.xcworkspace/contents.xcworkspacedata create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Assets.xcassets/AccentColor.colorset/Contents.json create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Assets.xcassets/AppIcon.appiconset/Contents.json create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Assets.xcassets/AppIcon.appiconset/k2-1024x1024.png create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Assets.xcassets/Contents.json create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/ContentView.swift create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Extension.swift create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Model.swift create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/Preview Content/Preview Assets.xcassets/Contents.json create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/SherpaOnnx2PassApp.swift create mode 100644 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/SherpaOnnxViewModel.swift create mode 120000 ios-swiftui/SherpaOnnx2Pass/SherpaOnnx2Pass/k2-1024x1024.png create mode 100644 swift-api-examples/.gitignore create mode 100644 swift-api-examples/decode-file-non-streaming.swift diff --git a/.gitignore b/.gitignore index d6f49076a..d2d16a1b9 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,10 @@ run-offline-decode-files-nemo-ctc.sh *.jar sherpa-onnx-nemo-ctc-* *.wav +sherpa-onnx-zipformer-* +sherpa-onnx-conformer-* +sherpa-onnx-whisper-* +swift-api-examples/k2fsa-* +run-*.sh +two-pass-*.sh +build-* diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 827ad6a1b..1619f3b27 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -21,10 +21,6 @@ private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 class MainActivity : AppCompatActivity() { private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) - // If there is a GPU and useGPU is true, we will use GPU - // If there is no GPU and useGPU is true, we won't use GPU - private val useGPU: Boolean = true - private lateinit var model: SherpaOnnx private var audioRecord: AudioRecord? = null private lateinit var recordButton: Button @@ -91,7 +87,7 @@ class MainActivity : AppCompatActivity() { audioRecord!!.startRecording() recordButton.setText(R.string.stop) isRecording = true - model.reset() + model.reset(true) textView.text = "" lastText = "" idx = 0 @@ -125,26 +121,32 @@ class MainActivity : AppCompatActivity() { while (model.isReady()) { model.decode() } - runOnUiThread { - val isEndpoint = model.isEndpoint() - val text = model.text - - if(text.isNotBlank()) { - if (lastText.isBlank()) { - textView.text = "${idx}: ${text}" - } else { - textView.text = "${lastText}\n${idx}: ${text}" - } + + val isEndpoint = model.isEndpoint() + val text = model.text + + var textToDisplay = lastText; + + if(text.isNotBlank()) { + if (lastText.isBlank()) { + textToDisplay = "${idx}: ${text}" + } else { + textToDisplay = "${lastText}\n${idx}: ${text}" } + } - if (isEndpoint) { - model.reset() - if (text.isNotBlank()) { - lastText = "${lastText}\n${idx}: ${text}" - idx += 1 - } + if (isEndpoint) { + model.reset() + if (text.isNotBlank()) { + lastText = "${lastText}\n${idx}: ${text}" + textToDisplay = lastText; + idx += 1 } } + + runOnUiThread { + textView.text = textToDisplay + } } } } diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index c68703b46..185765622 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -77,7 +77,7 @@ class SherpaOnnx( acceptWaveform(ptr, samples, sampleRate) fun inputFinished() = inputFinished(ptr) - fun reset() = reset(ptr) + fun reset(recreate: Boolean = false) = reset(ptr, recreate = recreate) fun decode() = decode(ptr) fun isEndpoint(): Boolean = isEndpoint(ptr) fun isReady(): Boolean = isReady(ptr) @@ -99,7 +99,7 @@ class SherpaOnnx( private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) private external fun inputFinished(ptr: Long) private external fun getText(ptr: Long): String - private external fun reset(ptr: Long) + private external fun reset(ptr: Long, recreate: Boolean) private external fun decode(ptr: Long) private external fun isEndpoint(ptr: Long): Boolean private external fun isReady(ptr: Long): Boolean diff --git a/android/SherpaOnnx2Pass/.gitignore b/android/SherpaOnnx2Pass/.gitignore new file mode 100644 index 000000000..aa724b770 --- /dev/null +++ b/android/SherpaOnnx2Pass/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/android/SherpaOnnx2Pass/.idea/.gitignore b/android/SherpaOnnx2Pass/.idea/.gitignore new file mode 100644 index 000000000..26d33521a --- /dev/null +++ b/android/SherpaOnnx2Pass/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/android/SherpaOnnx2Pass/.idea/compiler.xml b/android/SherpaOnnx2Pass/.idea/compiler.xml new file mode 100644 index 000000000..fb7f4a8a4 --- /dev/null +++ b/android/SherpaOnnx2Pass/.idea/compiler.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/.idea/gradle.xml b/android/SherpaOnnx2Pass/.idea/gradle.xml new file mode 100644 index 000000000..a2d7c2133 --- /dev/null +++ b/android/SherpaOnnx2Pass/.idea/gradle.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/.idea/misc.xml b/android/SherpaOnnx2Pass/.idea/misc.xml new file mode 100644 index 000000000..bdd92780c --- /dev/null +++ b/android/SherpaOnnx2Pass/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/.idea/vcs.xml b/android/SherpaOnnx2Pass/.idea/vcs.xml new file mode 100644 index 000000000..b2bdec2d7 --- /dev/null +++ b/android/SherpaOnnx2Pass/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/.gitignore b/android/SherpaOnnx2Pass/app/.gitignore new file mode 100644 index 000000000..42afabfd2 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/build.gradle b/android/SherpaOnnx2Pass/app/build.gradle new file mode 100644 index 000000000..d64be8079 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/build.gradle @@ -0,0 +1,44 @@ +plugins { + id 'com.android.application' + id 'org.jetbrains.kotlin.android' +} + +android { + namespace 'com.k2fsa.sherpa.onnx' + compileSdk 32 + + defaultConfig { + applicationId "com.k2fsa.sherpa.onnx" + minSdk 21 + targetSdk 32 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = '1.8' + } +} + +dependencies { + + implementation 'androidx.core:core-ktx:1.7.0' + implementation 'androidx.appcompat:appcompat:1.5.1' + implementation 'com.google.android.material:material:1.7.0' + implementation 'androidx.constraintlayout:constraintlayout:2.1.4' + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.4' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0' +} \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/proguard-rules.pro b/android/SherpaOnnx2Pass/app/proguard-rules.pro new file mode 100644 index 000000000..481bb4348 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/androidTest/java/com/k2fsa/sherpa/onnx/ExampleInstrumentedTest.kt b/android/SherpaOnnx2Pass/app/src/androidTest/java/com/k2fsa/sherpa/onnx/ExampleInstrumentedTest.kt new file mode 100644 index 000000000..183383202 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/androidTest/java/com/k2fsa/sherpa/onnx/ExampleInstrumentedTest.kt @@ -0,0 +1,24 @@ +package com.k2fsa.sherpa.onnx + +import androidx.test.platform.app.InstrumentationRegistry +import androidx.test.ext.junit.runners.AndroidJUnit4 + +import org.junit.Test +import org.junit.runner.RunWith + +import org.junit.Assert.* + +/** + * Instrumented test, which will execute on an Android device. + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +@RunWith(AndroidJUnit4::class) +class ExampleInstrumentedTest { + @Test + fun useAppContext() { + // Context of the app under test. + val appContext = InstrumentationRegistry.getInstrumentation().targetContext + assertEquals("com.k2fsa.sherpa.onnx", appContext.packageName) + } +} \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/.gitignore b/android/SherpaOnnx2Pass/app/src/main/.gitignore new file mode 100644 index 000000000..140f8cf80 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/.gitignore @@ -0,0 +1 @@ +*.so diff --git a/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml b/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml new file mode 100644 index 000000000..2a440df14 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/AndroidManifest.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/assets/.gitkeep b/android/SherpaOnnx2Pass/app/src/main/assets/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt new file mode 100644 index 000000000..012c0db5e --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -0,0 +1,251 @@ +package com.k2fsa.sherpa.onnx + +import android.Manifest +import android.content.pm.PackageManager +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import android.os.Bundle +import android.text.method.ScrollingMovementMethod +import android.util.Log +import android.widget.Button +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import androidx.core.app.ActivityCompat +import kotlin.concurrent.thread + +private const val TAG = "sherpa-onnx" +private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 + +class MainActivity : AppCompatActivity() { + private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) + + private lateinit var onlineRecognizer: SherpaOnnx + private lateinit var offlineRecognizer: SherpaOnnxOffline + private var audioRecord: AudioRecord? = null + private lateinit var recordButton: Button + private lateinit var textView: TextView + private var recordingThread: Thread? = null + + private val audioSource = MediaRecorder.AudioSource.MIC + private val sampleRateInHz = 16000 + private val channelConfig = AudioFormat.CHANNEL_IN_MONO + + private var samplesBuffer = arrayListOf() + + // Note: We don't use AudioFormat.ENCODING_PCM_FLOAT + // since the AudioRecord.read(float[]) needs API level >= 23 + // but we are targeting API level >= 21 + private val audioFormat = AudioFormat.ENCODING_PCM_16BIT + private var idx: Int = 0 + private var lastText: String = "" + + @Volatile + private var isRecording: Boolean = false + + override fun onRequestPermissionsResult( + requestCode: Int, permissions: Array, grantResults: IntArray + ) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults) + val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) { + grantResults[0] == PackageManager.PERMISSION_GRANTED + } else { + false + } + + if (!permissionToRecordAccepted) { + Log.e(TAG, "Audio record is disallowed") + finish() + } + + Log.i(TAG, "Audio record is permitted") + } + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_main) + + ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) + + Log.i(TAG, "Start to initialize first-pass recognizer") + initOnlineRecognizer() + Log.i(TAG, "Finished initializing first-pass recognizer") + + Log.i(TAG, "Start to initialize second-pass recognizer") + initOfflineRecognizer() + Log.i(TAG, "Finished initializing second-pass recognizer") + + recordButton = findViewById(R.id.record_button) + recordButton.setOnClickListener { onclick() } + + textView = findViewById(R.id.my_text) + textView.movementMethod = ScrollingMovementMethod() + } + + private fun onclick() { + if (!isRecording) { + val ret = initMicrophone() + if (!ret) { + Log.e(TAG, "Failed to initialize microphone") + return + } + Log.i(TAG, "state: ${audioRecord?.state}") + audioRecord!!.startRecording() + recordButton.setText(R.string.stop) + isRecording = true + onlineRecognizer.reset(true) + samplesBuffer.clear() + textView.text = "" + lastText = "" + idx = 0 + + recordingThread = thread(true) { + processSamples() + } + Log.i(TAG, "Started recording") + } else { + isRecording = false + audioRecord!!.stop() + audioRecord!!.release() + audioRecord = null + recordButton.setText(R.string.start) + Log.i(TAG, "Stopped recording") + } + } + + private fun processSamples() { + Log.i(TAG, "processing samples") + + val interval = 0.1 // i.e., 100 ms + val bufferSize = (interval * sampleRateInHz).toInt() // in samples + val buffer = ShortArray(bufferSize) + + while (isRecording) { + val ret = audioRecord?.read(buffer, 0, buffer.size) + if (ret != null && ret > 0) { + val samples = FloatArray(ret) { buffer[it] / 32768.0f } + samplesBuffer.add(samples) + + onlineRecognizer.acceptWaveform(samples, sampleRate = sampleRateInHz) + while (onlineRecognizer.isReady()) { + onlineRecognizer.decode() + } + val isEndpoint = onlineRecognizer.isEndpoint() + var textToDisplay = lastText + + var text = onlineRecognizer.text + if (text.isNotBlank()) { + if (lastText.isBlank()) { + // textView.text = "${idx}: ${text}" + textToDisplay = "${idx}: ${text}" + } else { + textToDisplay = "${lastText}\n${idx}: ${text}" + } + } + + if (isEndpoint) { + onlineRecognizer.reset() + + if (text.isNotBlank()) { + text = runSecondPass() + lastText = "${lastText}\n${idx}: ${text}" + idx += 1 + } else { + samplesBuffer.clear() + } + } + + runOnUiThread { + textView.text = textToDisplay.lowercase() + } + } + } + } + + private fun initMicrophone(): Boolean { + if (ActivityCompat.checkSelfPermission( + this, Manifest.permission.RECORD_AUDIO + ) != PackageManager.PERMISSION_GRANTED + ) { + ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) + return false + } + + val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat) + Log.i( + TAG, "buffer size in milliseconds: ${numBytes * 1000.0f / sampleRateInHz}" + ) + + audioRecord = AudioRecord( + audioSource, + sampleRateInHz, + channelConfig, + audioFormat, + numBytes * 2 // a sample has two bytes as we are using 16-bit PCM + ) + return true + } + + private fun initOnlineRecognizer() { + // Please change getModelConfig() to add new models + // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + // for a list of available models + val firstType = 1 + println("Select model type ${firstType} for the first pass") + val config = OnlineRecognizerConfig( + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), + modelConfig = getModelConfig(type = firstType)!!, + endpointConfig = getEndpointConfig(), + enableEndpoint = true, + ) + + onlineRecognizer = SherpaOnnx( + assetManager = application.assets, + config = config, + ) + } + + private fun initOfflineRecognizer() { + // Please change getOfflineModelConfig() to add new models + // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + // for a list of available models + val secondType = 1 + println("Select model type ${secondType} for the second pass") + + val config = OfflineRecognizerConfig( + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), + modelConfig = getOfflineModelConfig(type = secondType)!!, + ) + + offlineRecognizer = SherpaOnnxOffline( + assetManager = application.assets, + config = config, + ) + } + + private fun runSecondPass(): String { + var totalSamples = 0 + for (a in samplesBuffer) { + totalSamples += a.size + } + var i = 0 + + val samples = FloatArray(totalSamples) + + // todo(fangjun): Make it more efficient + for (a in samplesBuffer) { + for (s in a) { + samples[i] = s + i += 1 + } + } + + + val n = maxOf(0, samples.size - 8000) + + samplesBuffer.clear() + samplesBuffer.add(samples.sliceArray(n..samples.size-1)) + + return offlineRecognizer.decode(samples.sliceArray(0..n), sampleRateInHz) + } +} diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt new file mode 100644 index 000000000..99ca65827 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -0,0 +1,375 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + +data class EndpointRule( + var mustContainNonSilence: Boolean, + var minTrailingSilence: Float, + var minUtteranceLength: Float, +) + +data class EndpointConfig( + var rule1: EndpointRule = EndpointRule(false, 2.0f, 0.0f), + var rule2: EndpointRule = EndpointRule(true, 1.2f, 0.0f), + var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f) +) + +data class OnlineTransducerModelConfig( + var encoder: String = "", + var decoder: String = "", + var joiner: String = "", +) + +data class OnlineParaformerModelConfig( + var encoder: String = "", + var decoder: String = "", +) + +data class OnlineModelConfig( + var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), + var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), + var tokens: String, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", + var modelType: String = "", +) + +data class OnlineLMConfig( + var model: String = "", + var scale: Float = 0.5f, +) + +data class FeatureConfig( + var sampleRate: Int = 16000, + var featureDim: Int = 80, +) + +data class OnlineRecognizerConfig( + var featConfig: FeatureConfig = FeatureConfig(), + var modelConfig: OnlineModelConfig, + var lmConfig: OnlineLMConfig = OnlineLMConfig(), + var endpointConfig: EndpointConfig = EndpointConfig(), + var enableEndpoint: Boolean = true, + var decodingMethod: String = "greedy_search", + var maxActivePaths: Int = 4, +) + +data class OfflineTransducerModelConfig( + var encoder: String = "", + var decoder: String = "", + var joiner: String = "", +) + +data class OfflineParaformerModelConfig( + var model: String = "", +) + +data class OfflineWhisperModelConfig( + var encoder: String = "", + var decoder: String = "", +) + +data class OfflineModelConfig( + var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(), + var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(), + var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(), + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", + var modelType: String = "", + var tokens: String, +) + +data class OfflineRecognizerConfig( + var featConfig: FeatureConfig = FeatureConfig(), + var modelConfig: OfflineModelConfig, + // var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it + var decodingMethod: String = "greedy_search", + var maxActivePaths: Int = 4, +) + +class SherpaOnnx( + assetManager: AssetManager? = null, + var config: OnlineRecognizerConfig, +) { + private val ptr: Long + + init { + if (assetManager != null) { + ptr = new(assetManager, config) + } else { + ptr = newFromFile(config) + } + } + + protected fun finalize() { + delete(ptr) + } + + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) + + fun inputFinished() = inputFinished(ptr) + fun reset(recreate: Boolean = false) = reset(ptr, recreate = recreate) + fun decode() = decode(ptr) + fun isEndpoint(): Boolean = isEndpoint(ptr) + fun isReady(): Boolean = isReady(ptr) + + val text: String + get() = getText(ptr) + + private external fun delete(ptr: Long) + + private external fun new( + assetManager: AssetManager, + config: OnlineRecognizerConfig, + ): Long + + private external fun newFromFile( + config: OnlineRecognizerConfig, + ): Long + + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + private external fun inputFinished(ptr: Long) + private external fun getText(ptr: Long): String + private external fun reset(ptr: Long, recreate: Boolean) + private external fun decode(ptr: Long) + private external fun isEndpoint(ptr: Long): Boolean + private external fun isReady(ptr: Long): Boolean + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +class SherpaOnnxOffline( + assetManager: AssetManager? = null, + var config: OfflineRecognizerConfig, +) { + private val ptr: Long + + init { + if (assetManager != null) { + ptr = new(assetManager, config) + } else { + ptr = newFromFile(config) + } + } + + protected fun finalize() { + delete(ptr) + } + + fun decode(samples: FloatArray, sampleRate: Int) = decode(ptr, samples, sampleRate) + + private external fun delete(ptr: Long) + + private external fun new( + assetManager: AssetManager, + config: OfflineRecognizerConfig, + ): Long + + private external fun newFromFile( + config: OfflineRecognizerConfig, + ): Long + + private external fun decode(ptr: Long, samples: FloatArray, sampleRate: Int): String + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own. (It should be straightforward to add a new model +by following the code) + +@param type +0 - csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-zh-14m-2023-02-23 + encoder/joiner int8, decoder float32 + +1 - csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 (English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-en-20m-2023-02-17-english + encoder/joiner int8, decoder fp32 + + */ +fun getModelConfig(type: Int): OnlineModelConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + } + return null +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own LM model. (It should be straightforward to train a new NN LM model +by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py) + +@param type +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + */ +fun getOnlineLMConfig(type: Int): OnlineLMConfig { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" + return OnlineLMConfig( + model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx", + scale = 0.5f, + ) + } + } + return OnlineLMConfig() +} + +// for English models, use a small value for rule2.minTrailingSilence, e.g., 0.8 +fun getEndpointConfig(): EndpointConfig { + return EndpointConfig( + rule1 = EndpointRule(false, 2.4f, 0.0f), + rule2 = EndpointRule(true, 0.8f, 0.0f), + rule3 = EndpointRule(false, 0.0f, 20.0f) + ) +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own. (It should be straightforward to add a new model +by following the code) + +@param type + +0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese + int8 + +1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english + encoder int8, decoder/joiner float32 + +2 - sherpa-onnx-whisper-tiny.en + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en + encoder int8, decoder int8 + +3 - sherpa-onnx-whisper-base.en + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en + encoder int8, decoder int8 + +4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese + encoder/joiner int8, decoder fp32 + + */ +fun getOfflineModelConfig(type: Int): OfflineModelConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28" + return OfflineModelConfig( + paraformer = OfflineParaformerModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "paraformer", + ) + } + + 1 -> { + val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx", + decoder = "$modelDir/decoder-epoch-30-avg-4.onnx", + joiner = "$modelDir/joiner-epoch-30-avg-4.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 2 -> { + val modelDir = "sherpa-onnx-whisper-tiny.en" + return OfflineModelConfig( + whisper = OfflineWhisperModelConfig( + encoder = "$modelDir/tiny.en-encoder.int8.onnx", + decoder = "$modelDir/tiny.en-decoder.int8.onnx", + ), + tokens = "$modelDir/tiny.en-tokens.txt", + modelType = "whisper", + ) + } + + 3 -> { + val modelDir = "sherpa-onnx-whisper-base.en" + return OfflineModelConfig( + whisper = OfflineWhisperModelConfig( + encoder = "$modelDir/base.en-encoder.int8.onnx", + decoder = "$modelDir/base.en-decoder.int8.onnx", + ), + tokens = "$modelDir/base.en-tokens.txt", + modelType = "whisper", + ) + } + + + 4 -> { + val modelDir = "icefall-asr-zipformer-wenetspeech-20230615" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx", + decoder = "$modelDir/decoder-epoch-12-avg-4.onnx", + joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + } + return null +} diff --git a/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt new file mode 120000 index 000000000..d65321ad0 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt @@ -0,0 +1 @@ +../../../../../../../../../SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/jniLibs/.gitkeep b/android/SherpaOnnx2Pass/app/src/main/jniLibs/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/android/SherpaOnnx2Pass/app/src/main/jniLibs/arm64-v8a/.gitkeep b/android/SherpaOnnx2Pass/app/src/main/jniLibs/arm64-v8a/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/android/SherpaOnnx2Pass/app/src/main/jniLibs/armeabi-v7a/.gitkeep b/android/SherpaOnnx2Pass/app/src/main/jniLibs/armeabi-v7a/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/android/SherpaOnnx2Pass/app/src/main/jniLibs/x86/.gitkeep b/android/SherpaOnnx2Pass/app/src/main/jniLibs/x86/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/android/SherpaOnnx2Pass/app/src/main/jniLibs/x86_64/.gitkeep b/android/SherpaOnnx2Pass/app/src/main/jniLibs/x86_64/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/android/SherpaOnnx2Pass/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/android/SherpaOnnx2Pass/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..2b068d114 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/android/SherpaOnnx2Pass/app/src/main/res/drawable/ic_launcher_background.xml b/android/SherpaOnnx2Pass/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..07d5da9cb --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/SherpaOnnx2Pass/app/src/main/res/layout/activity_main.xml b/android/SherpaOnnx2Pass/app/src/main/res/layout/activity_main.xml new file mode 100644 index 000000000..f9b35e862 --- /dev/null +++ b/android/SherpaOnnx2Pass/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,39 @@ + + + + + + + +