diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 7e9f6e29b..a9ba0a95e 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -12,6 +12,13 @@ #include #include +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-lm.h" @@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { - if (!config_.hotwords_file.empty()) { - InitHotwords(); - } if (sym_.contains("")) { unk_id_ = sym_[""]; } if (config.decoding_method == "modified_beam_search") { + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } + if (!config_.lm_config.model.empty()) { lm_ = OnlineLM::Create(config.lm_config); } @@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (config.decoding_method == "modified_beam_search") { +#if 0 + // TODO(fangjun): Implement it + if (!config_.lm_config.model.empty()) { + lm_ = OnlineLM::Create(mgr, config.lm_config); + } +#endif + + if (!config_.hotwords_file.empty()) { + InitHotwords(mgr); + } + decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, config_.lm_config.scale, unk_id_); @@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { s->Reset(); } + private: void InitHotwords() { // each line in hotwords_file contains space-separated words @@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { std::make_shared(hotwords_, config_.hotwords_score); } - private: +#if __ANDROID_API__ >= 9 + void InitHotwords(AAssetManager *mgr) { + // each line in hotwords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.hotwords_file); + + std::istrstream is(buf.data(), buf.size()); + + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, sym_, &hotwords_)) { + SHERPA_ONNX_LOGE("Encode hotwords failed."); + exit(-1); + } + hotwords_graph_ = + std::make_shared(hotwords_, config_.hotwords_score); + } +#endif + void InitOnlineStream(OnlineStream *stream) const { auto r = decoder_->GetEmptyResult();