diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 23fad3df11..0f80450712 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } decoder_ = std::make_unique( - model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_, config_.blank_penalty); + model_.get(), + lm_.get(), + config_.max_active_paths, + config_.lm_config.scale, + unk_id_, + config_.blank_penalty, + config_.temperature_scale); + } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), unk_id_, config_.blank_penalty); + model_.get(), + unk_id_, + config_.blank_penalty, + config_.temperature_scale); + } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 8bd0c16ada..a7fdbdff39 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); + po->Register("temperature-scale", &temperature_scale, + "Temperature scale for confidence computation in decoding."); } bool OnlineRecognizerConfig::Validate() const { @@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "hotwords_score=" << hotwords_score << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; os << "decoding_method=\"" << decoding_method << "\", "; - os << "blank_penalty=" << blank_penalty << ")"; + os << "blank_penalty=" << blank_penalty << ", "; + os << "temperature_scale=" << temperature_scale << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 308cb08f7d..d8503bd130 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -96,16 +96,23 @@ struct OnlineRecognizerConfig { float blank_penalty = 0.0; + float temperature_scale = 2.0; + OnlineRecognizerConfig() = default; OnlineRecognizerConfig( const FeatureExtractorConfig &feat_config, - const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, + const OnlineModelConfig &model_config, + const OnlineLMConfig &lm_config, const EndpointConfig &endpoint_config, const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, - bool enable_endpoint, const std::string &decoding_method, - int32_t max_active_paths, const std::string &hotwords_file, - float hotwords_score, float blank_penalty) + bool enable_endpoint, + const std::string &decoding_method, + int32_t max_active_paths, + const std::string &hotwords_file, + float hotwords_score, + float blank_penalty, + float temperature_scale) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -114,9 +121,10 @@ struct OnlineRecognizerConfig { enable_endpoint(enable_endpoint), decoding_method(decoding_method), max_active_paths(max_active_paths), - hotwords_score(hotwords_score), hotwords_file(hotwords_file), - blank_penalty(blank_penalty) {} + hotwords_score(hotwords_score), + blank_penalty(blank_penalty), + temperature_scale(temperature_scale) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index d568026a6f..03447fc182 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -144,11 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( // export the per-token log scores if (y != 0 && y != unk_id_) { - // TODO(KarelVesely84): configure externally ? // apply temperature-scaling - float temperature_scale = 2.0; for (int32_t n = 0; n < vocab_size; ++n) { - p_logit[n] /= temperature_scale; + p_logit[n] /= temperature_scale_; } LogSoftmax(p_logit, vocab_size); // renormalize probabilities, // save time by doing it only for diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index c68c32dcf0..716f884847 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -15,8 +15,13 @@ namespace sherpa_onnx { class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { public: OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, - int32_t unk_id, float blank_penalty) - : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} + int32_t unk_id, + float blank_penalty, + float temperature_scale) + : model_(model), + unk_id_(unk_id), + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { OnlineTransducerModel *model_; // Not owned int32_t unk_id_; float blank_penalty_; + float temperature_scale_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index ed5e05f89b..ea3f78f4bc 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -130,17 +130,17 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( float *p_logit = logit.GetTensorMutableData(); - // copy raw logits, apply temperature-scaling (for confidences) + // copy raw logits, apply temperature-scaling (for confidences) + // Note: temperature scaling is used only for the confidences, + // the decoding algorithm uses the original logits int32_t p_logit_items = vocab_size * num_hyps; std::vector logit_with_temperature(p_logit_items); { std::copy(p_logit, p_logit + p_logit_items, logit_with_temperature.begin()); - // TODO(KarelVesely84): configure externally ? - float temperature_scale = 2.0; for (float& elem : logit_with_temperature) { - elem /= temperature_scale; + elem /= temperature_scale_; } LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps); } diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 92e9a69c9d..839aa768a4 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder OnlineLM *lm, int32_t max_active_paths, float lm_scale, int32_t unk_id, - float blank_penalty) + float blank_penalty, + float temperature_scale) : model_(model), lm_(lm), max_active_paths_(max_active_paths), lm_scale_(lm_scale), unk_id_(unk_id), - blank_penalty_(blank_penalty) {} + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder float lm_scale_; // used only when lm_ is not nullptr int32_t unk_id_; float blank_penalty_; + float temperature_scale_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index bd98c94e25..79f1546999 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") .def( - py::init(), - py::arg("feat_config"), py::arg("model_config"), + py::init(), + py::arg("feat_config"), + py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config") = EndpointConfig(), py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) + py::arg("enable_endpoint"), + py::arg("decoding_method"), + py::arg("max_active_paths") = 4, + py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, + py::arg("blank_penalty") = 0.0, + py::arg("temperature_scale") = 2.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) .def_readwrite("blank_penalty", &PyClass::blank_penalty) + .def_readwrite("temperature_scale", &PyClass::temperature_scale) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index a82ab1703b..5200000288 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -58,6 +58,7 @@ def from_transducer( model_type: str = "", lm: str = "", lm_scale: float = 0.1, + temperature_scale: float = 2.0, ): """ Please refer to @@ -123,6 +124,10 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + temperature_scale: + Temperature scaling for output symbol confidence estiamation. + It affects only confidence values, the decoding uses the original + logits without temperature. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. model_type: @@ -193,6 +198,7 @@ def from_transducer( hotwords_score=hotwords_score, hotwords_file=hotwords_file, blank_penalty=blank_penalty, + temperature_scale=temperature_scale, ) self.recognizer = _Recognizer(recognizer_config)