Skip to content

Commit

Permalink
Fix reading hotwords file for android
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 10, 2023
1 parent 8455057 commit c33c1f5
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include <strstream>

#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-lm.h"
Expand Down Expand Up @@ -62,14 +69,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}

if (config.decoding_method == "modified_beam_search") {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}

if (!config_.lm_config.model.empty()) {
lm_ = OnlineLM::Create(config.lm_config);
}
Expand Down Expand Up @@ -99,6 +107,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

if (config.decoding_method == "modified_beam_search") {
#if 0
// TODO(fangjun): Implement it
if (!config_.lm_config.model.empty()) {
lm_ = OnlineLM::Create(mgr, config.lm_config);
}
#endif

if (!config_.hotwords_file.empty()) {
InitHotwords(mgr);
}

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_);
Expand Down Expand Up @@ -268,6 +287,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->Reset();
}

private:
void InitHotwords() {
// each line in hotwords_file contains space-separated words

Expand All @@ -286,7 +306,29 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}

private:
#if __ANDROID_API__ >= 9
void InitHotwords(AAssetManager *mgr) {
// each line in hotwords_file contains space-separated words

auto buf = ReadFile(mgr, config_.hotwords_file);

std::istrstream is(buf.data(), buf.size());

if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}

if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
#endif

void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();

Expand Down

0 comments on commit c33c1f5

Please sign in to comment.