diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 79f154699..c402163fe 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -42,6 +42,8 @@ static void PybindOnlineRecognizerResult(py::module *m) { "segment", [](PyClass &self) -> int32_t { return self.segment; }) .def_property_readonly( "is_final", [](PyClass &self) -> bool { return self.is_final; }) + .def("__str__", &PyClass::AsJsonString, + py::call_guard()) .def("as_json_string", &PyClass::AsJsonString, py::call_guard()); } @@ -50,29 +52,17 @@ static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") .def( - py::init(), - py::arg("feat_config"), - py::arg("model_config"), + py::init(), + py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config") = EndpointConfig(), py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), - py::arg("decoding_method"), - py::arg("max_active_paths") = 4, - py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, - py::arg("blank_penalty") = 0.0, + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 2272fbd4c..36fb66826 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -12,9 +12,11 @@ from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import ( OnlineRecognizerConfig, + OnlineRecognizerResult, OnlineStream, OnlineTransducerModelConfig, OnlineWenetCtcModelConfig, + OnlineNeMoCtcModelConfig, OnlineZipformer2CtcModelConfig, OnlineCtcFstDecoderConfig, ) @@ -645,6 +647,9 @@ def decode_streams(self, ss: List[OnlineStream]): def is_ready(self, s: OnlineStream) -> bool: return self.recognizer.is_ready(s) + def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult: + return self.recognizer.get_result(s) + def get_result(self, s: OnlineStream) -> str: return self.recognizer.get_result(s).text.strip()