Skip to content

Commit

Permalink
fill lm_probs/context_scores only if LM/ContextGraph is present (make…
Browse files Browse the repository at this point in the history
… Result smaller)
  • Loading branch information
KarelVesely84 committed Feb 27, 2024
1 parent 0f1107b commit e442564
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ struct Hypothesis {

// lm_probs[i] contains the lm score for each token in ys.
// Used only in transducer mofified beam-search.
// Elements filled only if LM is used.
std::vector<float> lm_probs;

// context_scores[i] contains the context-graph score for each token in ys.
// Used only in transducer mofified beam-search.
// Elements filled only if `ContextGraph` is used.
std::vector<float> context_scores;

// The total score of ys in log space.
Expand Down
12 changes: 6 additions & 6 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace sherpa_onnx {

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ ";
Expand All @@ -35,9 +35,8 @@ const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6)

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
const std::string& VecToString<std::string>(const std::vector<std::string>& vec,
int32_t) // ignore 2nd arg
{
std::string VecToString<std::string>(const std::vector<std::string>& vec,
int32_t) { // ignore 2nd arg
std::ostringstream oss;
oss << "[ ";
std::string sep = "";
Expand All @@ -57,9 +56,10 @@ std::string OnlineRecognizerResult::AsJsonString() const {
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"constext_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2)
<< start_time << ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
os << "}";
return os.str();
Expand Down
17 changes: 11 additions & 6 deletions sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis& prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// for getting topk tokens
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
new_hyp.ys_probs.push_back(y_prob);

float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
if (lm_) { // export only when LM is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
new_hyp.lm_probs.push_back(lm_prob);
}
new_hyp.lm_probs.push_back(lm_prob);

new_hyp.context_scores.push_back(context_score);
// export only when `ContextGraph` is used
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
new_hyp.context_scores.push_back(context_score);
}
}

hyps.Add(std::move(new_hyp));
Expand Down

0 comments on commit e442564

Please sign in to comment.