Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Feb 19, 2024
1 parent 7d91e8b commit 8090385
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions egs/wenetspeech/KWS/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def decode_one_batch(
model: nn.Module,
lexicon: Lexicon,
batch: dict,
kws_graph: ContextGraph,
keywords_graph: ContextGraph,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
Expand Down Expand Up @@ -272,7 +272,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
keywords_graph=kws_graph,
keywords_graph=keywords_graph,
beam=params.beam_size,
num_tailing_blanks=8,
)
Expand All @@ -297,7 +297,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
kws_graph: ContextGraph,
keywords_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
Expand Down Expand Up @@ -343,7 +343,7 @@ def decode_dataset(
params=params,
model=model,
lexicon=lexicon,
kws_graph=kws_graph,
keywords_graph=keywords_graph,
batch=batch,
)

Expand Down Expand Up @@ -561,10 +561,10 @@ def main():
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)

kws_graph = ContextGraph(
keywords_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
kws_graph.build(
keywords_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
Expand Down Expand Up @@ -697,8 +697,8 @@ def remove_short_utt(c: Cut):
test_sets = []
test_dls = []
if params.test_set == "large":
test_sets.append("cn_commands_large")
test_dls.append(cn_commands_large_dl)
test_sets += ["cn_commands_large", "test_net"]
test_dls += [cn_commands_large_dl, test_net_dl]
else:
assert params.test_set == "small", params.test_set
test_sets += [
Expand All @@ -722,7 +722,7 @@ def remove_short_utt(c: Cut):
params=params,
model=model,
lexicon=lexicon,
kws_graph=kws_graph,
keywords_graph=keywords_graph,
keywords=keywords,
test_only_keywords="test_net" not in test_set,
)
Expand Down

0 comments on commit 8090385

Please sign in to comment.