diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh index d4f7445149..0b098190d5 100755 --- a/egs/gigaspeech/KWS/prepare.sh +++ b/egs/gigaspeech/KWS/prepare.sh @@ -49,6 +49,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then pushd data ln -svf $(realpath ../ASR/data/lang_bpe_500) . popd + touch data/fbank/.gigaspeech.done else log "Gigaspeech dataset already exists, skipping." fi @@ -63,7 +64,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt pushd open-commands - ./script/prepare.sh --stage 3 --stop-stage 3 + ./script/prepare.sh --stage 2 --stop-stage 2 ./script/prepare.sh --stage 6 --stop-stage 6 popd popd diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py index 3c743b1537..98b0039370 100755 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -186,13 +186,6 @@ def get_parser(): help="The default threshold (probability) to trigger the keyword.", ) - parser.add_argument( - "--keywords-version", - type=str, - default="", - help="The keywords configuration version, just to save results to different files.", - ) - parser.add_argument( "--num-tailing-blanks", type=int, @@ -222,7 +215,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, - kws_graph: Optional[ContextGraph] = None, + keywords_graph: Optional[ContextGraph] = None, ) -> List[List[Tuple[str, Tuple[int, int]]]]: """Decode one batch and return the result in a list. @@ -242,7 +235,7 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - kws_graph: + keywords_graph: The graph containing keywords. Returns: Return the decoding result. See above description for the format of @@ -274,7 +267,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - keywords_graph=kws_graph, + keywords_graph=keywords_graph, beam=params.beam, num_tailing_blanks=params.num_tailing_blanks, blank_penalty=params.blank_penalty, @@ -295,7 +288,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - kws_graph: ContextGraph, + keywords_graph: ContextGraph, keywords: Set[str], test_only_keywords: bool, ) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: @@ -310,7 +303,7 @@ def decode_dataset( The neural model. sp: The BPE model. - kws_graph: + keywords_graph: The graph containing keywords. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -341,7 +334,7 @@ def decode_dataset( params=params, model=model, sp=sp, - kws_graph=kws_graph, + keywords_graph=keywords_graph, batch=batch, ) @@ -459,7 +452,7 @@ def save_results( s += f"\tPrecision: {precision:.3f}\n" s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tFPR: {fpr:.3f}\n" - s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" + s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" if key != "all": s += f"\tTP list: {' # '.join(item.TP_list)}\n" s += f"\tFP list: {' # '.join(item.FP_list)}\n" @@ -505,7 +498,7 @@ def main(): params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" if params.blank_penalty != 0: params.suffix += f"-blank-penalty-{params.blank_penalty}" - params.suffix += f"-version-{params.keywords_version}" + params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -555,10 +548,10 @@ def main(): params.keywords_config = "".join(keywords_config) - kws_graph = ContextGraph( + keywords_graph = ContextGraph( context_score=params.keywords_score, ac_threshold=params.keywords_threshold ) - kws_graph.build( + keywords_graph.build( token_ids=token_ids, phrases=phrases, scores=keywords_scores, @@ -677,7 +670,7 @@ def main(): params=params, model=model, sp=sp, - kws_graph=kws_graph, + keywords_graph=keywords_graph, keywords=keywords, test_only_keywords="fsc" in test_set, ) diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 9bcb09e966..e7387dd39f 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -276,7 +276,7 @@ def get_parser(): return parser -def add_model_arguments(parser: argparse.ArgumentParser): +def add_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--world-size", type=int, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index d900b14b9d..66c84b2a94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -993,7 +993,7 @@ def keywords_search( """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert context_graph is not None + assert keywords_graph is not None packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -1018,7 +1018,7 @@ def keywords_search( Hypothesis( ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=context_graph.root, + context_state=keywords_graph.root, timestamp=[], ac_probs=[], ) @@ -1125,7 +1125,7 @@ def keywords_search( context_score, new_context_state, _, - ) = context_graph.forward_one_step(hyp.context_state, new_token) + ) = keywords_graph.forward_one_step(hyp.context_state, new_token) new_num_tailing_blanks = 0 if new_context_state.token == -1: # root new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] @@ -1143,7 +1143,7 @@ def keywords_search( B[i].add(new_hyp) top_hyp = B[i].get_most_probable(length_norm=True) - matched, matched_state = context_graph.is_matched(top_hyp.context_state) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) if matched: ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level @@ -1164,7 +1164,7 @@ def keywords_search( Hypothesis( ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=context_graph.root, + context_state=keywords_graph.root, timestamp=[], ac_probs=[], ) @@ -1174,7 +1174,7 @@ def keywords_search( for i, hyps in enumerate(B): top_hyp = hyps.get_most_probable(length_norm=True) - matched, matched_state = context_graph.is_matched(top_hyp.context_state) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) if matched: ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 29dee97b06..543d19ce0c 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -376,5 +376,6 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then --token-type $token \ --lang-dir $lang_dir fi + python ./local/compile_lg.py --lang-dir $lang_dir done fi diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh index d4f7445149..dcc65fab49 100755 --- a/egs/wenetspeech/KWS/prepare.sh +++ b/egs/wenetspeech/KWS/prepare.sh @@ -22,63 +22,69 @@ log() { } if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Prepare gigaspeech dataset." + log "Stage 0: Prepare wewetspeech dataset." mkdir -p data/fbank - if [ ! -e data/fbank/.gigaspeech.done ]; then + if [ ! -e data/fbank/.wewetspeech.done ]; then pushd ../ASR - ./prepare.sh --stage 0 --stop-stage 9 - ./prepare.sh --stage 11 --stop-stage 11 + ./prepare.sh --stage 0 --stop-stage 17 + ./prepare.sh --stage 22 --stop-stage 22 popd pushd data/fbank - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) . - ln -svf $(realpath ../ASR/data/fbank/XL_split) . + ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/L_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/M_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/S_split_1000) . ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . ln -svf $(realpath ../ASR/data/fbank/musan_feats) . popd pushd data - ln -svf $(realpath ../ASR/data/lang_bpe_500) . + ln -svf $(realpath ../ASR/data/lang_partial_tone) . popd + touch data/fbank/.wewetspeech.done else - log "Gigaspeech dataset already exists, skipping." + log "WenetSpeech dataset already exists, skipping." fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare open commands dataset." mkdir -p data/fbank - if [ ! -e data/fbank/.fluent_speech_commands.done ]; then + if [ ! -e data/fbank/.cn_speech_commands.done ]; then pushd data git clone https://github.com/pkufool/open-commands.git - ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt - ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt + ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt + ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt pushd open-commands - ./script/prepare.sh --stage 3 --stop-stage 3 - ./script/prepare.sh --stage 6 --stop-stage 6 + ./script/prepare.sh --stage 1 --stop-stage 1 + ./script/prepare.sh --stage 3 --stop-stage 5 popd popd pushd data/fbank - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) . popd - touch data/fbank/.fluent_speech_commands.done + touch data/fbank/.cn_speech_commands.done else - log "Fluent speech commands dataset already exists, skipping." + log "CN speech commands dataset already exists, skipping." fi fi diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index e13a789647..971f54e29b 100644 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -37,7 +37,7 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Train a model." - if [ ! -e data/fbank/.gigaspeech.done ]; then + if [ ! -e data/fbank/.wenetspeech.done ]; then log "You need to run the prepare.sh first." exit -1 fi diff --git a/egs/wenetspeech/KWS/zipformer/decode-asr.py b/egs/wenetspeech/KWS/zipformer/decode-asr.py index 56a4014d94..6425030eb7 100755 --- a/egs/wenetspeech/KWS/zipformer/decode-asr.py +++ b/egs/wenetspeech/KWS/zipformer/decode-asr.py @@ -19,38 +19,7 @@ # limitations under the License. """ Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) +(1) fast beam search (LG) ./zipformer/decode.py \ --epoch 30 \ --avg 15 \ @@ -61,20 +30,6 @@ --beam 20.0 \ --max-contexts 8 \ --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 """ diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 4d30cabc7d..84f55ac693 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -17,19 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Usage: -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 -""" - import argparse import logging @@ -207,13 +194,6 @@ def get_parser(): help="The default threshold (probability) to trigger the keyword.", ) - parser.add_argument( - "--keywords-version", - type=str, - default="", - help="The keywords configuration version, just to save results to different files.", - ) - parser.add_argument( "--num-tailing-blanks", type=int, @@ -479,7 +459,7 @@ def save_results( s += f"\tPrecision: {precision:.3f}\n" s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tFPR: {fpr:.3f}\n" - s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" + s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" if key != "all": s += f"\tTP list: {' # '.join(item.TP_list)}\n" s += f"\tFP list: {' # '.join(item.FP_list)}\n" @@ -525,7 +505,7 @@ def main(): params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" if params.blank_penalty != 0: params.suffix += f"-blank-penalty-{params.blank_penalty}" - params.suffix += f"-version-{params.keywords_version}" + params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started")