diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 84f55ac693..50316b4027 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -211,7 +211,7 @@ def decode_one_batch( model: nn.Module, lexicon: Lexicon, batch: dict, - kws_graph: ContextGraph, + keywords_graph: ContextGraph, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -272,7 +272,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - keywords_graph=kws_graph, + keywords_graph=keywords_graph, beam=params.beam_size, num_tailing_blanks=8, ) @@ -297,7 +297,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, - kws_graph: ContextGraph, + keywords_graph: ContextGraph, keywords: Set[str], test_only_keywords: bool, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: @@ -343,7 +343,7 @@ def decode_dataset( params=params, model=model, lexicon=lexicon, - kws_graph=kws_graph, + keywords_graph=keywords_graph, batch=batch, ) @@ -561,10 +561,10 @@ def main(): keywords_thresholds.append(threshold) params.keywords_config = "".join(keywords_config) - kws_graph = ContextGraph( + keywords_graph = ContextGraph( context_score=params.keywords_score, ac_threshold=params.keywords_threshold ) - kws_graph.build( + keywords_graph.build( token_ids=token_ids, phrases=phrases, scores=keywords_scores, @@ -697,8 +697,8 @@ def remove_short_utt(c: Cut): test_sets = [] test_dls = [] if params.test_set == "large": - test_sets.append("cn_commands_large") - test_dls.append(cn_commands_large_dl) + test_sets += ["cn_commands_large", "test_net"] + test_dls += [cn_commands_large_dl, test_net_dl] else: assert params.test_set == "small", params.test_set test_sets += [ @@ -722,7 +722,7 @@ def remove_short_utt(c: Cut): params=params, model=model, lexicon=lexicon, - kws_graph=kws_graph, + keywords_graph=keywords_graph, keywords=keywords, test_only_keywords="test_net" not in test_set, )