Skip to content

Commit

Permalink
Fix wewetspeech prepare.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Feb 18, 2024
1 parent afe3b18 commit 7d91e8b
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 129 deletions.
3 changes: 2 additions & 1 deletion egs/gigaspeech/KWS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
pushd data
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
popd
touch data/fbank/.gigaspeech.done
else
log "Gigaspeech dataset already exists, skipping."
fi
Expand All @@ -63,7 +64,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 3 --stop-stage 3
./script/prepare.sh --stage 2 --stop-stage 2
./script/prepare.sh --stage 6 --stop-stage 6
popd
popd
Expand Down
29 changes: 11 additions & 18 deletions egs/gigaspeech/KWS/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,6 @@ def get_parser():
help="The default threshold (probability) to trigger the keyword.",
)

parser.add_argument(
"--keywords-version",
type=str,
default="",
help="The keywords configuration version, just to save results to different files.",
)

parser.add_argument(
"--num-tailing-blanks",
type=int,
Expand Down Expand Up @@ -222,7 +215,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
kws_graph: Optional[ContextGraph] = None,
keywords_graph: Optional[ContextGraph] = None,
) -> List[List[Tuple[str, Tuple[int, int]]]]:
"""Decode one batch and return the result in a list.
Expand All @@ -242,7 +235,7 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
kws_graph:
keywords_graph:
The graph containing keywords.
Returns:
Return the decoding result. See above description for the format of
Expand Down Expand Up @@ -274,7 +267,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
keywords_graph=kws_graph,
keywords_graph=keywords_graph,
beam=params.beam,
num_tailing_blanks=params.num_tailing_blanks,
blank_penalty=params.blank_penalty,
Expand All @@ -295,7 +288,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
kws_graph: ContextGraph,
keywords_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
Expand All @@ -310,7 +303,7 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
kws_graph:
keywords_graph:
The graph containing keywords.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
Expand Down Expand Up @@ -341,7 +334,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
kws_graph=kws_graph,
keywords_graph=keywords_graph,
batch=batch,
)

Expand Down Expand Up @@ -459,7 +452,7 @@ def save_results(
s += f"\tPrecision: {precision:.3f}\n"
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
if key != "all":
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
Expand Down Expand Up @@ -505,7 +498,7 @@ def main():
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-version-{params.keywords_version}"
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"

setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
Expand Down Expand Up @@ -555,10 +548,10 @@ def main():

params.keywords_config = "".join(keywords_config)

kws_graph = ContextGraph(
keywords_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
kws_graph.build(
keywords_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
Expand Down Expand Up @@ -677,7 +670,7 @@ def main():
params=params,
model=model,
sp=sp,
kws_graph=kws_graph,
keywords_graph=keywords_graph,
keywords=keywords,
test_only_keywords="fsc" in test_set,
)
Expand Down
2 changes: 1 addition & 1 deletion egs/gigaspeech/KWS/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def get_parser():
return parser


def add_model_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--world-size",
type=int,
Expand Down
12 changes: 6 additions & 6 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def keywords_search(
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert context_graph is not None
assert keywords_graph is not None

packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
Expand All @@ -1018,7 +1018,7 @@ def keywords_search(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=context_graph.root,
context_state=keywords_graph.root,
timestamp=[],
ac_probs=[],
)
Expand Down Expand Up @@ -1125,7 +1125,7 @@ def keywords_search(
context_score,
new_context_state,
_,
) = context_graph.forward_one_step(hyp.context_state, new_token)
) = keywords_graph.forward_one_step(hyp.context_state, new_token)
new_num_tailing_blanks = 0
if new_context_state.token == -1: # root
new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
Expand All @@ -1143,7 +1143,7 @@ def keywords_search(
B[i].add(new_hyp)

top_hyp = B[i].get_most_probable(length_norm=True)
matched, matched_state = context_graph.is_matched(top_hyp.context_state)
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
if matched:
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
Expand All @@ -1164,7 +1164,7 @@ def keywords_search(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=context_graph.root,
context_state=keywords_graph.root,
timestamp=[],
ac_probs=[],
)
Expand All @@ -1174,7 +1174,7 @@ def keywords_search(

for i, hyps in enumerate(B):
top_hyp = hyps.get_most_probable(length_norm=True)
matched, matched_state = context_graph.is_matched(top_hyp.context_state)
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
if matched:
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
Expand Down
1 change: 1 addition & 0 deletions egs/wenetspeech/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -376,5 +376,6 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
--token-type $token \
--lang-dir $lang_dir
fi
python ./local/compile_lg.py --lang-dir $lang_dir
done
fi
74 changes: 40 additions & 34 deletions egs/wenetspeech/KWS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,63 +22,69 @@ log() {
}

if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare gigaspeech dataset."
log "Stage 0: Prepare wewetspeech dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.gigaspeech.done ]; then
if [ ! -e data/fbank/.wewetspeech.done ]; then
pushd ../ASR
./prepare.sh --stage 0 --stop-stage 9
./prepare.sh --stage 11 --stop-stage 11
./prepare.sh --stage 0 --stop-stage 17
./prepare.sh --stage 22 --stop-stage 22
popd
pushd data/fbank
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) .
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) .
ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/L_split_1000) .
ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/M_split_1000) .
ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/S_split_1000) .
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
popd
pushd data
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
ln -svf $(realpath ../ASR/data/lang_partial_tone) .
popd
touch data/fbank/.wewetspeech.done
else
log "Gigaspeech dataset already exists, skipping."
log "WenetSpeech dataset already exists, skipping."
fi
fi

if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare open commands dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
if [ ! -e data/fbank/.cn_speech_commands.done ]; then
pushd data
git clone https://github.com/pkufool/open-commands.git
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 3 --stop-stage 3
./script/prepare.sh --stage 6 --stop-stage 6
./script/prepare.sh --stage 1 --stop-stage 1
./script/prepare.sh --stage 3 --stop-stage 5
popd
popd
pushd data/fbank
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) .
popd
touch data/fbank/.fluent_speech_commands.done
touch data/fbank/.cn_speech_commands.done
else
log "Fluent speech commands dataset already exists, skipping."
log "CN speech commands dataset already exists, skipping."
fi
fi
2 changes: 1 addition & 1 deletion egs/wenetspeech/KWS/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fi

if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train a model."
if [ ! -e data/fbank/.gigaspeech.done ]; then
if [ ! -e data/fbank/.wenetspeech.done ]; then
log "You need to run the prepare.sh first."
exit -1
fi
Expand Down
47 changes: 1 addition & 46 deletions egs/wenetspeech/KWS/zipformer/decode-asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,7 @@
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search (trivial_graph)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(4) fast beam search (LG)
(1) fast beam search (LG)
./zipformer/decode.py \
--epoch 30 \
--avg 15 \
Expand All @@ -61,20 +30,6 @@
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
"""


Expand Down
Loading

0 comments on commit 7d91e8b

Please sign in to comment.