diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index b2b804d4c..814bde81a 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -327,7 +327,7 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): ) -def add_contexts_args(parser: argparse.ArgumentParser): +def add_hotwords_args(parser: argparse.ArgumentParser): parser.add_argument( "--bpe-model", type=str, @@ -337,25 +337,36 @@ def add_contexts_args(parser: argparse.ArgumentParser): Used only when --decoding-method=modified_beam_search """, ) + parser.add_argument( + "--tokens_type", + type=str, + default="cjkchar", + help=""" + The type of tokens (i.e the modeling unit). + Valid values are bpe, cjkchar+bpe, cjkchar. + """, + ) parser.add_argument( - "--modeling-unit", + "--hotwords-file", type=str, - default="char", + default="", help=""" - The type of modeling unit. - Valid values are bpe, bpe+char, char. - Note: the char here means characters in CJK languages. + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + HELLO WORLD + 你 好 世 界 """, ) parser.add_argument( - "--context-score", + "--hotwords-score", type=float, default=1.5, help=""" - The context score of each token for biasing word/phrase. Used only if - --contexts is given. + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. """, ) @@ -376,7 +387,7 @@ def check_args(args): assert Path(args.decoder).is_file(), args.decoder assert Path(args.joiner).is_file(), args.joiner - if args.contexts != "": + if args.hotwords_file != "": assert args.decoding_method == "modified_beam_search", args.decoding_method @@ -388,7 +399,7 @@ def get_args(): add_model_args(parser) add_feature_config_args(parser) add_decoding_args(parser) - add_contexts_args(parser) + add_hotwords_args(parser) parser.add_argument( "--port", @@ -808,24 +819,6 @@ def assert_file_exists(filename: str): ) -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: - sp = None - if "bpe" in args.modeling_unit: - assert_file_exists(args.bpe_model) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - tokens = {} - with open(args.tokens, "r", encoding="utf-8") as f: - for line in f: - toks = line.strip().split() - assert len(toks) == 2, len(toks) - assert toks[0] not in tokens, f"Duplicate token: {toks} " - tokens[toks[0]] = int(toks[1]) - return sherpa_onnx.encode_contexts( - modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens - ) - - def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: if args.encoder: assert len(args.paraformer) == 0, args.paraformer @@ -848,7 +841,10 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: feature_dim=args.feat_dim, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, - context_score=args.context_score, + tokens_type=args.tokens_type, + bpe_model=args.bpe_model, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, ) elif args.paraformer: assert len(args.nemo_ctc) == 0, args.nemo_ctc diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index a3f07531f..7227dfd81 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -134,8 +134,8 @@ def get_args(): type=float, default=1.5, help=""" - The context score of each token for biasing word/phrase. Used only if - --contexts is given. + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. """, ) @@ -327,7 +327,7 @@ def main(): tokens_type=args.tokens_type, bpe_model=args.bpe_model, hotwords_file=args.hotwords_file, - hotwords_score=args.context_score, + hotwords_score=args.hotwords_score, debug=args.debug, ) elif args.paraformer: diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index 1a884cab6..38ff2b6ee 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -134,36 +134,35 @@ def get_args(): ) parser.add_argument( - "--modeling-unit", + "--tokens_type", type=str, - default="char", + default="cjkchar", help=""" - The type of modeling unit, it will be used to tokenize contexts biasing phrases. - Valid values are bpe, bpe+char, char. - Note: the char here means characters in CJK languages. - Used only when --decoding-method=modified_beam_search + The type of tokens (i.e the modeling unit). + Valid values are bpe, cjkchar+bpe, cjkchar. """, ) parser.add_argument( - "--contexts", + "--hotwords-file", type=str, default="", help=""" - The context list, it is a string containing some words/phrases separated - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". - Used only when --decoding-method=modified_beam_search + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + HELLO WORLD + 你 好 世 界 """, ) parser.add_argument( - "--context-score", + "--hotwords-score", type=float, default=1.5, help=""" - The context score of each token for biasing word/phrase. Used only if - --contexts is given. - Used only when --decoding-method=modified_beam_search + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. """, ) @@ -214,27 +213,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: return samples_float32, f.getframerate() -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: - sp = None - if "bpe" in args.modeling_unit: - assert_file_exists(args.bpe_model) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - tokens = {} - with open(args.tokens, "r", encoding="utf-8") as f: - for line in f: - toks = line.strip().split() - assert len(toks) == 2, len(toks) - assert toks[0] not in tokens, f"Duplicate token: {toks} " - tokens[toks[0]] = int(toks[1]) - return sherpa_onnx.encode_contexts( - modeling_unit=args.modeling_unit, - contexts=contexts, - sp=sp, - tokens_table=tokens, - ) - - def main(): args = get_args() assert_file_exists(args.tokens) @@ -258,7 +236,10 @@ def main(): feature_dim=80, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, - hotwords_score=args.context_score, + tokens_type=args.tokens_type, + bpe_model=args.bpe_model, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, ) elif args.paraformer_encoder: recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( @@ -277,12 +258,6 @@ def main(): print("Started!") start_time = time.time() - contexts_list = [] - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] - if contexts: - print(f"Contexts list: {contexts}") - contexts_list = encode_contexts(args, contexts) - streams = [] total_duration = 0 for wave_filename in args.sound_files: @@ -291,10 +266,7 @@ def main(): duration = len(samples) / sample_rate total_duration += duration - if contexts_list: - s = recognizer.create_stream(contexts_list=contexts_list) - else: - s = recognizer.create_stream() + s = recognizer.create_stream() s.accept_waveform(sample_rate, samples) diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index 33d4e5ee0..66d701b01 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -187,6 +187,51 @@ def add_decoding_args(parser: argparse.ArgumentParser): add_modified_beam_search_args(parser) +def add_hotwords_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--bpe-model", + type=str, + default="", + help=""" + Path to bpe.model, + Used only when --decoding-method=modified_beam_search + """, + ) + parser.add_argument( + "--tokens_type", + type=str, + default="cjkchar", + help=""" + The type of tokens (i.e the modeling unit). + Valid values are bpe, cjkchar+bpe, cjkchar. + """, + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + HELLO WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + + def add_modified_beam_search_args(parser: argparse.ArgumentParser): parser.add_argument( "--num-active-paths", @@ -239,6 +284,7 @@ def get_args(): add_model_args(parser) add_decoding_args(parser) add_endpointing_args(parser) + add_hotwords_args(parser) parser.add_argument( "--port", @@ -343,6 +389,10 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: feature_dim=args.feat_dim, decoding_method=args.decoding_method, max_active_paths=args.num_active_paths, + tokens_type=args.tokens_type, + bpe_model=args.bpe_model, + hotwords_score=args.hotwords_score, + hotwords_file=args.hotwords_file, enable_endpoint_detection=args.use_endpoint != 0, rule1_min_trailing_silence=args.rule1_min_trailing_silence, rule2_min_trailing_silence=args.rule2_min_trailing_silence, diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index c040b5ec6..773c4859a 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -93,17 +93,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { const std::string &hotwords) const override { auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); std::istringstream is(hws); - int32_t default_hws_num = hotwords_.size(); - std::vector> tmp; + std::vector> current; if (!EncodeHotwords(is, config_.model_config.tokens_type, symbol_table_, - bpe_processor_, &tmp)) { + bpe_processor_, ¤t)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } - hotwords_.insert(hotwords_.end(), tmp.begin(), tmp.end()); + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); + auto context_graph = - std::make_shared(hotwords_, config_.hotwords_score); - hotwords_.resize(default_hws_num); + std::make_shared(current, config_.hotwords_score); return std::make_unique(config_.feat_config, context_graph); } diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index 515c9d9e8..199f8440f 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -29,7 +29,7 @@ class OnlineRecognizerImpl { virtual std::unique_ptr CreateStream() const = 0; virtual std::unique_ptr CreateStream( - const std::vector> &contexts) const { + const std::string &hotwords) const { SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); exit(-1); } diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index c40202b5f..13528aee3 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -117,15 +117,18 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } std::unique_ptr CreateStream( - const std::vector> &contexts) const override { - // We create context_graph at this level, because we might have default - // context_graph(will be added later if needed) that belongs to the whole - // model rather than each stream. - std::vector> hotwords; - hotwords.insert(hotwords.end(), hotwords_.begin(), hotwords_.end()); - hotwords.insert(hotwords.end(), contexts.begin(), contexts.end()); + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + if (!EncodeHotwords(is, config_.model_config.tokens_type, sym_, + bpe_processor_, ¤t)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); auto context_graph = - std::make_shared(hotwords, config_.hotwords_score); + std::make_shared(current, config_.hotwords_score); auto stream = std::make_unique(config_.feat_config, context_graph); InitOnlineStream(stream.get()); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 4fbf3ae80..30121176c 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -113,8 +113,8 @@ std::unique_ptr OnlineRecognizer::CreateStream() const { } std::unique_ptr OnlineRecognizer::CreateStream( - const std::vector> &context_list) const { - return impl_->CreateStream(context_list); + const std::string &hotwords) const { + return impl_->CreateStream(hotwords); } bool OnlineRecognizer::IsReady(OnlineStream *s) const { diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index e7d600064..26b6b8d58 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -124,8 +124,7 @@ class OnlineRecognizer { std::unique_ptr CreateStream() const; // Create a stream with context phrases - std::unique_ptr CreateStream( - const std::vector> &context_list) const; + std::unique_ptr CreateStream(const std::string &hotwords) const; /** * Return true if the given stream has enough frames for decoding. diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 0962e30a1..5f5621376 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -57,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) { [](const PyClass &self) { return self.CreateStream(); }) .def( "create_stream", - [](PyClass &self, - const std::vector> &contexts_list) { - return self.CreateStream(contexts_list); + [](PyClass &self, const std::string &hotwords) { + return self.CreateStream(hotwords); }, - py::arg("contexts_list")) + py::arg("hotwords")) .def("is_ready", &PyClass::IsReady) .def("decode_stream", &PyClass::DecodeStream) .def("decode_streams", diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 20bc40e6a..6f0306e70 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -385,11 +385,11 @@ def from_tdnn_ctc( self.config = recognizer_config return self - def create_stream(self, contexts_list: Optional[List[List[int]]] = None): - if contexts_list is None: + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: return self.recognizer.create_stream() else: - return self.recognizer.create_stream(contexts_list) + return self.recognizer.create_stream(hotwords) def decode_stream(self, s: OfflineStream): self.recognizer.decode_stream(s) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 5c8ba8223..a1eb0d61c 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -45,6 +45,7 @@ def from_transducer( tokens_type: str = "cjkchar", bpe_model: str = "", hotwords_score: float = 1.5, + hotwords_file: str = "", provider: str = "cpu", model_type: str = "", ): @@ -253,11 +254,11 @@ def from_paraformer( self.config = recognizer_config return self - def create_stream(self, contexts_list: Optional[List[List[int]]] = None): - if contexts_list is None: + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: return self.recognizer.create_stream() else: - return self.recognizer.create_stream(contexts_list) + return self.recognizer.create_stream(hotwords) def decode_stream(self, s: OnlineStream): self.recognizer.decode_stream(s)