From 4f55ce1c010ca4ee23ea3f5efb0a94aeb32dfa92 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 2 Apr 2024 23:11:39 +0800 Subject: [PATCH 1/3] WIP: Begin to add HLG decoding for streaming CTC models --- cmake/kaldi-decoder.cmake | 16 +-- sherpa-onnx/csrc/CMakeLists.txt | 2 + .../csrc/offline-ctc-fst-decoder-config.cc | 10 ++ .../csrc/offline-ctc-fst-decoder-config.h | 1 + sherpa-onnx/csrc/offline-ctc-fst-decoder.cc | 2 +- sherpa-onnx/csrc/offline-recognizer.cc | 6 + sherpa-onnx/csrc/online-ctc-decoder.h | 11 +- .../csrc/online-ctc-fst-decoder-config.cc | 42 +++++++ .../csrc/online-ctc-fst-decoder-config.h | 32 +++++ sherpa-onnx/csrc/online-ctc-fst-decoder.cc | 111 ++++++++++++++++++ sherpa-onnx/csrc/online-ctc-fst-decoder.h | 36 ++++++ .../csrc/online-ctc-greedy-search-decoder.cc | 3 +- .../csrc/online-ctc-greedy-search-decoder.h | 3 +- sherpa-onnx/csrc/online-recognizer-ctc-impl.h | 14 ++- sherpa-onnx/csrc/online-recognizer.cc | 27 +++-- sherpa-onnx/csrc/online-recognizer.h | 20 ++-- sherpa-onnx/csrc/online-stream.cc | 27 +++++ sherpa-onnx/csrc/online-stream.h | 6 + 18 files changed, 335 insertions(+), 34 deletions(-) create mode 100644 sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc create mode 100644 sherpa-onnx/csrc/online-ctc-fst-decoder-config.h create mode 100644 sherpa-onnx/csrc/online-ctc-fst-decoder.cc create mode 100644 sherpa-onnx/csrc/online-ctc-fst-decoder.h diff --git a/cmake/kaldi-decoder.cmake b/cmake/kaldi-decoder.cmake index 6ebd3f139..99ebf9aa0 100644 --- a/cmake/kaldi-decoder.cmake +++ b/cmake/kaldi-decoder.cmake @@ -1,9 +1,9 @@ function(download_kaldi_decoder) include(FetchContent) - set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz") - set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz") - set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601") + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz") + set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz") + set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff") set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE) @@ -12,11 +12,11 @@ function(download_kaldi_decoder) # If you don't have access to the Internet, # please pre-download kaldi-decoder set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz - ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz - ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz - /tmp/kaldi-decoder-0.2.4.tar.gz - /star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz + $ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz + ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz + ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz + /tmp/kaldi-decoder-0.2.5.tar.gz + /star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 86dbc12c9..1ebdc6264 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -51,6 +51,8 @@ set(sources offline-zipformer-ctc-model-config.cc offline-zipformer-ctc-model.cc online-conformer-transducer-model.cc + online-ctc-fst-decoder-config.cc + online-ctc-fst-decoder.cc online-ctc-greedy-search-decoder.cc online-ctc-model.cc online-lm-config.cc diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc index bd4126685..fa8353319 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc @@ -7,6 +7,9 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + namespace sherpa_onnx { std::string OfflineCtcFstDecoderConfig::ToString() const { @@ -29,4 +32,11 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { "Decoder max active states. Larger->slower; more accurate"); } +bool OfflineCtcFstDecoderConfig::Validate() const { + if (!graph.empty() && !FileExists(graph)) { + SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); + return false; + } +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h index 6d7f70aed..b87fe89e6 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h @@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig { std::string ToString() const; void Register(ParseOptions *po); + bool Validate() const; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc index efee65a72..e54274df4 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc @@ -20,7 +20,7 @@ namespace sherpa_onnx { // @param filename Path to a StdVectorFst or StdConstFst graph // @return The caller should free the returned pointer using `delete` to // avoid memory leak. -static fst::Fst *ReadGraph(const std::string &filename) { +fst::Fst *ReadGraph(const std::string &filename) { // read decoding network FST std::ifstream is(filename, std::ios::binary); if (!is.good()) { diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 5c10eb3a1..8005cc855 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const { return false; } + if (!ctc_fst_decoder_config.graph.empty() && + !ctc_fst_decoder_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in fst_decoder"); + return false; + } + return model_config.Validate(); } diff --git a/sherpa-onnx/csrc/online-ctc-decoder.h b/sherpa-onnx/csrc/online-ctc-decoder.h index 6690e1bb2..40a36ebac 100644 --- a/sherpa-onnx/csrc/online-ctc-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-decoder.h @@ -7,10 +7,13 @@ #include +#include "kaldi-decoder/csrc/faster-decoder.h" #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { +class OnlineStream; + struct OnlineCtcDecoderResult { /// Number of frames after subsampling we have decoded so far int32_t frame_offset = 0; @@ -37,7 +40,13 @@ class OnlineCtcDecoder { * @param results Input & Output parameters.. */ virtual void Decode(Ort::Value log_probs, - std::vector *results) = 0; + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) = 0; + + virtual std::unique_ptr CreateFasterDecoder() + const { + return nullptr; + } }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc new file mode 100644 index 000000000..f9e6ffc6b --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +std::string OnlineCtcFstDecoderConfig::ToString() const { + std::ostringstream os; + + os << "OnlineCtcFstDecoderConfig("; + os << "graph=\"" << graph << "\", "; + os << "max_active=" << max_active << ")"; + + return os.str(); +} + +void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) { + std::string prefix = "ctc"; + ParseOptions p(prefix, po); + + p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); + + p.Register("max-active", &max_active, + "Decoder max active states. Larger->slower; more accurate"); +} + +bool OnlineCtcFstDecoderConfig::Validate() const { + if (!graph.empty() && !FileExists(graph)) { + SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); + return false; + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder-config.h b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.h new file mode 100644 index 000000000..6f9e5b156 --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.h @@ -0,0 +1,32 @@ +// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlineCtcFstDecoderConfig { + // Path to H.fst, HL.fst or HLG.fst + std::string graph; + int32_t max_active = 3000; + + OnlineCtcFstDecoderConfig() = default; + + OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active) + : graph(graph), max_active(max_active) {} + + std::string ToString() const; + + void Register(ParseOptions *po); + bool Validate() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc new file mode 100644 index 000000000..4c0011f7e --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc @@ -0,0 +1,111 @@ +// sherpa-onnx/csrc/online-ctc-fst-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" + +#include +#include +#include + +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-ctc.h" +#include "kaldifst/csrc/fstext-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-stream.h" + +namespace sherpa_onnx { + +// defined in ./offline-ctc-fst-decoder.cc +fst::Fst *ReadGraph(const std::string &filename); + +OnlineCtcFstDecoder::OnlineCtcFstDecoder( + const OnlineCtcFstDecoderConfig &config) + : config_(config), fst_(ReadGraph(config.graph)) { + options_.max_active = config_.max_active; +} + +std::unique_ptr +OnlineCtcFstDecoder::CreateFasterDecoder() const { + return std::make_unique(*fst_, options_); +} + +static void DecodeOne(const float *log_probs, int32_t num_rows, + int32_t num_cols, OnlineCtcDecoderResult *result, + OnlineStream *s) { + int32_t &processed_frames = s->GetFasterDecoderProcessedFrames(); + kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols, + processed_frames); + + kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder(); + if (processed_frames == 0) { + decoder->InitDecoding(); + } + decoder->AdvanceDecoding(&decodable); + + if (decoder->ReachedFinal()) { + fst::VectorFst fst_out; + bool ok = decoder->GetBestPath(&fst_out); + if (ok) { + std::vector isymbols_out; + std::vector osymbols_out; + ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out, + nullptr); + SHERPA_ONNX_LOGE("num tokens: %d\n", + static_cast(isymbols_out.size())); + std::vector tokens; + tokens.reserve(isymbols_out.size()); + std::ostringstream os; + int32_t prev_id = -1; + for (auto i : isymbols_out) { + i -= 1; + if (i != 0 && i != prev_id) { + tokens.push_back(i); + } + prev_id = i; + // TODO(fangjun): set num_trailing_blanks + } + + result->tokens = std::move(tokens); + } else { + result->tokens.clear(); + } + } else { + result->tokens.clear(); + } + + processed_frames += num_rows; +} + +void OnlineCtcFstDecoder::Decode(Ort::Value log_probs, + std::vector *results, + OnlineStream **ss, int32_t n) { + std::vector log_probs_shape = + log_probs.GetTensorTypeAndShapeInfo().GetShape(); + + if (log_probs_shape[0] != results->size()) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", + static_cast(log_probs_shape[0]), + static_cast(results->size())); + exit(-1); + } + + if (log_probs_shape[0] != n) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", + static_cast(log_probs_shape[0]), n); + exit(-1); + } + + int32_t batch_size = static_cast(log_probs_shape[0]); + int32_t num_frames = static_cast(log_probs_shape[1]); + int32_t vocab_size = static_cast(log_probs_shape[2]); + + const float *p = log_probs.GetTensorData(); + + for (int32_t i = 0; i != batch_size; ++i) { + DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, + &(*results)[i], ss[i]); + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.h b/sherpa-onnx/csrc/online-ctc-fst-decoder.h new file mode 100644 index 000000000..3612b27e7 --- /dev/null +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.h @@ -0,0 +1,36 @@ +// sherpa-onnx/csrc/online-ctc-fst-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_FST_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_FST_DECODER_H_ + +#include + +#include "fst/fst.h" +#include "sherpa-onnx/csrc/online-ctc-decoder.h" +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" + +namespace sherpa_onnx { + +class OnlineCtcFstDecoder : public OnlineCtcDecoder { + public: + explicit OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config); + + void Decode(Ort::Value log_probs, + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) override; + + std::unique_ptr CreateFasterDecoder() + const override; + + private: + OnlineCtcFstDecoderConfig config_; + kaldi_decoder::FasterDecoderOptions options_; + + std::unique_ptr> fst_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_FST_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc index 909373e71..e813c9873 100644 --- a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc @@ -13,7 +13,8 @@ namespace sherpa_onnx { void OnlineCtcGreedySearchDecoder::Decode( - Ort::Value log_probs, std::vector *results) { + Ort::Value log_probs, std::vector *results, + OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { std::vector log_probs_shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h index fc724f2c3..0af37593e 100644 --- a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h @@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { : blank_id_(blank_id) {} void Decode(Ort::Value log_probs, - std::vector *results) override; + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) override; private: int32_t blank_id_; diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 5697a77e8..1790e1879 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -16,6 +16,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-ctc-decoder.h" +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-ctc-model.h" #include "sherpa-onnx/csrc/online-recognizer-impl.h" @@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { std::unique_ptr CreateStream() const override { auto stream = std::make_unique(config_.feat_config); stream->SetStates(model_->GetInitStates()); + stream->SetFasterDecoder(decoder_->CreateFasterDecoder()); return stream; } @@ -221,7 +223,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { private: void InitDecoder() { - if (config_.decoding_method == "greedy_search") { + if (!config_.ctc_fst_decoder_config.graph.empty()) { + decoder_ = + std::make_unique(config_.ctc_fst_decoder_config); + } else if (config_.decoding_method == "greedy_search") { if (!sym_.contains("") && !sym_.contains("") && !sym_.contains("")) { SHERPA_ONNX_LOGE( @@ -243,8 +248,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique(blank_id); } else { - SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config_.decoding_method.c_str()); + SHERPA_ONNX_LOGE( + "Unsupported decoding method: %s for streaming CTC models", + config_.decoding_method.c_str()); exit(-1); } } @@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { std::vector results(1); results[0] = std::move(s->GetCtcResult()); - decoder_->Decode(std::move(out[0]), &results); + decoder_->Decode(std::move(out[0]), &results, &s, 1); s->SetCtcResult(results[0]); } diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index ea7e9f905..41bb9bbde 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -19,13 +19,13 @@ namespace sherpa_onnx { /// Helper for `OnlineRecognizerResult::AsJsonString()` -template -std::string VecToString(const std::vector& vec, int32_t precision = 6) { +template +std::string VecToString(const std::vector &vec, int32_t precision = 6) { std::ostringstream oss; oss << std::fixed << std::setprecision(precision); oss << "[ "; std::string sep = ""; - for (const auto& item : vec) { + for (const auto &item : vec) { oss << sep << item; sep = ", "; } @@ -34,13 +34,13 @@ std::string VecToString(const std::vector& vec, int32_t precision = 6) { } /// Helper for `OnlineRecognizerResult::AsJsonString()` -template<> // explicit specialization for T = std::string -std::string VecToString(const std::vector& vec, +template <> // explicit specialization for T = std::string +std::string VecToString(const std::vector &vec, int32_t) { // ignore 2nd arg std::ostringstream oss; oss << "[ "; std::string sep = ""; - for (const auto& item : vec) { + for (const auto &item : vec) { oss << sep << "\"" << item << "\""; sep = ", "; } @@ -51,15 +51,17 @@ std::string VecToString(const std::vector& vec, std::string OnlineRecognizerResult::AsJsonString() const { std::ostringstream os; os << "{ "; - os << "\"text\": " << "\"" << text << "\"" << ", "; + os << "\"text\": " + << "\"" << text << "\"" + << ", "; os << "\"tokens\": " << VecToString(tokens) << ", "; os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; os << "\"segment\": " << segment << ", "; - os << "\"start_time\": " << std::fixed << std::setprecision(2) - << start_time << ", "; + os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time + << ", "; os << "\"is_final\": " << (is_final ? "true" : "false"); os << "}"; return os.str(); @@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { model_config.Register(po); endpoint_config.Register(po); lm_config.Register(po); + ctc_fst_decoder_config.Register(po); po->Register("enable-endpoint", &enable_endpoint, "True to enable endpoint detection. False to disable it."); @@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const { return false; } + if (!ctc_fst_decoder_config.graph.empty() && + !ctc_fst_decoder_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config"); + return false; + } + return model_config.Validate(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index ec8875e68..89382ad11 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -16,6 +16,7 @@ #include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-stream.h" @@ -80,6 +81,7 @@ struct OnlineRecognizerConfig { OnlineModelConfig model_config; OnlineLMConfig lm_config; EndpointConfig endpoint_config; + OnlineCtcFstDecoderConfig ctc_fst_decoder_config; bool enable_endpoint = true; std::string decoding_method = "greedy_search"; @@ -96,19 +98,19 @@ struct OnlineRecognizerConfig { OnlineRecognizerConfig() = default; - OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, - const OnlineModelConfig &model_config, - const OnlineLMConfig &lm_config, - const EndpointConfig &endpoint_config, - bool enable_endpoint, - const std::string &decoding_method, - int32_t max_active_paths, - const std::string &hotwords_file, float hotwords_score, - float blank_penalty) + OnlineRecognizerConfig( + const FeatureExtractorConfig &feat_config, + const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, + const EndpointConfig &endpoint_config, + const OnlineCtcFstDecoderConfig &&ctc_fst_decoder_config, + bool enable_endpoint, const std::string &decoding_method, + int32_t max_active_paths, const std::string &hotwords_file, + float hotwords_score, float blank_penalty) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), endpoint_config(endpoint_config), + ctc_fst_decoder_config(ctc_fst_decoder_config), enable_endpoint(enable_endpoint), decoding_method(decoding_method), max_active_paths(max_active_paths), diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index aaddfb545..52cfb899f 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -104,6 +104,18 @@ class OnlineStream::Impl { return paraformer_alpha_cache_; } + void SetFasterDecoder(std::unique_ptr decoder) { + faster_decoder_ = std::move(decoder); + } + + kaldi_decoder::FasterDecoder *GetFasterDecoder() const { + return faster_decoder_.get(); + } + + int32_t &GetFasterDecoderProcessedFrames() { + return faster_decoder_processed_frames_; + } + private: FeatureExtractor feat_extractor_; /// For contextual-biasing @@ -121,6 +133,8 @@ class OnlineStream::Impl { std::vector paraformer_encoder_out_cache_; std::vector paraformer_alpha_cache_; OnlineParaformerDecoderResult paraformer_result_; + std::unique_ptr faster_decoder_; + int32_t faster_decoder_processed_frames_ = 0; }; OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, @@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { return impl_->GetContextGraph(); } +void OnlineStream::SetFasterDecoder( + std::unique_ptr decoder) { + impl_->SetFasterDecoder(std::move(decoder)); +} + +kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const { + return impl_->GetFasterDecoder(); +} + +int32_t &OnlineStream::GetFasterDecoderProcessedFrames() { + return impl_->GetFasterDecoderProcessedFrames(); +} + std::vector &OnlineStream::GetParaformerFeatCache() { return impl_->GetParaformerFeatCache(); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index f648ca5dc..49b7f7402 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -8,6 +8,7 @@ #include #include +#include "kaldi-decoder/csrc/faster-decoder.h" #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/context-graph.h" #include "sherpa-onnx/csrc/features.h" @@ -97,6 +98,11 @@ class OnlineStream { */ const ContextGraphPtr &GetContextGraph() const; + // for online ctc decoder + void SetFasterDecoder(std::unique_ptr decoder); + kaldi_decoder::FasterDecoder *GetFasterDecoder() const; + int32_t &GetFasterDecoderProcessedFrames(); + // for streaming paraformer std::vector &GetParaformerFeatCache(); std::vector &GetParaformerEncoderOutCache(); From 66ad1523a3be155a2344e7d56e7b83e56a16b141 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 3 Apr 2024 18:01:46 +0800 Subject: [PATCH 2/3] Add fst-based decoding --- sherpa-onnx/csrc/online-ctc-fst-decoder.cc | 42 ++++++++++++------- sherpa-onnx/csrc/online-ctc-fst-decoder.h | 4 +- sherpa-onnx/csrc/online-recognizer-ctc-impl.h | 42 +++++++++---------- 3 files changed, 51 insertions(+), 37 deletions(-) diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc index 4c0011f7e..76524538c 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc @@ -20,8 +20,8 @@ namespace sherpa_onnx { fst::Fst *ReadGraph(const std::string &filename); OnlineCtcFstDecoder::OnlineCtcFstDecoder( - const OnlineCtcFstDecoderConfig &config) - : config_(config), fst_(ReadGraph(config.graph)) { + const OnlineCtcFstDecoderConfig &config, int32_t blank_id) + : config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) { options_.max_active = config_.max_active; } @@ -32,7 +32,7 @@ OnlineCtcFstDecoder::CreateFasterDecoder() const { static void DecodeOne(const float *log_probs, int32_t num_rows, int32_t num_cols, OnlineCtcDecoderResult *result, - OnlineStream *s) { + OnlineStream *s, int32_t blank_id) { int32_t &processed_frames = s->GetFasterDecoderProcessedFrames(); kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols, processed_frames); @@ -41,6 +41,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, if (processed_frames == 0) { decoder->InitDecoding(); } + decoder->AdvanceDecoding(&decodable); if (decoder->ReachedFinal()) { @@ -48,30 +49,41 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, bool ok = decoder->GetBestPath(&fst_out); if (ok) { std::vector isymbols_out; - std::vector osymbols_out; - ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out, - nullptr); - SHERPA_ONNX_LOGE("num tokens: %d\n", - static_cast(isymbols_out.size())); + std::vector osymbols_out_unused; + ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, + &osymbols_out_unused, nullptr); std::vector tokens; tokens.reserve(isymbols_out.size()); + + std::vector timestamps; + timestamps.reserve(isymbols_out.size()); + std::ostringstream os; int32_t prev_id = -1; + int32_t num_trailing_blanks = 0; + int32_t f = 0; // frame number + for (auto i : isymbols_out) { i -= 1; - if (i != 0 && i != prev_id) { + + if (i == blank_id) { + num_trailing_blanks += 1; + } else { + num_trailing_blanks = 0; + } + + if (i != blank_id && i != prev_id) { tokens.push_back(i); + timestamps.push_back(f); } prev_id = i; - // TODO(fangjun): set num_trailing_blanks + f += 1; } result->tokens = std::move(tokens); - } else { - result->tokens.clear(); + result->timestamps = std::move(timestamps); + // no need to set frame_offset } - } else { - result->tokens.clear(); } processed_frames += num_rows; @@ -104,7 +116,7 @@ void OnlineCtcFstDecoder::Decode(Ort::Value log_probs, for (int32_t i = 0; i != batch_size; ++i) { DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, - &(*results)[i], ss[i]); + &(*results)[i], ss[i], blank_id_); } } diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.h b/sherpa-onnx/csrc/online-ctc-fst-decoder.h index 3612b27e7..aa801c1a1 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.h @@ -15,7 +15,8 @@ namespace sherpa_onnx { class OnlineCtcFstDecoder : public OnlineCtcDecoder { public: - explicit OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config); + OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, + int32_t blank_id); void Decode(Ort::Value log_probs, std::vector *results, @@ -29,6 +30,7 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder { kaldi_decoder::FasterDecoderOptions options_; std::unique_ptr> fst_; + int32_t blank_id_ = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 1790e1879..0cbae1528 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -223,29 +223,29 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { private: void InitDecoder() { - if (!config_.ctc_fst_decoder_config.graph.empty()) { - decoder_ = - std::make_unique(config_.ctc_fst_decoder_config); - } else if (config_.decoding_method == "greedy_search") { - if (!sym_.contains("") && !sym_.contains("") && - !sym_.contains("")) { - SHERPA_ONNX_LOGE( - "We expect that tokens.txt contains " - "the symbol or or and its ID."); - exit(-1); - } + if (!sym_.contains("") && !sym_.contains("") && + !sym_.contains("")) { + SHERPA_ONNX_LOGE( + "We expect that tokens.txt contains " + "the symbol or or and its ID."); + exit(-1); + } - int32_t blank_id = 0; - if (sym_.contains("")) { - blank_id = sym_[""]; - } else if (sym_.contains("")) { - // for tdnn models of the yesno recipe from icefall - blank_id = sym_[""]; - } else if (sym_.contains("")) { - // for WeNet CTC models - blank_id = sym_[""]; - } + int32_t blank_id = 0; + if (sym_.contains("")) { + blank_id = sym_[""]; + } else if (sym_.contains("")) { + // for tdnn models of the yesno recipe from icefall + blank_id = sym_[""]; + } else if (sym_.contains("")) { + // for WeNet CTC models + blank_id = sym_[""]; + } + if (!config_.ctc_fst_decoder_config.graph.empty()) { + decoder_ = std::make_unique( + config_.ctc_fst_decoder_config, blank_id); + } else if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique(blank_id); } else { SHERPA_ONNX_LOGE( From c162950b9ad1209cbbfdbcf9adbf906e9864c188 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 3 Apr 2024 20:27:32 +0800 Subject: [PATCH 3/3] Add CI tests --- .github/scripts/test-online-ctc.sh | 22 ++- .github/scripts/test-python.sh | 19 +- .github/workflows/linux.yaml | 15 +- .../online-zipformer-ctc-hlg-decode-file.py | 172 ++++++++++++++++++ .../csrc/offline-ctc-fst-decoder-config.cc | 1 + sherpa-onnx/csrc/online-ctc-decoder.h | 1 + .../csrc/online-ctc-fst-decoder-config.cc | 10 +- sherpa-onnx/csrc/online-ctc-fst-decoder.cc | 2 + sherpa-onnx/csrc/online-ctc-fst-decoder.h | 7 +- sherpa-onnx/csrc/online-recognizer-ctc-impl.h | 2 +- sherpa-onnx/csrc/online-recognizer.cc | 1 + sherpa-onnx/csrc/online-recognizer.h | 2 +- sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../csrc/online-ctc-fst-decoder-config.cc | 23 +++ .../csrc/online-ctc-fst-decoder-config.h | 16 ++ sherpa-onnx/python/csrc/online-recognizer.cc | 42 ++--- sherpa-onnx/python/csrc/sherpa-onnx.cc | 2 + .../python/sherpa_onnx/online_recognizer.py | 15 ++ 18 files changed, 312 insertions(+), 41 deletions(-) create mode 100755 python-api-examples/online-zipformer-ctc-hlg-decode-file.py create mode 100644 sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc create mode 100644 sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h diff --git a/.github/scripts/test-online-ctc.sh b/.github/scripts/test-online-ctc.sh index fa331be6f..7c631dd05 100755 --- a/.github/scripts/test-online-ctc.sh +++ b/.github/scripts/test-online-ctc.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -set -e +set -ex log() { # This function is from espnet @@ -13,6 +13,26 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run streaming Zipformer2 CTC HLG decoding " +log "------------------------------------------------------------" +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 +ls -lh $repo +echo "pwd: $PWD" + +$EXE \ + --zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + --ctc-graph=$repo/HLG.fst \ + --tokens=$repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 + log "------------------------------------------------------------" log "Run streaming Zipformer2 CTC " log "------------------------------------------------------------" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index b454d5310..3604a0059 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -set -e +set -ex log() { # This function is from espnet @@ -8,6 +8,23 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test streaming zipformer2 ctc HLG decoding" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 + +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \ + --debug 1 \ + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \ + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \ + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav + +rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 + + mkdir -p /tmp/icefall-models dir=/tmp/icefall-models diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index b1f3fa91b..b32362a3d 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -124,6 +124,14 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: build/bin/* + - name: Test online CTC + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx + + .github/scripts/test-online-ctc.sh + - name: Test C API shell: bash run: | @@ -149,13 +157,6 @@ jobs: .github/scripts/test-kws.sh - - name: Test online CTC - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx - - .github/scripts/test-online-ctc.sh - name: Test offline Whisper if: matrix.build_type != 'Debug' diff --git a/python-api-examples/online-zipformer-ctc-hlg-decode-file.py b/python-api-examples/online-zipformer-ctc-hlg-decode-file.py new file mode 100755 index 000000000..869840c7c --- /dev/null +++ b/python-api-examples/online-zipformer-ctc-hlg-decode-file.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# This file shows how to use a streaming zipformer CTC model and an HLG +# graph for decoding. +# +# We use the following model as an example +# +""" +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \ + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \ + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \ + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav + +""" +# (The above model is from https://github.com/k2-fsa/icefall/pull/1557) + +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_onnx + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the ONNX model", + ) + + parser.add_argument( + "--graph", + type=str, + required=True, + help="Path to H.fst, HL.fst, or HLG.fst", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--debug", + type=int, + default=0, + help="Valid values: 1, 0", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to decode. It must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + print(vars(args)) + + assert_file_exists(args.tokens) + assert_file_exists(args.graph) + assert_file_exists(args.model) + + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( + tokens=args.tokens, + model=args.model, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + ctc_graph=args.graph, + ) + + wave_filename = args.sound_file + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + + print("Started") + + start_time = time.time() + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while recognizer.is_ready(s): + recognizer.decode_stream(s) + + result = recognizer.get_result(s).lower() + end_time = time.time() + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / duration + print(f"num_threads: {args.num_threads}") + print(f"Wave duration: {duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + print(result) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc index fa8353319..481ecaef5 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc @@ -37,6 +37,7 @@ bool OfflineCtcFstDecoderConfig::Validate() const { SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); return false; } + return true; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-decoder.h b/sherpa-onnx/csrc/online-ctc-decoder.h index 40a36ebac..28809e39f 100644 --- a/sherpa-onnx/csrc/online-ctc-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-decoder.h @@ -5,6 +5,7 @@ #ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ #define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ +#include #include #include "kaldi-decoder/csrc/faster-decoder.h" diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc index f9e6ffc6b..9eccebea7 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc @@ -23,13 +23,10 @@ std::string OnlineCtcFstDecoderConfig::ToString() const { } void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) { - std::string prefix = "ctc"; - ParseOptions p(prefix, po); + po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); - p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); - - p.Register("max-active", &max_active, - "Decoder max active states. Larger->slower; more accurate"); + po->Register("ctc-max-active", &max_active, + "Decoder max active states. Larger->slower; more accurate"); } bool OnlineCtcFstDecoderConfig::Validate() const { @@ -37,6 +34,7 @@ bool OnlineCtcFstDecoderConfig::Validate() const { SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); return false; } + return true; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc index 76524538c..7619e0db5 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc @@ -5,6 +5,8 @@ #include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" #include +#include +#include #include #include diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.h b/sherpa-onnx/csrc/online-ctc-fst-decoder.h index aa801c1a1..992276d6b 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.h @@ -2,9 +2,10 @@ // // Copyright (c) 2024 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_ONLINE_FST_DECODER_H_ -#define SHERPA_ONNX_CSRC_ONLINE_FST_DECODER_H_ +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ +#include #include #include "fst/fst.h" @@ -35,4 +36,4 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder { } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_ONLINE_FST_DECODER_H_ +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 0cbae1528..4b137e299 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -167,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { std::vector> next_states = model_->UnStackStates(std::move(out_states)); - decoder_->Decode(std::move(out[0]), &results); + decoder_->Decode(std::move(out[0]), &results, ss, n); for (int32_t k = 0; k != n; ++k) { ss[k]->SetCtcResult(results[k]); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 41bb9bbde..5d3445659 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -136,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const { os << "model_config=" << model_config.ToString() << ", "; os << "lm_config=" << lm_config.ToString() << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", "; + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", "; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_score=" << hotwords_score << ", "; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 89382ad11..e7f1b38d7 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -102,7 +102,7 @@ struct OnlineRecognizerConfig { const FeatureExtractorConfig &feat_config, const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, const EndpointConfig &endpoint_config, - const OnlineCtcFstDecoderConfig &&ctc_fst_decoder_config, + const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, bool enable_endpoint, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, float blank_penalty) diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 9e5af779d..53aebd78c 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -18,6 +18,7 @@ set(srcs offline-wenet-ctc-model-config.cc offline-whisper-model-config.cc offline-zipformer-ctc-model-config.cc + online-ctc-fst-decoder-config.cc online-lm-config.cc online-model-config.cc online-paraformer-model-config.cc diff --git a/sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc b/sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc new file mode 100644 index 000000000..116278ec0 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" + +#include + +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" + +namespace sherpa_onnx { + +void PybindOnlineCtcFstDecoderConfig(py::module *m) { + using PyClass = OnlineCtcFstDecoderConfig; + py::class_(*m, "OnlineCtcFstDecoderConfig") + .def(py::init(), py::arg("graph") = "", + py::arg("max_active") = 3000) + .def_readwrite("graph", &PyClass::graph) + .def_readwrite("max_active", &PyClass::max_active) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h b/sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h new file mode 100644 index 000000000..00727646b --- /dev/null +++ b/sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineCtcFstDecoderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 0213bd7b2..bd98c94e2 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) { "tokens", [](PyClass &self) -> std::vector { return self.tokens; }) .def_property_readonly( - "start_time", - [](PyClass &self) -> float { return self.start_time; }) + "start_time", [](PyClass &self) -> float { return self.start_time; }) .def_property_readonly( "timestamps", [](PyClass &self) -> std::vector { return self.timestamps; }) @@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) { .def_property_readonly( "lm_probs", [](PyClass &self) -> std::vector { return self.lm_probs; }) + .def_property_readonly("context_scores", + [](PyClass &self) -> std::vector { + return self.context_scores; + }) .def_property_readonly( - "context_scores", - [](PyClass &self) -> std::vector { - return self.context_scores; - }) + "segment", [](PyClass &self) -> int32_t { return self.segment; }) .def_property_readonly( - "segment", - [](PyClass &self) -> int32_t { return self.segment; }) - .def_property_readonly( - "is_final", - [](PyClass &self) -> bool { return self.is_final; }) + "is_final", [](PyClass &self) -> bool { return self.is_final; }) .def("as_json_string", &PyClass::AsJsonString, - py::call_guard()); + py::call_guard()); } static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") - .def(py::init(), - py::arg("feat_config"), py::arg("model_config"), - py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), - py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) + .def( + py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("lm_config") = OnlineLMConfig(), + py::arg("endpoint_config") = EndpointConfig(), + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("max_active_paths", &PyClass::max_active_paths) diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 62c64ec72..4952e150b 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -15,6 +15,7 @@ #include "sherpa-onnx/python/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-recognizer.h" #include "sherpa-onnx/python/csrc/offline-stream.h" +#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/online-lm-config.h" #include "sherpa-onnx/python/csrc/online-model-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" @@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { m.doc() = "pybind11 binding of sherpa-onnx"; PybindFeatures(&m); + PybindOnlineCtcFstDecoderConfig(&m); PybindOnlineModelConfig(&m); PybindOnlineLMConfig(&m); PybindOnlineStream(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 105043399..a82ab1703 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -16,6 +16,7 @@ OnlineTransducerModelConfig, OnlineWenetCtcModelConfig, OnlineZipformer2CtcModelConfig, + OnlineCtcFstDecoderConfig, ) @@ -314,6 +315,8 @@ def from_zipformer2_ctc( rule2_min_trailing_silence: float = 1.2, rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", + ctc_graph: str = "", + ctc_max_active: int = 3000, provider: str = "cpu", ): """ @@ -355,6 +358,12 @@ def from_zipformer2_ctc( is detected. decoding_method: The only valid value is greedy_search. + ctc_graph: + If not empty, decoding_method is ignored. It contains the path to + H.fst, HL.fst, or HLG.fst + ctc_max_active: + Used only when ctc_graph is not empty. It specifies the maximum + active paths at a time. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ @@ -384,10 +393,16 @@ def from_zipformer2_ctc( rule3_min_utterance_length=rule3_min_utterance_length, ) + ctc_fst_decoder_config = OnlineCtcFstDecoderConfig( + graph=ctc_graph, + max_active=ctc_max_active, + ) + recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, endpoint_config=endpoint_config, + ctc_fst_decoder_config=ctc_fst_decoder_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, )