Skip to content

Commit

Permalink
Add fst-based decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 3, 2024
1 parent 4f55ce1 commit 66ad152
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 37 deletions.
42 changes: 27 additions & 15 deletions sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace sherpa_onnx {
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename);

OnlineCtcFstDecoder::OnlineCtcFstDecoder(
const OnlineCtcFstDecoderConfig &config)
: config_(config), fst_(ReadGraph(config.graph)) {
const OnlineCtcFstDecoderConfig &config, int32_t blank_id)
: config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) {
options_.max_active = config_.max_active;
}

Expand All @@ -32,7 +32,7 @@ OnlineCtcFstDecoder::CreateFasterDecoder() const {

static void DecodeOne(const float *log_probs, int32_t num_rows,
int32_t num_cols, OnlineCtcDecoderResult *result,
OnlineStream *s) {
OnlineStream *s, int32_t blank_id) {
int32_t &processed_frames = s->GetFasterDecoderProcessedFrames();
kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols,
processed_frames);
Expand All @@ -41,37 +41,49 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
if (processed_frames == 0) {
decoder->InitDecoding();
}

decoder->AdvanceDecoding(&decodable);

if (decoder->ReachedFinal()) {
fst::VectorFst<fst::LatticeArc> fst_out;
bool ok = decoder->GetBestPath(&fst_out);
if (ok) {
std::vector<int32_t> isymbols_out;
std::vector<int32_t> osymbols_out;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
nullptr);
SHERPA_ONNX_LOGE("num tokens: %d\n",
static_cast<int32_t>(isymbols_out.size()));
std::vector<int32_t> osymbols_out_unused;
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
&osymbols_out_unused, nullptr);
std::vector<int64_t> tokens;
tokens.reserve(isymbols_out.size());

std::vector<int32_t> timestamps;
timestamps.reserve(isymbols_out.size());

std::ostringstream os;
int32_t prev_id = -1;
int32_t num_trailing_blanks = 0;
int32_t f = 0; // frame number

for (auto i : isymbols_out) {
i -= 1;
if (i != 0 && i != prev_id) {

if (i == blank_id) {
num_trailing_blanks += 1;
} else {
num_trailing_blanks = 0;
}

if (i != blank_id && i != prev_id) {
tokens.push_back(i);
timestamps.push_back(f);
}
prev_id = i;
// TODO(fangjun): set num_trailing_blanks
f += 1;
}

result->tokens = std::move(tokens);
} else {
result->tokens.clear();
result->timestamps = std::move(timestamps);
// no need to set frame_offset
}
} else {
result->tokens.clear();
}

processed_frames += num_rows;
Expand Down Expand Up @@ -104,7 +116,7 @@ void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,

for (int32_t i = 0; i != batch_size; ++i) {
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
&(*results)[i], ss[i]);
&(*results)[i], ss[i], blank_id_);
}
}

Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/online-ctc-fst-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace sherpa_onnx {

class OnlineCtcFstDecoder : public OnlineCtcDecoder {
public:
explicit OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config);
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
int32_t blank_id);

void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results,
Expand All @@ -29,6 +30,7 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder {
kaldi_decoder::FasterDecoderOptions options_;

std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
int32_t blank_id_ = 0;
};

} // namespace sherpa_onnx
Expand Down
42 changes: 21 additions & 21 deletions sherpa-onnx/csrc/online-recognizer-ctc-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,29 +223,29 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {

private:
void InitDecoder() {
if (!config_.ctc_fst_decoder_config.graph.empty()) {
decoder_ =
std::make_unique<OnlineCtcFstDecoder>(config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
!sym_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
!sym_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}

int32_t blank_id = 0;
if (sym_.contains("<blk>")) {
blank_id = sym_["<blk>"];
} else if (sym_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = sym_["<eps>"];
} else if (sym_.contains("<blank>")) {
// for WeNet CTC models
blank_id = sym_["<blank>"];
}
int32_t blank_id = 0;
if (sym_.contains("<blk>")) {
blank_id = sym_["<blk>"];
} else if (sym_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = sym_["<eps>"];
} else if (sym_.contains("<blank>")) {
// for WeNet CTC models
blank_id = sym_["<blank>"];
}

if (!config_.ctc_fst_decoder_config.graph.empty()) {
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
config_.ctc_fst_decoder_config, blank_id);
} else if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE(
Expand Down

0 comments on commit 66ad152

Please sign in to comment.