From 4771c9275ce354390c41f596dce5df500deda997 Mon Sep 17 00:00:00 2001 From: Peng He <34941901+kamirdin@users.noreply.github.com> Date: Fri, 13 Oct 2023 11:15:16 +0800 Subject: [PATCH] Add lm decode for the Python API. (#353) * Add lm decode for the Python API. * fix style. * Fix LogAdd, Shouldn't double lm_log_prob when merge same prefix path * sort the import alphabetically --- python-api-examples/online-decode-files.py | 20 +++++++++++++++++++ sherpa-onnx/csrc/hypothesis.cc | 5 ----- sherpa-onnx/python/csrc/online-recognizer.cc | 1 + .../python/sherpa_onnx/online_recognizer.py | 15 ++++++++++++++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index cdf7870fb..c6606f94b 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -115,6 +115,24 @@ def get_args(): """, ) + parser.add_argument( + "--lm", + type=str, + default="", + help="""Used only when --decoding-method is modified_beam_search. + path of language model. + """, + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.1, + help="""Used only when --decoding-method is modified_beam_search. + scale of language model. + """, + ) + parser.add_argument( "--provider", type=str, @@ -215,6 +233,8 @@ def main(): feature_dim=80, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, + lm=args.lm, + lm_scale=args.lm_scale, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, ) diff --git a/sherpa-onnx/csrc/hypothesis.cc b/sherpa-onnx/csrc/hypothesis.cc index 55d2492b0..ea332bcb5 100644 --- a/sherpa-onnx/csrc/hypothesis.cc +++ b/sherpa-onnx/csrc/hypothesis.cc @@ -17,11 +17,6 @@ void Hypotheses::Add(Hypothesis hyp) { hyps_dict_[key] = std::move(hyp); } else { it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); - - if (it->second.lm_log_prob != 0 && hyp.lm_log_prob != 0) { - it->second.lm_log_prob = - LogAdd()(it->second.lm_log_prob, hyp.lm_log_prob); - } } } diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 68e97b60a..9cfce8456 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::arg("hotwords_score") = 0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("decoding_method", &PyClass::decoding_method) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index eabf99ec8..c547c3166 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -5,6 +5,7 @@ from _sherpa_onnx import ( EndpointConfig, FeatureExtractorConfig, + OnlineLMConfig, OnlineModelConfig, OnlineParaformerModelConfig, OnlineRecognizer as _Recognizer, @@ -46,6 +47,8 @@ def from_transducer( hotwords_file: str = "", provider: str = "cpu", model_type: str = "", + lm: str = "", + lm_scale: float = 0.1, ): """ Please refer to @@ -137,10 +140,22 @@ def from_transducer( "Please use --decoding-method=modified_beam_search when using " f"--hotwords-file. Currently given: {decoding_method}" ) + + if lm and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--lm. Currently given: {decoding_method}" + ) + + lm_config = OnlineLMConfig( + model=lm, + scale=lm_scale, + ) recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, + lm_config=lm_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method,