From e53eae7f014a4561afddbd844501c9df850a39c5 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 26 Jun 2023 17:22:02 +0800 Subject: [PATCH 01/60] Init commit for swbd --- egs/swbd/ASR/README.md | 29 + egs/swbd/ASR/RESULTS.md | 50 + egs/swbd/ASR/conformer_ctc/__init__.py | 0 egs/swbd/ASR/conformer_ctc/ali.py | 395 ++++++++ egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 410 ++++++++ egs/swbd/ASR/conformer_ctc/conformer.py | 910 +++++++++++++++++ egs/swbd/ASR/conformer_ctc/decode.py | 815 +++++++++++++++ egs/swbd/ASR/conformer_ctc/export.py | 163 +++ egs/swbd/ASR/conformer_ctc/label_smoothing.py | 109 ++ egs/swbd/ASR/conformer_ctc/pretrained.py | 430 ++++++++ egs/swbd/ASR/conformer_ctc/subsampling.py | 153 +++ .../ASR/conformer_ctc/test_label_smoothing.py | 52 + .../ASR/conformer_ctc/test_subsampling.py | 48 + .../ASR/conformer_ctc/test_transformer.py | 104 ++ egs/swbd/ASR/conformer_ctc/train.py | 813 +++++++++++++++ egs/swbd/ASR/conformer_ctc/transformer.py | 928 ++++++++++++++++++ egs/swbd/ASR/local/compile_hlg.py | 167 ++++ egs/swbd/ASR/local/compile_lg.py | 147 +++ egs/swbd/ASR/local/compute_fbank_swbd.py | 158 +++ .../convert_transcript_words_to_tokens.py | 103 ++ egs/swbd/ASR/local/dict.patch | 380 +++++++ .../ASR/local/display_manifest_statistics.py | 211 ++++ egs/swbd/ASR/local/eval2000_data_prep.sh | 120 +++ egs/swbd/ASR/local/extend_segments.pl | 99 ++ egs/swbd/ASR/local/filter_cuts.py | 160 +++ egs/swbd/ASR/local/filter_empty_text.py | 71 ++ egs/swbd/ASR/local/format_acronyms_dict.py | 118 +++ egs/swbd/ASR/local/generate_unique_lexicon.py | 98 ++ .../ASR/local/map_acronyms_transcripts.py | 60 ++ egs/swbd/ASR/local/prepare_lang.py | 413 ++++++++ egs/swbd/ASR/local/prepare_lang_bpe.py | 277 ++++++ .../ASR/local/prepare_lm_training_data.py | 167 ++++ egs/swbd/ASR/local/prepare_swbd_testsets.py | 161 +++ egs/swbd/ASR/local/rt03_data_prep.sh | 107 ++ egs/swbd/ASR/local/sort_lm_training_data.py | 141 +++ egs/swbd/ASR/local/swbd1_data_prep.sh | 132 +++ egs/swbd/ASR/local/swbd1_map_words.pl | 52 + egs/swbd/ASR/local/swbd1_prepare_dict.sh | 92 ++ egs/swbd/ASR/local/train_bpe_model.py | 100 ++ egs/swbd/ASR/local/validate_bpe_lexicon.py | 77 ++ egs/swbd/ASR/prepare.sh | 403 ++++++++ egs/swbd/ASR/shared | 1 + egs/swbd/ASR/utils/filter_scp.pl | 87 ++ egs/swbd/ASR/utils/fix_data_dir.sh | 197 ++++ egs/swbd/ASR/utils/parse_options.sh | 97 ++ egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl | 27 + egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl | 38 + 47 files changed, 9870 insertions(+) create mode 100644 egs/swbd/ASR/README.md create mode 100644 egs/swbd/ASR/RESULTS.md create mode 100644 egs/swbd/ASR/conformer_ctc/__init__.py create mode 100755 egs/swbd/ASR/conformer_ctc/ali.py create mode 100644 egs/swbd/ASR/conformer_ctc/asr_datamodule.py create mode 100644 egs/swbd/ASR/conformer_ctc/conformer.py create mode 100755 egs/swbd/ASR/conformer_ctc/decode.py create mode 100755 egs/swbd/ASR/conformer_ctc/export.py create mode 100644 egs/swbd/ASR/conformer_ctc/label_smoothing.py create mode 100755 egs/swbd/ASR/conformer_ctc/pretrained.py create mode 100644 egs/swbd/ASR/conformer_ctc/subsampling.py create mode 100755 egs/swbd/ASR/conformer_ctc/test_label_smoothing.py create mode 100755 egs/swbd/ASR/conformer_ctc/test_subsampling.py create mode 100644 egs/swbd/ASR/conformer_ctc/test_transformer.py create mode 100755 egs/swbd/ASR/conformer_ctc/train.py create mode 100644 egs/swbd/ASR/conformer_ctc/transformer.py create mode 100755 egs/swbd/ASR/local/compile_hlg.py create mode 100755 egs/swbd/ASR/local/compile_lg.py create mode 100755 egs/swbd/ASR/local/compute_fbank_swbd.py create mode 100755 egs/swbd/ASR/local/convert_transcript_words_to_tokens.py create mode 100644 egs/swbd/ASR/local/dict.patch create mode 100755 egs/swbd/ASR/local/display_manifest_statistics.py create mode 100755 egs/swbd/ASR/local/eval2000_data_prep.sh create mode 100755 egs/swbd/ASR/local/extend_segments.pl create mode 100755 egs/swbd/ASR/local/filter_cuts.py create mode 100755 egs/swbd/ASR/local/filter_empty_text.py create mode 100755 egs/swbd/ASR/local/format_acronyms_dict.py create mode 100755 egs/swbd/ASR/local/generate_unique_lexicon.py create mode 100755 egs/swbd/ASR/local/map_acronyms_transcripts.py create mode 100755 egs/swbd/ASR/local/prepare_lang.py create mode 100755 egs/swbd/ASR/local/prepare_lang_bpe.py create mode 100755 egs/swbd/ASR/local/prepare_lm_training_data.py create mode 100755 egs/swbd/ASR/local/prepare_swbd_testsets.py create mode 100755 egs/swbd/ASR/local/rt03_data_prep.sh create mode 100755 egs/swbd/ASR/local/sort_lm_training_data.py create mode 100755 egs/swbd/ASR/local/swbd1_data_prep.sh create mode 100755 egs/swbd/ASR/local/swbd1_map_words.pl create mode 100755 egs/swbd/ASR/local/swbd1_prepare_dict.sh create mode 100755 egs/swbd/ASR/local/train_bpe_model.py create mode 100755 egs/swbd/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/swbd/ASR/prepare.sh create mode 120000 egs/swbd/ASR/shared create mode 100755 egs/swbd/ASR/utils/filter_scp.pl create mode 100755 egs/swbd/ASR/utils/fix_data_dir.sh create mode 100755 egs/swbd/ASR/utils/parse_options.sh create mode 100755 egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl create mode 100755 egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md new file mode 100644 index 0000000000..b9bbab787a --- /dev/null +++ b/egs/swbd/ASR/README.md @@ -0,0 +1,29 @@ +# Switchboard + +The Switchboard-1 Telephone Speech Corpus (LDC97S62) consists of approximately 260 hours of speech and was originally collected by Texas Instruments in 1990-1, under DARPA sponsorship. The first release of the corpus was published by NIST and distributed by the LDC in 1992-3. Since that release, a number of corrections have been made to the data files as presented on the original CD-ROM set and all copies of the first pressing have been distributed. + +Switchboard is a collection of about 2,400 two-sided telephone conversations among 543 speakers (302 male, 241 female) from all areas of the United States. A computer-driven robot operator system handled the calls, giving the caller appropriate recorded prompts, selecting and dialing another person (the callee) to take part in a conversation, introducing a topic for discussion and recording the speech from the two subjects into separate channels until the conversation was finished. About 70 topics were provided, of which about 50 were used frequently. Selection of topics and callees was constrained so that: (1) no two speakers would converse together more than once and (2) no one spoke more than once on a given topic. + +(The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) + +**Caution**: The `conformer_ctc` recipe for Switchboard is currently very rough and has a high Word Error Rate, requiring more improvement and refinement. The TODO list for this recipe is as follows. + +## TODO List +- [ ] Incorporate Lhotse for dataprocessing +- [ ] Refer to Global Mapping Rules when computing Word Error Rate +- [ ] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset +- [ ] Switchboard transcript train/dev split for LM training +- [ ] Fisher corpus LDC2004T19 LDC2005T19 LDC2004S13 LDC2005S13 for LM training + +## Performance Record +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. + +## Credit + +The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. + +A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored to incorporate with Lhotse and icefall. diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md new file mode 100644 index 0000000000..ef8216f036 --- /dev/null +++ b/egs/swbd/ASR/RESULTS.md @@ -0,0 +1,50 @@ +## Results +### Switchboard BPE training results (Conformer-CTC) + +#### 2023-06-26 + +The best WER, as of 2023-06-26, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.3 | 2.5 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.7 | 1.3 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 55 \ + --avg 1 \ + --max-duration 50 +``` + + diff --git a/egs/swbd/ASR/conformer_ctc/__init__.py b/egs/swbd/ASR/conformer_ctc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/swbd/ASR/conformer_ctc/ali.py b/egs/swbd/ASR/conformer_ctc/ali.py new file mode 100755 index 0000000000..42e14abac9 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/ali.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./conformer_ctc/ali.py \ + --exp-dir ./conformer_ctc/exp \ + --lang-dir ./data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse import CutSet +from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--out-dir", + type=str, + required=True, + help="""Output directory. + It contains 3 generated files: + + - labels_xxx.h5 + - aux_labels_xxx.h5 + - librispeech_cuts_xxx.jsonl.gz + + where xxx is the value of `--dataset`. For instance, if + `--dataset` is `train-clean-100`, it will contain 3 files: + + - `labels_train-clean-100.h5` + - `aux_labels_train-clean-100.h5` + - `librispeech_cuts_train-clean-100.jsonl.gz` + + Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise + alignment. The difference is that labels_xxx.h5 contains repeats. + """, + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + # Set it to 0 since attention decoder + # is not used for computing alignments + "num_decoder_layers": 0, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + labels_writer: FeaturesWriter, + aux_labels_writer: FeaturesWriter, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> CutSet: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a CutSet. Each cut has two custom fields: labels_alignment + and aux_labels_alignment, containing framewise alignments information. + Both are of type `lhotse.array.TemporalArray`. The difference between + the two alignments is that `labels_alignment` contain repeats. + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = graph_compiler.device + cuts = [] + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cut_list = supervisions["cut"] + + for cut in cut_list: + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + + cut_list = [cut_list[i] for i in new2old] + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.output_beam, + ) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + + labels_ali = get_alignments(best_path, kind="labels") + aux_labels_ali = get_alignments(best_path, kind="aux_labels") + assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): + cut.labels_alignment = labels_writer.store_array( + key=cut.id, + value=np.asarray(labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + cut.aux_labels_alignment = aux_labels_writer.store_array( + key=cut.id, + value=np.asarray(aux_labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + + cuts += cut_list + + num_cuts += len(cut_list) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + args.enable_spec_aug = False + args.enable_musan = False + args.return_cuts = True + args.concatenate_cuts = False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ali") + + logging.info(f"Computing alignments for {params.dataset} - started") + logging.info(params) + + out_dir = Path(params.out_dir) + out_dir.mkdir(exist_ok=True) + + out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" + out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + + for f in ( + out_labels_ali_filename, + out_aux_labels_ali_filename, + out_manifest_filename, + ): + if f.exists(): + logging.info(f"{f} exists - skipping") + return + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + logging.info(f"Processing {params.dataset}") + with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: + with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + labels_writer=labels_writer, + aux_labels_writer=aux_labels_writer, + params=params, + graph_compiler=graph_compiler, + ) + + cut_set.to_file(out_manifest_filename) + + logging.info( + f"For dataset {params.dataset}, its alignments with repeats are " + f"saved to {out_labels_ali_filename}, the alignments without repeats " + f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 0000000000..c36b8727f5 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,410 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class SwitchBoardAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. SwitchBoard rt03 + and eval2000). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_cuts(self) -> CutSet: + logging.info("switchboard: About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz").subset(last=2388).trim_to_supervisions(keep_all_channels=True) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("switchboard: About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz").subset(first=50).trim_to_supervisions(keep_all_channels=True) + + @lru_cache() + def test_eval2000_cuts(self) -> CutSet: + logging.info("switchboard: About to get eval2000 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_eval2000.jsonl.gz" + ) + + @lru_cache() + def test_rt03_cuts(self) -> CutSet: + logging.info("switchboard: About to get rt03 cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py new file mode 100644 index 0000000000..a1cfe6e75b --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1,910 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + use_conv_batchnorm = True + if isinstance(use_feat_batchnorm, float): + use_conv_batchnorm = False + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + use_conv_batchnorm, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + if mask is not None: + mask = mask.to(x.device) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + use_conv_batchnorm: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm + ) + + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + use_batchnorm: bool = False, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + self.use_batchnorm = use_batchnorm + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + if self.use_batchnorm: + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + if self.use_batchnorm: + x = self.norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py new file mode 100755 index 0000000000..a97936af84 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -0,0 +1,815 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=77, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + switchboard = SwitchBoardAsrDataModule(args) + + test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions(keep_all_channels=True) + test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions(keep_all_channels=True) + + test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) + test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + + test_sets = ["test-eval2000", "test-rt03"] + test_dl = [test_eval2000_dl, test_rt03_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py new file mode 100755 index 0000000000..fbcbd7b29a --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py new file mode 100644 index 0000000000..52d2eda3bb --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 100755 index 0000000000..30def9c404 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from the rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=params.num_classes > 500, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py new file mode 100644 index 0000000000..8e0f73d05b --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1,153 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.ReLU(), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 0000000000..5d4438fd1f --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/test_subsampling.py b/egs/swbd/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 0000000000..81fa234dd5 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py new file mode 100644 index 0000000000..667057c513 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch.nn.utils.rnn import pad_sequence +from transformer import ( + Transformer, + add_eos, + add_sos, + decoder_padding_mask, + encoder_padding_mask, + generate_square_subsequent_mask, +) + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py new file mode 100755 index 0000000000..0c795868e3 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=78, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + switchboard = SwitchBoardAsrDataModule(args) + + train_cuts = switchboard.train_all_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 600.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = switchboard.train_dataloaders(train_cuts) + + valid_cuts = switchboard.dev_cuts() + valid_dl = switchboard.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py new file mode 100644 index 0000000000..0566cfc81c --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1,928 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling, VggSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + use_feat_batchnorm: + True to use batchnorm for the input layer. + Float value to scale the input layer. + False to do nothing. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + assert isinstance(use_feat_batchnorm, (float, bool)) + if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + if isinstance(self.use_feat_batchnorm, float): + x *= self.use_feat_batchnorm + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py new file mode 100755 index 0000000000..d19d50ae63 --- /dev/null +++ b/egs/swbd/ASR/local/compile_hlg.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_n_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"data/lm/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir, args.lm) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py new file mode 100755 index 0000000000..709b140702 --- /dev/null +++ b/egs/swbd/ASR/local/compile_lg.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing LG. + """ + lexicon = Lexicon(lang_dir) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"data/lm/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "LG.pt").is_file(): + logging.info(f"{lang_dir}/LG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir, args.lm) + logging.info(f"Saving LG.pt to {lang_dir}") + torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py new file mode 100755 index 0000000000..a270f52ee4 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all", "eval2000", "rt03") + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "swbd" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"].resample(16000), + supervisions=m["supervisions"], + ) + + if "train" in partition: + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ).resample(16000) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 0000000000..a8d5117c97 --- /dev/null +++ b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/dict.patch b/egs/swbd/ASR/local/dict.patch new file mode 100644 index 0000000000..12c63d6127 --- /dev/null +++ b/egs/swbd/ASR/local/dict.patch @@ -0,0 +1,380 @@ +1d0 +< file: $SWB/data/dictionary/sw-ms98-dict.text +8645a8646 +> uh-hum ah m hh ah m +9006c9007 +< April ey p r ih l +--- +> April ey p r ax l +9144d9144 +< B ay zh aa n iy z +9261c9261 +< Battle b ae t el +--- +> Battle b ae t ax l +10014a10015 +> Chevy sh eh v iy +10211a10213 +> Colorado k ao l ax r aa d ow +10212a10215 +> Colorado' k ao l ax r aa d ow z +10370c10373 +< Creek k r ih k +--- +> Creek k r iy k +10889a10893 +> Eleven ax l eh v ih n +10951c10955 +< Erie ih r iy +--- +> Erie iy r iy +11183c11187 +< Forever f ax r eh v er +--- +> Forever f er eh v er +11231a11236 +> Friday f r ay d iy +11744a11750 +> History hh ih s t r iy +12004a12011,12012 +> Israel ih z r ih l +> Israel's ih z r ih l z +12573a12582 +> Lincoln l ih ng k ih n +12574a12584 +> Lincolns l ih ng k ih n z +13268c13278 +< NAACP eh ey ey s iy p iy +--- +> NAACP eh n ey ey s iy p iy +13286c13296 +< NIT eh ay t iy +--- +> NIT eh n ay t iy +13292c13302 +< NTSC eh t iy eh s s iy +--- +> NTSC eh n t iy eh s s iy +14058a14069 +> Quarter k ow r t er +14059a14071 +> Quarterback k ow r t er b ae k +14060a14073 +> Quarters k ow r t er z +14569a14583 +> Science s ay n s +15087a15102 +> Sunday s ah n d iy +15088a15104 +> Sunday's s ah n d iy z +15089a15106 +> Sundays s ah n d iy z +15290,15291c15307,15308 +< Texan t eh k sh ih n +< Texan's t eh k sh ih n s +--- +> Texan t eh k s ih n +> Texan's t eh k s ih n s +15335a15353 +> Thousands th aw z ih n z +15739c15757 +< Waco w ae k ow +--- +> Waco w ey k ow +15841a15860 +> Weekends w iy k eh n z +16782a16802 +> acceptable eh k s eh p ax b ax l +16833a16854 +> accounting ax k aw n ih ng +16948a16970 +> address ax d r eh s +17281a17304 +> already aa r d iy +17315a17339 +> am m +17709a17734 +> asked ae s t +17847a17873 +> attorney ih t er n iy +17919a17946 +> autopilot ao t ow p ay l ih t +17960a17988 +> awfully ao f l iy +18221a18250 +> basketball b ae s k ax b ao l +18222a18252 +> basketball's b ae s k ax b ao l z +18302a18333 +> become b ah k ah m +18303a18335 +> becomes b iy k ah m z +18344a18377 +> began b ax g en n +18817c18850 +< bottle b aa t el +--- +> bottle b aa t ax l +19332,19333c19365,19367 +< camera's k ae m ax r ax z +< cameras k ae m ax r ax z +--- +> camera k ae m r ax +> camera's k ae m r ax z +> cameras k ae m r ax z +19411a19446 +> capital k ae p ax l +19505a19541 +> carrying k ae r ih ng +20316a20353,20354 +> combination k aa m ih n ey sh ih n +> combinations k aa m ih n ey sh ih n z +20831a20870 +> contracts k aa n t r ae k s +21010a21050 +> costs k ao s +21062a21103 +> county k aw n iy +21371a21413 +> cultural k ao l ch ax r ax l +21372a21415 +> culturally k ao l ch ax r ax l iy +21373a21417 +> culture k ao l ch er +21375a21420 +> cultures k ao l ch er z +21543a21589 +> data d ey t ax +22097a22144 +> differently d ih f ax r ih n t l iy +22972a23020 +> effects ax f eh k t s +23016a23065 +> election ax l eh k sh ih n +23018a23068 +> elections ax l eh k sh ih n z +23052a23103 +> eleven ax l eh v ih n +23242a23294 +> enjoyable ae n jh oy ax b ax l +23248a23301 +> enjoys ae n jh oy z +23293a23347 +> entire ih n t ay r +23295a23350,23351 +> entirely ih n t ay r l iy +> entirety ih n t ay r t iy +23745a23802 +> extra eh k s t er +23818a23876 +> facts f ae k s +24508c24566 +< forever f ax r eh v er +--- +> forever f er eh v er +24514c24572 +< forget f ow r g eh t +--- +> forget f er r g eh t +24521a24580 +> forgot f er r g aa t +24522a24582 +> forgotten f er r g aa t ax n +24563a24624 +> forward f ow er d +24680a24742 +> frightening f r ay t n ih ng +24742a24805 +> full-time f ax l t ay m +24862a24926 +> garage g r aa jh +25218a25283 +> grandmother g r ae m ah dh er +25790a25856 +> heavily hh eh v ax l iy +25949a26016 +> history hh ih s t r iy +26038a26106 +> honestly aa n ax s t l iy +26039a26108 +> honesty aa n ax s t iy +26099a26169 +> horror hh ow r +26155a26226 +> houses hh aw z ih z +26184c26255 +< huh-uh hh ah hh ah +--- +> huh-uh ah hh ah +26189c26260 +< hum-um hh m hh m +--- +> hum-um ah m hh ah m +26236a26308 +> hunting hh ah n ih ng +26307a26380,26381 +> ideal ay d iy l +> idealist ay d iy l ih s t +26369a26444 +> imagine m ae jh ih n +26628a26704 +> individuals ih n d ih v ih jh ax l z +26968a27045 +> interest ih n t r ih s t +27184a27262 +> it'd ih d +27702a27781 +> lead l iy d +28378a28458 +> mandatory m ae n d ih t ow r iy +28885a28966 +> minute m ih n ih t +29167a29249 +> mountains m aw t n z +29317a29400 +> mysteries m ih s t r iy z +29318a29402 +> mystery m ih s t r iy +29470a29555 +> nervous n er v ih s +29578,29580c29663,29665 +< nobody n ow b aa d iy +< nobody'll n ow b aa d iy l +< nobody's n ow b aa d iy z +--- +> nobody n ow b ah d iy +> nobody'll n ow b ah d iy l +> nobody's n ow b ah d iy z +29712a29798 +> nuclear n uw k l iy r +29938a30025 +> onto aa n t ax +30051a30139 +> originally ax r ih jh ax l iy +30507a30596 +> particularly p er t ih k y ax l iy +30755a30845 +> perfectly p er f ih k l iy +30820a30911 +> personally p er s n ax l iy +30915a31007 +> physically f ih z ih k l iy +30986a31079 +> pilot p ay l ih t +30987a31081 +> pilot's p ay l ih t s +31227a31322 +> police p l iy s +31513a31609 +> prefer p er f er +31553a31650 +> prepare p r ax p ey r +31578a31676 +> prescription p er s k r ih p sh ih n +31579a31678 +> prescriptions p er s k r ih p sh ih n z +31770a31870 +> products p r aa d ax k s +31821a31922 +> projects p r aa jh eh k s +31908a32010 +> protect p er t eh k t +31909a32012 +> protected p er t eh k t ih d +31911a32015 +> protection p er t eh k sh ih n +31914a32019 +> protection p er t eh k t ih v +32149a32255 +> quarter k ow r t er +32414a32521 +> read r iy d +32785a32893 +> rehabilitation r iy ax b ih l ih t ey sh ih n +33150a33259 +> resource r ih s ow r s +33151a33261 +> resources r iy s ow r s ih z +33539c33649 +< roots r uh t s +--- +> roots r uw t s +33929a34040 +> science s ay n s +34315a34427 +> seventy s eh v ih n iy +34319,34320c34431,34432 +< severe s ax v iy r +< severely s ax v iy r l iy +--- +> severe s ih v iy r +> severely s ih v iy r l iy +35060a35173 +> software s ao f w ey r +35083a35197 +> solid s ao l ih d +35084a35199 +> solidly s ao l ih d l iy +35750a35866 +> stood s t ih d +35854a35971 +> strictly s t r ih k l iy +35889c36006 +< stronger s t r ao ng er +--- +> stronger s t r ao ng g er +36192a36310,36311 +> supposed s p ow z +> supposed s p ow s +36510a36630 +> tastes t ey s +36856a36977 +> thoroughly th er r l iy +36866a36988 +> thousands th aw z ih n z +37081c37203 +< toots t uh t s +--- +> toots t uw t s +37157a37280 +> toward t w ow r d +37158a37282 +> towards t w ow r d z +37564a37689 +> twenties t w eh n iy z +37565a37691 +> twentieth t w eh n iy ih th +37637a37764 +> unacceptable ah n ae k s eh p ax b ax l +37728a37856 +> understand ah n d er s t ae n +37860a37989 +> unless ih n l eh s +38040a38170 +> use y uw z +38049a38180 +> uses y uw z ih z +38125a38257 +> various v ah r iy ih s +38202a38335 +> versus v er s ih z +38381c38514 +< wacko w ae k ow +--- +> wacko w ey k ow +38455c38588 +< wanna w aa n ax +--- +> wanna w ah n ax +38675c38808 +< whatnot w ah t n aa t +--- +> whatnot w aa t n aa t +38676a38810 +> whatsoever w aa t s ow eh v er +38890c39024 +< wok w aa k +--- +> wok w ao k +38910a39045 +> wondering w ah n d r ih ng diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py new file mode 100755 index 0000000000..a6eb1e2b25 --- /dev/null +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" + # path = "./data/fbank/swbd_cuts_eval2000.jsonl.gz" + path = "./data/fbank/swbd_cuts_all.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +## train-clean-100 +Cuts count: 85617 +Total duration (hours): 303.8 +Speech duration (hours): 303.8 (100.0%) +*** +Duration statistics (seconds): +mean 12.8 +std 3.8 +min 1.3 +0.1% 1.9 +0.5% 2.2 +1% 2.5 +5% 4.2 +10% 6.4 +25% 11.4 +50% 13.8 +75% 15.3 +90% 16.7 +95% 17.3 +99% 18.1 +99.5% 18.4 +99.9% 18.8 +max 27.2 + +## train-clean-360 +Cuts count: 312042 +Total duration (hours): 1098.2 +Speech duration (hours): 1098.2 (100.0%) +*** +Duration statistics (seconds): +mean 12.7 +std 3.8 +min 1.0 +0.1% 1.8 +0.5% 2.2 +1% 2.5 +5% 4.2 +10% 6.2 +25% 11.2 +50% 13.7 +75% 15.3 +90% 16.6 +95% 17.3 +99% 18.1 +99.5% 18.4 +99.9% 18.8 +max 33.0 + +## train-other 500 +Cuts count: 446064 +Total duration (hours): 1500.6 +Speech duration (hours): 1500.6 (100.0%) +*** +Duration statistics (seconds): +mean 12.1 +std 4.2 +min 0.8 +0.1% 1.7 +0.5% 2.1 +1% 2.3 +5% 3.5 +10% 5.0 +25% 9.8 +50% 13.4 +75% 15.1 +90% 16.5 +95% 17.2 +99% 18.1 +99.5% 18.4 +99.9% 18.9 +max 31.0 + +## dev-clean +Cuts count: 2703 +Total duration (hours): 5.4 +Speech duration (hours): 5.4 (100.0%) +*** +Duration statistics (seconds): +mean 7.2 +std 4.7 +min 1.4 +0.1% 1.6 +0.5% 1.8 +1% 1.9 +5% 2.4 +10% 2.7 +25% 3.8 +50% 5.9 +75% 9.3 +90% 13.3 +95% 16.4 +99% 23.8 +99.5% 28.5 +99.9% 32.3 +max 32.6 + +## dev-other +Cuts count: 2864 +Total duration (hours): 5.1 +Speech duration (hours): 5.1 (100.0%) +*** +Duration statistics (seconds): +mean 6.4 +std 4.3 +min 1.1 +0.1% 1.3 +0.5% 1.7 +1% 1.8 +5% 2.2 +10% 2.6 +25% 3.5 +50% 5.3 +75% 7.9 +90% 12.0 +95% 15.0 +99% 22.2 +99.5% 27.1 +99.9% 32.4 +max 35.2 + +## test-clean +Cuts count: 2620 +Total duration (hours): 5.4 +Speech duration (hours): 5.4 (100.0%) +*** +Duration statistics (seconds): +mean 7.4 +std 5.2 +min 1.3 +0.1% 1.6 +0.5% 1.8 +1% 2.0 +5% 2.3 +10% 2.7 +25% 3.7 +50% 5.8 +75% 9.6 +90% 14.6 +95% 17.8 +99% 25.5 +99.5% 28.4 +99.9% 32.8 +max 35.0 + +## test-other +Cuts count: 2939 +Total duration (hours): 5.3 +Speech duration (hours): 5.3 (100.0%) +*** +Duration statistics (seconds): +mean 6.5 +std 4.4 +min 1.2 +0.1% 1.5 +0.5% 1.8 +1% 1.9 +5% 2.3 +10% 2.6 +25% 3.4 +50% 5.2 +75% 8.2 +90% 12.6 +95% 15.8 +99% 21.4 +99.5% 23.8 +99.9% 33.5 +max 34.5 +""" diff --git a/egs/swbd/ASR/local/eval2000_data_prep.sh b/egs/swbd/ASR/local/eval2000_data_prep.sh new file mode 100755 index 0000000000..f881983286 --- /dev/null +++ b/egs/swbd/ASR/local/eval2000_data_prep.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +# Hub-5 Eval 2000 data preparation +# Author: Arnab Ghoshal (Jan 2013) + +# To be run from one directory above this script. + +# The input is two directory names (possibly the same) containing the +# 2000 Hub5 english evaluation test set and transcripts, which are +# respectively: LDC2002S09 LDC2002T43 +# e.g. see +# http://www.ldc.upenn.edu/Catalog/catalogEntry.jsp?catalogId=LDC2002S09 +# http://www.ldc.upenn.edu/Catalog/CatalogEntry.jsp?catalogId=LDC2002T43 +# +# Example usage: +# local/eval2000_data_prep_edin.sh /exports/work/inf_hcrc_cstr_general/corpora/hub5/2000 /exports/work/inf_hcrc_cstr_general/corpora/hub5/2000/transcr +# The first directory ($sdir) contains the speech data, and the directory +# $sdir/english/ must exist. +# The second directory ($tdir) contains the transcripts, and the directory +# $tdir/reference must exist; in particular we need the file +# $tdir/reference/hub5e00.english.000405.stm + +if [ $# -ne 2 ]; then + echo "Usage: "`basename $0`" " + echo "See comments in the script for more details" + exit 1 +fi + +sdir=$1 +tdir=$2 +[ ! -d $sdir/english ] \ + && echo Expecting directory $sdir/english to be present && exit 1; + [ -d $tdir/2000_hub5_eng_eval_tr ] \ + && tdir=$tdir/2000_hub5_eng_eval_tr + [ ! -d $tdir/reference ] \ + && echo Expecting directory $tdir/reference to be present && exit 1; + + + dir=data/local/eval2000 + mkdir -p $dir + + find $sdir/english -iname '*.sph' | sort > $dir/sph.flist + sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + > $dir/sph.scp + + sph2pipe=sph2pipe + [ ! -x $sph2pipe ] \ + && echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1; + + awk -v sph2pipe=$sph2pipe '{ + printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); + printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); + }' < $dir/sph.scp | sort > $dir/wav.scp || exit 1; + #side A - channel 1, side B - channel 2 + +# Get segments file... +# segments file format is: utt-id side-id start-time end-time, e.g.: +# sw02001-A_000098-001156 sw02001-A 0.98 11.56 +pem=$sdir/english/hub5e_00.pem +[ ! -f $pem ] && echo "$0: No such file $pem" && exit 1; +# pem file has lines like: +# en_4156 A unknown_speaker 301.85 302.48 + +# we ignore the warnings below for now, although they seem to indicate some problems +# with the data. +grep -v ';;' $pem \ + | awk '{ + spk=$1"-"$2; + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + print utt,spk,$4,$5;}' \ + | sort -u | local/extend_segments.pl 0.1 > $dir/segments + +# stm file has lines like: +# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER +# TODO(arnab): We should really be lowercasing this since the Edinburgh +# recipe uses lowercase. This is not used in the actual scoring. +grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ + | awk '{ + spk=$1"-"$2; + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' \ + | sort > $dir/text.all + +# We'll use the stm file for sclite scoring. There seem to be various errors +# in the stm file that upset hubscr.pl, and we fix them here. +sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $tdir/reference/hub5e00.english.000405.stm > $dir/stm + cp $tdir/reference/en20000405_hub5.glm $dir/glm + +# next line uses command substitution +# Just checking that the segments are the same in pem vs. stm. +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && \ + echo "Segments from pem file and stm file do not match." && exit 1; + +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all > $dir/text + +# create an utt2spk file that assumes each conversation side is +# a separate speaker. +awk '{print $1,$2;}' $dir/segments > $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +# cp $dir/segments $dir/segments.tmp +# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ +# $dir/segments.tmp > $dir/segments + +awk '{print $1}' $dir/wav.scp \ + | perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + > $dir/reco2file_and_channel || exit 1; + + + echo Data preparation and formatting completed for Eval 2000 + echo "(but not MFCC extraction)" + + utils/fix_data_dir.sh $dir + + if [ $(wc -l < $dir/wav.scp) -ne 80 ]; then + echo "$0: error: expected 80 lines in wav.scp, got $(wc -l < $dir/wav.scp)" + exit 1; + fi diff --git a/egs/swbd/ASR/local/extend_segments.pl b/egs/swbd/ASR/local/extend_segments.pl new file mode 100755 index 0000000000..e8b4894d5f --- /dev/null +++ b/egs/swbd/ASR/local/extend_segments.pl @@ -0,0 +1,99 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +if (@ARGV != 1 || !($ARGV[0] =~ m/^-?\d+\.?\d*$/ && $ARGV[0] >= 0)) { + print STDERR "Usage: extend_segments.pl time-in-seconds segments.extended \n" . + "e.g. extend_segments.pl 0.25 segments.2\n" . + "This command modifies a segments file, with lines like\n" . + " \n" . + "by extending the beginning and end of each segment by a certain\n" . + "length of time. This script makes sure the output segments do not\n" . + "overlap as a result of this time-extension, and that there are no\n" . + "negative times in the output.\n"; + exit 1; +} + +$extend = $ARGV[0]; + +@all_lines = (); + +while () { + chop; + @A = split(" ", $_); + if (@A != 4) { + die "invalid line in segments file: $_"; + } + $line = @all_lines; # current number of lines. + ($utt_id, $reco_id, $start_time, $end_time) = @A; + + push @all_lines, [ $utt_id, $reco_id, $start_time, $end_time ]; # anonymous array. + if (! defined $lines_for_reco{$reco_id}) { + $lines_for_reco{$reco_id} = [ ]; # push new anonymous array. + } + push @{$lines_for_reco{$reco_id}}, $line; +} + +foreach $reco_id (keys %lines_for_reco) { + $ref = $lines_for_reco{$reco_id}; + @line_numbers = sort { ${$all_lines[$a]}[2] <=> ${$all_lines[$b]}[2] } @$ref; + + + { + # handle start of earliest segment as a special case. + $l0 = $line_numbers[0]; + $tstart = ${$all_lines[$l0]}[2] - $extend; + if ($tstart < 0.0) { $tstart = 0.0; } + ${$all_lines[$l0]}[2] = $tstart; + } + { + # handle end of latest segment as a special case. + $lN = $line_numbers[$#line_numbers]; + $tend = ${$all_lines[$lN]}[3] + $extend; + ${$all_lines[$lN]}[3] = $tend; + } + for ($i = 0; $i < $#line_numbers; $i++) { + $ln = $line_numbers[$i]; + $ln1 = $line_numbers[$i+1]; + $tend = ${$all_lines[$ln]}[3]; # end of earlier segment. + $tstart = ${$all_lines[$ln1]}[2]; # start of later segment. + if ($tend > $tstart) { + $utt1 = ${$all_lines[$ln]}[0]; + $utt2 = ${$all_lines[$ln1]}[0]; + print STDERR "Warning: for utterances $utt1 and $utt2, segments " . + "already overlap; leaving these times unchanged.\n"; + } else { + $my_extend = $extend; + $max_extend = 0.5 * ($tstart - $tend); + if ($my_extend > $max_extend) { $my_extend = $max_extend; } + $tend += $my_extend; + $tstart -= $my_extend; + ${$all_lines[$ln]}[3] = $tend; + ${$all_lines[$ln1]}[2] = $tstart; + } + } +} + +# leave the numbering of the lines unchanged. +for ($l = 0; $l < @all_lines; $l++) { + $ref = $all_lines[$l]; + ($utt_id, $reco_id, $start_time, $end_time) = @$ref; + printf("%s %s %.2f %.2f\n", $utt_id, $reco_id, $start_time, $end_time); +} + +__END__ + +# testing below. + +# ( echo a1 A 0 1; echo a2 A 3 4; echo b1 B 0 1; echo b2 B 2 3 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.00 +a2 A 2.00 5.00 +b1 B 0.00 1.50 +b2 B 1.50 4.00 +# ( echo a1 A 0 2; echo a2 A 1 3 ) | local/extend_segments.pl 1.0 +Warning: for utterances a1 and a2, segments already overlap; leaving these times unchanged. +a1 A 0.00 2.00 +a2 A 1.00 4.00 +# ( echo a1 A 0 2; echo a2 A 5 6; echo a3 A 3 4 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.50 +a2 A 4.50 7.00 +a3 A 2.50 4.50 diff --git a/egs/swbd/ASR/local/filter_cuts.py b/egs/swbd/ASR/local/filter_cuts.py new file mode 100755 index 0000000000..fbcc9e24af --- /dev/null +++ b/egs/swbd/ASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py new file mode 100755 index 0000000000..d335abb464 --- /dev/null +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +import logging +from typing import List + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-data-dir", + type=Path, + required=True, + help="Path to the kaldi data dir", + ) + + return parser.parse_args() + + +def load_segments(path: Path): + segments = {} + with open(path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + utt_id, rec_id, start, end = line.split() + segments[utt_id] = line + return segments + + +def filter_text(path: Path): + with open(path, "r") as f: + lines = f.readlines() + return list(filter(lambda x: len(x.strip().split()) > 1, lines)) + + +def write_segments(path: Path, texts: List[str]): + with open(path, "w") as f: + f.writelines(texts) + + +def main(): + args = get_args() + orig_text_dict = filter_text(args.kaldi_data_dir / "text") + write_segments(args.kaldi_data_dir / "text", orig_text_dict) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() + + logging.info("Empty lines filtered") diff --git a/egs/swbd/ASR/local/format_acronyms_dict.py b/egs/swbd/ASR/local/format_acronyms_dict.py new file mode 100755 index 0000000000..fa598dd03c --- /dev/null +++ b/egs/swbd/ASR/local/format_acronyms_dict.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd dict to fisher convention +# IBM to i._b._m. +# BBC to b._b._c. +# BBCs to b._b._c.s +# BBC's to b._b._c.'s + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input lexicon", required=True) +parser.add_argument("-o", "--output", help="Output lexicon", required=True) +parser.add_argument( + "-L", "--Letter", help="Input single letter pronunciation", required=True +) +parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) +args = parser.parse_args() + + +fin_lex = open(args.input, "r") +fin_Letter = open(args.Letter, "r") +fout_lex = open(args.output, "w") +fout_map = open(args.Map, "w") + +# Initialise single letter dictionary +dict_letter = {} +for single_letter_lex in fin_Letter: + items = single_letter_lex.split() + dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() +fin_Letter.close() +# print dict_letter + +for lex in fin_lex: + items = lex.split() + word = items[0] + lexicon = lex[len(items[0]) + 1 :].strip() + # find acronyms from words with only letters and ' + pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) + if pre_match: + # find if words in the form of xxx's is acronym + if word[-2:] == "'s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-2] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxxs is acronym + elif word[-1] == "s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-1] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxx (not ended with 's or s) is acronym + elif word.find("'") == -1 and word[-1] != "s": + acronym_lexicon = "" + for w in word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + word[-1].lower() + "." + acronym_mapped_back = acronym_mapped_back + word[-1].lower() + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + else: + fout_lex.write(lex) + + else: + fout_lex.write(lex) diff --git a/egs/swbd/ASR/local/generate_unique_lexicon.py b/egs/swbd/ASR/local/generate_unique_lexicon.py new file mode 100755 index 0000000000..3459c2f5ad --- /dev/null +++ b/egs/swbd/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/map_acronyms_transcripts.py b/egs/swbd/ASR/local/map_acronyms_transcripts.py new file mode 100755 index 0000000000..ba02aaec34 --- /dev/null +++ b/egs/swbd/ASR/local/map_acronyms_transcripts.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd transcript to fisher convention +# according to first two columns in the input acronyms mapping + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input transcripts", required=True) +parser.add_argument("-o", "--output", help="Output transcripts", required=True) +parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) +args = parser.parse_args() + +fin_map = open(args.Map, "r") +dict_acronym = {} +dict_acronym_noi = {} # Mapping of acronyms without I, i +for pair in fin_map: + items = pair.split("\t") + dict_acronym[items[0]] = items[1] + dict_acronym_noi[items[0]] = items[1] +fin_map.close() +del dict_acronym_noi["I"] +del dict_acronym_noi["i"] + + +fin_trans = open(args.input, "r") +fout_trans = open(args.output, "w") +for line in fin_trans: + items = line.split() + L = len(items) + # First pass mapping to map I as part of acronym + for i in range(L): + if items[i] == "I": + x = 0 + while i - 1 - x >= 0 and re.match(r"^[A-Z]$", items[i - 1 - x]): + x += 1 + + y = 0 + while i + 1 + y < L and re.match(r"^[A-Z]$", items[i + 1 + y]): + y += 1 + + if x + y > 0: + for bias in range(-x, y + 1): + items[i + bias] = dict_acronym[items[i + bias]] + + # Second pass mapping (not mapping 'i' and 'I') + for i in range(len(items)): + if items[i] in dict_acronym_noi.keys(): + items[i] = dict_acronym_noi[items[i]] + sentence = " ".join(items[1:]) + fout_trans.write(items[0] + " " + sentence.lower() + "\n") + +fin_trans.close() +fout_trans.close() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py new file mode 100755 index 0000000000..f7d4609a95 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool + +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "sil", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" + sil_token = "sil" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + tokens = get_tokens(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + assert "" not in tokens + tokens = [""] + tokens + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + ["#0", "", ""] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py new file mode 100755 index 0000000000..f5b3a25b34 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!sil", + "", + args.oov, + "#0", + "", + "", + "[vocalized-noise]", + "[noise]", + "[laughter]", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py new file mode 100755 index 0000000000..70343fef7c --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +./download/lm/librispeech-lm-norm.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "librispeech-lm-norm" in args.lm_data: + num_lines_in_total = 40418261.0 + step = 5000000 + elif "valid" in args.lm_data: + num_lines_in_total = 5567.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 5559.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed/num_lines_in_total*100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/prepare_swbd_testsets.py b/egs/swbd/ASR/local/prepare_swbd_testsets.py new file mode 100755 index 0000000000..62ad0fda82 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_swbd_testsets.py @@ -0,0 +1,161 @@ +import argparse +import logging +import os +from pathlib import Path +import tarfile +from itertools import chain +from typing import Dict, Optional, Union + +from lhotse import fix_manifests, validate_recordings_and_supervisions +from lhotse.audio import Recording, RecordingSet +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike, check_and_rglob, resumable_download, safe_extract + +import sentencepiece as spm +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + + +def prepare_switchboard( + audio_dir: Pathlike, + transcripts_dir: Optional[Pathlike] = None, + sentiment_dir: Optional[Pathlike] = None, + output_dir: Optional[Pathlike] = None, + omit_silence: bool = True, + absolute_paths: bool = False, +) -> Dict[str, Union[RecordingSet, SupervisionSet]]: + """ + Prepare manifests for the Switchboard corpus. + We create two manifests: one with recordings, and the other one with text supervisions. + When ``sentiment_dir`` is provided, we create another supervision manifest with sentiment annotations. + :param audio_dir: Path to ``LDC97S62`` package. + :param transcripts_dir: Path to the transcripts directory (typically named "swb_ms98_transcriptions"). + If not provided, the transcripts will be downloaded. + :param sentiment_dir: Optional path to ``LDC2020T14`` package which contains sentiment annotations + for SWBD segments. + :param output_dir: Directory where the manifests should be written. Can be omitted to avoid writing. + :param omit_silence: Whether supervision segments with ``[silence]`` token should be removed or kept. + :param absolute_paths: Whether to return absolute or relative (to the corpus dir) paths for recordings. + :return: A dict with manifests. The keys are: ``{'recordings', 'supervisions'}``. + """ + audio_paths = check_and_rglob(audio_dir, "*.sph") + text_paths = check_and_rglob(transcripts_dir, "*trans.text") + + groups = [] + name_to_text = {p.stem.split("-")[0]: p for p in text_paths} + for ap in audio_paths: + name = ap.stem.replace("sw0", "sw") + groups.append( + { + "audio": ap, + "text-0": name_to_text[f"{name}A"], + "text-1": name_to_text[f"{name}B"], + } + ) + + recordings = RecordingSet.from_recordings( + Recording.from_file( + group["audio"], relative_path_depth=None if absolute_paths else 3 + ) + for group in groups + ) + supervisions = SupervisionSet.from_segments( + chain.from_iterable( + make_segments( + transcript_path=group[f"text-{channel}"], + recording=recording, + channel=channel, + omit_silence=omit_silence, + ) + for group, recording in zip(groups, recordings) + for channel in [0, 1] + ) + ) + + recordings, supervisions = fix_manifests(recordings, supervisions) + validate_recordings_and_supervisions(recordings, supervisions) + + if sentiment_dir is not None: + parse_and_add_sentiment_labels(sentiment_dir, supervisions) + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + recordings.to_file(output_dir / "swbd_recordings_all.jsonl.gz") + supervisions.to_file(output_dir / "swbd_supervisions_all.jsonl.gz") + return {"recordings": recordings, "supervisions": supervisions} + + +def make_segments( + transcript_path: Path, recording: Recording, channel: int, omit_silence: bool = True +): + lines = transcript_path.read_text().splitlines() + return [ + SupervisionSegment( + id=segment_id, + recording_id=recording.id, + start=float(start), + duration=round(float(end) - float(start), ndigits=8), + channel=channel, + text=" ".join(words), + language="English", + speaker=f"{recording.id}A", + ) + for segment_id, start, end, *words in map(str.split, lines) + if words[0] != "[silence]" or not omit_silence + ] + + +def download_and_untar( + target_dir: Pathlike = ".", force_download: bool = False, url: str = SWBD_TEXT_URL +) -> Path: + target_dir = Path(target_dir) + transcript_dir = target_dir / "swb_ms98_transcriptions" + if transcript_dir.is_dir(): + return transcript_dir + target_dir.mkdir(parents=True, exist_ok=True) + tar_name = "switchboard_word_alignments.tar.gz" + tar_path = target_dir / tar_name + resumable_download(url, filename=tar_path, force_download=force_download) + with tarfile.open(tar_path) as tar: + safe_extract(tar, path=target_dir) + return transcript_dir + + +def parse_and_add_sentiment_labels( + sentiment_dir: Pathlike, supervisions: SupervisionSet +): + """Read 'LDC2020T14' sentiment annotations and add then to the supervision segments.""" + import pandas as pd + + # Sanity checks + sentiment_dir = Path(sentiment_dir) + labels = sentiment_dir / "data" / "sentiment_labels.tsv" + assert sentiment_dir.is_dir() and labels.is_file() + # Read the TSV as a dataframe + df = pd.read_csv(labels, delimiter="\t", names=["id", "start", "end", "sentiment"]) + # We are going to match the segments in LDC2020T14 with the ones we already + # parsed from ISIP transcripts. We simply look which of the existing segments + # fall into a sentiment-annotated time span. When doing it this way, we use + # 51773 out of 52293 available sentiment annotations, which should be okay. + for _, row in df.iterrows(): + call_id = row["id"].split("_")[0] + matches = list( + supervisions.find( + recording_id=call_id, + start_after=row["start"] - 1e-2, + end_before=row["end"] + 1e-2, + ) + ) + if not matches: + continue + labels = row["sentiment"].split("#") + # SupervisionSegments returned from .find() are references to the ones in the + # SupervisionSet, so we can just modify them. We use the "custom" field + # to add the sentiment label. Since there were multiple annotators, + # we add all available labels and leave it up to the user to disambiguate them. + for segment in matches: + segment.custom = {f"sentiment{i}": label for i, label in enumerate(labels)} diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh new file mode 100755 index 0000000000..1545b7809a --- /dev/null +++ b/egs/swbd/ASR/local/rt03_data_prep.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash + +# RT-03 data preparation (conversational telephone speech part only) +# Adapted from Arnab Ghoshal's script for Hub-5 Eval 2000 by Peng Qi + +# To be run from one directory above this script. + +# Expects the standard directory layout for RT-03 + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/corpora/LDC/LDC2007S10" + echo "See comments in the script for more details" + exit 1 +fi + +sdir=$1 +[ ! -d $sdir/data/audio/eval03/english/cts ] \ + && echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1; + [ ! -d $sdir/data/references/eval03/english/cts ] \ + && echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1; + + dir=data/local/rt03 + mkdir -p $dir + + rtroot=$sdir + tdir=$sdir/data/references/eval03/english/cts + sdir=$sdir/data/audio/eval03/english/cts + + find -L $sdir -iname '*.sph' | sort > $dir/sph.flist + sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + > $dir/sph.scp + + sph2pipe=sph2pipe + ! command -v "${sph2pipe}" &> /dev/null \ + && echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1; + + awk -v sph2pipe=$sph2pipe '{ + printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); + printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' < $dir/sph.scp | sort > $dir/wav.scp || exit 1; +#side A - channel 1, side B - channel 2 + +# Get segments file... +# segments file format is: utt-id side-id start-time end-time, e.g.: +# sw02001-A_000098-001156 sw02001-A 0.98 11.56 +#pem=$sdir/english/hub5e_00.pem +#[ ! -f $pem ] && echo "No such file $pem" && exit 1; +# pem file has lines like: +# en_4156 A unknown_speaker 301.85 302.48 + +#grep -v ';;' $pem \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap \ + | awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + print utt,spk,$4,$5;}' \ + | sort -u > $dir/segments + +# stm file has lines like: +# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER +# TODO(arnab): We should really be lowercasing this since the Edinburgh +# recipe uses lowercase. This is not used in the actual scoring. +#grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap \ + | awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' \ + | sort > $dir/text.all + +# We'll use the stm file for sclite scoring. There seem to be various errors +# in the stm file that upset hubscr.pl, and we fix them here. +cat $tdir/*.stm | \ + sed -e 's:((:(:' -e 's:::g' -e 's:::g' | \ + grep -v inter_segment_gap | \ + awk '{ + printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }'\ + > $dir/stm + #$tdir/reference/hub5e00.english.000405.stm > $dir/stm + cp $rtroot/data/trans_rules/en20030506.glm $dir/glm + +# next line uses command substitution +# Just checking that the segments are the same in pem vs. stm. +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && \ + echo "Segments from pem file and stm file do not match." && exit 1; + +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all > $dir/text + +# create an utt2spk file that assumes each conversation side is +# a separate speaker. +awk '{print $1,$2;}' $dir/segments > $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +# cp $dir/segments $dir/segments.tmp +# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ +# $dir/segments.tmp > $dir/segments + +awk '{print $1}' $dir/wav.scp \ + | perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + > $dir/reco2file_and_channel || exit 1; + + ./utils/fix_data_dir.sh $dir + + echo Data preparation and formatting completed for RT-03 + echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py new file mode 100755 index 0000000000..bed3856e44 --- /dev/null +++ b/egs/swbd/ASR/local/sort_lm_training_data.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh new file mode 100755 index 0000000000..0027403540 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_data_prep.sh @@ -0,0 +1,132 @@ +#!/usr/bin/env bash + +# Switchboard-1 training data preparation customized for Edinburgh +# Author: Arnab Ghoshal (Jan 2013) + +# To be run from one directory above this script. + +## The input is some directory containing the switchboard-1 release 2 +## corpus (LDC97S62). Note: we don't make many assumptions about how +## you unpacked this. We are just doing a "find" command to locate +## the .sph files. + +## The second input is optional, which should point to a directory containing +## Switchboard transcriptions/documentations (specifically, the conv.tab file). +## If specified, the script will try to use the actual speaker PINs provided +## with the corpus instead of the conversation side ID (Kaldi default). We +## will be using "find" to locate this file so we don't make any assumptions +## on the directory structure. (Peng Qi, Aug 2014) + + +#check existing directories +if [ $# != 1 -a $# != 2 ]; then + echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" + exit 1; +fi + +SWBD_DIR=$1 + +dir=data/local/train +mkdir -p $dir + + +# Audio data directory check +if [ ! -d $SWBD_DIR ]; then + echo "Error: run.sh requires a directory argument" + exit 1; +fi + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &> /dev/null \ + && echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1; + +# Option A: SWBD dictionary file check +[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && \ + echo "SWBD dictionary file does not exist" && exit 1; + +# find sph audio files +find -L $SWBD_DIR -iname '*.sph' | sort > $dir/sph.flist + +n=`cat $dir/sph.flist | wc -l` +[ $n -ne 2435 ] && [ $n -ne 2438 ] && \ + echo Warning: expected 2435 or 2438 data data files, found $n + + +# (1a) Transcriptions preparation +# make basic transcription file (add segments info) +# **NOTE: In the default Kaldi recipe, everything is made uppercase, while we +# make everything lowercase here. This is because we will be using SRILM which +# can optionally make everything lowercase (but not uppercase) when mapping +# LM vocabs. +awk '{ +name=substr($1,1,6); gsub("^sw","sw0",name); side=substr($1,7,1); +stime=$2; etime=$3; +printf("%s-%s_%06.0f-%06.0f", +name, side, int(100*stime+0.5), int(100*etime+0.5)); +for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" +}' ./swb_ms98_transcriptions/*/*/*-trans.text > $dir/transcripts1.txt + +# test if trans. file is sorted +export LC_ALL=C; +sort -c $dir/transcripts1.txt || exit 1; # check it's sorted. + +# Remove SILENCE, and . + +# Note: we have [NOISE], [VOCALIZED-NOISE], [LAUGHTER], [SILENCE]. +# removing [SILENCE], and the and markers that mark +# speech to somone; we will give phones to the other three (NSN, SPN, LAU). +# There will also be a silence phone, SIL. +# **NOTE: modified the pattern matches to make them case insensitive +cat $dir/transcripts1.txt \ + | perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; + s///gi; + s///gi; + print;' \ + | awk '{if(NF > 1) { print; } } ' > $dir/transcripts2.txt + + +# **NOTE: swbd1_map_words.pl has been modified to make the pattern matches +# case insensitive +local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt > $dir/text + +# format acronyms in text +python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ + -M data/local/dict_nosp/acronyms.map + mv $dir/text_map $dir/text + +# (1c) Make segment files from transcript +#segments file format is: utt-id side-id start-time end-time, e.g.: +#sw02001-A_000098-001156 sw02001-A 0.98 11.56 +awk '{ +segment=$1; +split(segment,S,"[_-]"); +side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; +print segment " " audioname "-" side " " startf/100 " " endf/100 +}' < $dir/text > $dir/segments + +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + > $dir/sph.scp + +awk -v sph2pipe=$sph2pipe '{ +printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); +printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' < $dir/sph.scp | sort > $dir/wav.scp || exit 1; +#side A - channel 1, side B - channel 2 + +# this file reco2file_and_channel maps recording-id (e.g. sw02001-A) +# to the file name sw02001 and the A, e.g. +# sw02001-A sw02001 A +# In this case it's trivial, but in other corpora the information might +# be less obvious. Later it will be needed for ctm scoring. +awk '{print $1}' $dir/wav.scp \ + | perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + > $dir/reco2file_and_channel || exit 1; + + awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments > $dir/utt2spk \ + || exit 1; + sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt || exit 1; + + echo Switchboard-1 data preparation succeeded. + + utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_map_words.pl b/egs/swbd/ASR/local/swbd1_map_words.pl new file mode 100755 index 0000000000..4fb8d4ffe7 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_map_words.pl @@ -0,0 +1,52 @@ +#!/usr/bin/env perl + +# Modified from swbd_map_words.pl in Kaldi s5 recipe to make pattern +# matches case-insensitive --Arnab (Jan 2013) + +if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + + +while (<>) { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + # e.g. [LAUGHTER-STORY] -> STORY; + $a =~ s:(|\-)^\[LAUGHTER-(.+)\](|\-)$:$1$2$3:i; + # $1 and $3 relate to preserving trailing "-" + $a =~ s:^\[(.+)/.+\](|\-)$:$1$2:; # e.g. [IT'N/ISN'T] -> IT'N ... note, + # 1st part may include partial-word stuff, which we process further below, + # e.g. [LEM[GUINI]-/LINGUINI] + # the (|\_) at the end is to accept and preserve trailing -'s. + $a =~ s:^(|\-)\[[^][]+\](.+)$:-$2:; # e.g. -[AN]Y , note \047 is quote; + # let the leading - be optional on input, as sometimes omitted. + $a =~ s:^(.+)\[[^][]+\](|\-)$:$1-:; # e.g. AB[SOLUTE]- -> AB-; + # let the trailing - be optional on input, as sometimes omitted. + $a =~ s:([^][]+)\[.+\]$:$1:; # e.g. EX[SPECIALLY]-/ESPECIALLY] -> EX- + # which is a mistake in the input. + $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM + $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- + $a =~ s:_\d$::; # e.g. THEM_1 -> THEM + } + $A[$n] = $a; + } + print join(" ", @A) . "\n"; +} diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh new file mode 100755 index 0000000000..0c38a72dce --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash + +# Formatting the Mississippi State dictionary for use in Edinburgh. Differs +# from the one in Kaldi s5 recipe in that it uses lower-case --Arnab (Jan 2013) + +# To be run from one directory above this script. + + +#check existing directories +[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1; + +srcdir=. # This is where we downloaded some stuff.. +dir=./data/local/dict_nosp +mkdir -p $dir +srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text + +# assume swbd_p1_data_prep.sh was done already. +[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1; + +cp $srcdict $dir/lexicon0.txt || exit 1; +chmod a+w $dir/lexicon0.txt +patch 0' | sort > $dir/lexicon1.txt || exit 1; + +cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | \ + grep -v sil > $dir/nonsilence_phones.txt || exit 1; + +( echo sil; echo spn; echo nsn; echo lau ) > $dir/silence_phones.txt + +echo sil > $dir/optional_silence.txt + +# No "extra questions" in the input to this setup, as we don't +# have stress or tone. +echo -n >$dir/extra_questions.txt + +cp local/MSU_single_letter.txt $dir/ +# Add to the lexicon the silences, noises etc. +# Add single letter lexicon +# The original swbd lexicon does not have precise single letter lexicion +# e.g. it does not have entry of W +( echo '!sil sil'; echo '[vocalized-noise] spn'; echo '[noise] nsn'; \ + echo '[laughter] lau'; echo ' spn' ) \ + | cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt > $dir/lexicon2.txt || exit 1; + +# Map the words in the lexicon. That is-- for each word in the lexicon, we map it +# to a new written form. The transformations we do are: +# remove laughter markings, e.g. +# [LAUGHTER-STORY] -> STORY +# Remove partial-words, e.g. +# -[40]1K W AH N K EY +# becomes -1K +# and +# -[AN]Y IY +# becomes +# -Y +# -[A]B[OUT]- B +# becomes +# -B- +# Also, curly braces, which appear to be used for "nonstandard" +# words or non-words, are removed, e.g. +# {WOLMANIZED} W OW L M AX N AY Z D +# -> WOLMANIZED +# Also, mispronounced words, e.g. +# [YEAM/YEAH] Y AE M +# are changed to just e.g. YEAM, i.e. the orthography +# of the mispronounced version. +# Note-- this is only really to be used in training. The main practical +# reason is to avoid having tons of disambiguation symbols, which +# we otherwise would get because there are many partial words with +# the same phone sequences (most problematic: S). +# Also, map +# THEM_1 EH M -> THEM +# so that multiple pronunciations just have alternate entries +# in the lexicon. + +local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ + > $dir/lexicon3.txt || exit 1; + +python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ + -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map + cat $dir/acronyms_raw.map | sort -u > $dir/acronyms.map + + ( echo 'i ay' )| cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u > $dir/lexicon5.txt + + pushd $dir >&/dev/null + ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. + popd >&/dev/null + rm $dir/lexiconp.txt 2>/dev/null + echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py new file mode 100755 index 0000000000..43142aee4c --- /dev/null +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py new file mode 100755 index 0000000000..c542f2fabd --- /dev/null +++ b/egs/swbd/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks that there are no OOV tokens in the BPE-based lexicon. + +Usage example: + + python3 ./local/validate_bpe_lexicon.py \ + --lexicon /path/to/lexicon.txt \ + --bpe-model /path/to/bpe.model +""" + +import argparse +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm + +from icefall.lexicon import read_lexicon + +# Map word to word pieces +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--lexicon", + required=True, + type=Path, + help="Path to lexicon.txt", + ) + + parser.add_argument( + "--bpe-model", + required=True, + type=Path, + help="Path to bpe.model", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + assert args.lexicon.is_file(), args.lexicon + assert args.bpe_model.is_file(), args.bpe_model + + lexicon = read_lexicon(args.lexicon) + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size())))) + for word, pieces in lexicon: + for p in pieces: + if p not in word_pieces: + raise ValueError(f"The word {word} contains an OOV token {p}") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh new file mode 100755 index 0000000000..ae0b278b15 --- /dev/null +++ b/egs/swbd/ASR/prepare.sh @@ -0,0 +1,403 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=./download +swbd1_dir="/export/corpora3/LDC/LDC97S62" +eval2000_dir="/export/corpora2/LDC/LDC2002S09/hub5e_00" +eval2000_ref_dir="/export/corpora2/LDC/LDC2002T43" +rt03_dir="/export/corpora/LDC/LDC2007S10" + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 5000 + 2000 + 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "swbd1_dir: $swbd1_dir" +log "eval2000_dir: $eval2000_dir $eval2000_ref_dir" +log "rt03_dir: $rt03_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare SwitchBoard manifest" + # We assume that you have downloaded the SwitchBoard corpus + # to respective dirs + mkdir -p data/manifests + if [ ! -e data/manifests/.swbd.done ]; then + lhotse prepare switchboard --absolute-paths True $swbd1_dir data/manifests_train + ./local/swbd1_prepare_dict.sh + ./local/swbd1_data_prep.sh $swbd1_dir + lhotse kaldi import data/local/train 8000 data/manifests_train + mv data/manifests_train/recordings.jsonl.gz data/manifests_train/swbd_recordings_all.jsonl.gz + mv data/manifests_train/supervisions.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz + + ./local/eval2000_data_prep.sh $eval2000_dir $eval2000_ref_dir + ./local/rt03_data_prep.sh $rt03_dir + + # normalize eval2000 and rt03 texts by + # 1) convert upper to lower + # 2) remove tags (%AH) (%HESITATION) (%UH) + # 3) remove + # 4) remove "(" or ")" + for x in eval2000 rt03; do + cp data/local/${x}/text data/local/${x}/text.org + paste -d "" \ + <(cut -f 1 -d" " data/local/${x}/text.org) \ + <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") \ + | sed -e 's/\s\+/ /g' > data/local/${x}/text + rm data/local/${x}/text.org + done + + python ./local/filter_empty_text.py --kaldi-data-dir data/local/eval2000 + ./utils/fix_data_dir.sh data/local/eval2000 + lhotse kaldi import data/local/eval2000 8000 data/manifests_eval2000 + mv data/manifests_eval2000/recordings.jsonl.gz data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz + mv data/manifests_eval2000/supervisions.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz + + python ./local/filter_empty_text.py --kaldi-data-dir data/local/rt03 + ./utils/fix_data_dir.sh data/local/rt03 + lhotse kaldi import data/local/rt03 8000 data/manifests_rt03 + mv data/manifests_rt03/recordings.jsonl.gz data/manifests_rt03/swbd_recordings_rt03.jsonl.gz + mv data/manifests_rt03/supervisions.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz + + lhotse fix data/manifests_train/swbd_recordings_all.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz data/manifests + lhotse fix data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz data/manifests + lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + + touch data/manifests/.swbd.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for switchboard" + mkdir -p data/fbank + if [ ! -e data/fbank/.swbd.done ]; then + ./local/compute_fbank_swbd.py + touch data/fbank/.swbd.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + log "prepare dict" + cut -f 2- -d" " data/local/train/text >${lang_dir}/input.txt + # [noise] nsn + # !sil sil + # spn + cat data/local/dict_nosp/lexicon.txt | + sort | uniq >$lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + + cat ./data/local/train/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + >$lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa >$lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + lang_dir=data/lang_phone + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/3-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/4-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data data/lang_phone/input.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +# if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then +# log "Stage 12: Generate LM validation data" + +# for vocab_size in ${vocab_sizes[@]}; do +# log "Processing vocab_size == ${vocab_size}" +# out_dir=data/lm_training_bpe_${vocab_size} +# mkdir -p $out_dir + +# if [ ! -f $out_dir/valid.txt ]; then +# files=$( +# find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" +# find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" +# ) +# for f in ${files[@]}; do +# cat $f | cut -d " " -f 2- +# done >$out_dir/valid.txt +# fi + +# lang_dir=data/lang_bpe_${vocab_size} +# ./local/prepare_lm_training_data.py \ +# --bpe-model $lang_dir/bpe.model \ +# --lm-data $out_dir/valid.txt \ +# --lm-archive $out_dir/lm_data-valid.pt +# done +# fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + testsets=(eval2000 rt03) + + for testset in ${testsets[@]}; do + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + cat data/local/${testset}/text | cut -d " " -f 2- >$out_dir/${testset}.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/${testset}.txt \ + --lm-archive $out_dir/lm_data-${testset}.pt + done + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + testsets=(eval2000 rt03) + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + for testset in ${testsets[@]}; do + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-${testset}.pt \ + --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ + --out-statistics $out_dir/statistics-test-${testset}.txt + done + done +fi diff --git a/egs/swbd/ASR/shared b/egs/swbd/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/swbd/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/swbd/ASR/utils/filter_scp.pl b/egs/swbd/ASR/utils/filter_scp.pl new file mode 100755 index 0000000000..b76d37f41b --- /dev/null +++ b/egs/swbd/ASR/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/swbd/ASR/utils/fix_data_dir.sh b/egs/swbd/ASR/utils/fix_data_dir.sh new file mode 100755 index 0000000000..ca0972ca85 --- /dev/null +++ b/egs/swbd/ASR/utils/fix_data_dir.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + for x in feats.scp text segments utt2lang $maybe_wav; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/swbd/ASR/utils/parse_options.sh b/egs/swbd/ASR/utils/parse_options.sh new file mode 100755 index 0000000000..34476fdb37 --- /dev/null +++ b/egs/swbd/ASR/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 0000000000..23992f25de --- /dev/null +++ b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 0000000000..6e0e438ca6 --- /dev/null +++ b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} From 35984ca22420d43cb0582a682424c8780f439456 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 26 Jun 2023 17:27:42 +0800 Subject: [PATCH 02/60] removed unsed scripts --- egs/swbd/ASR/local/prepare_swbd_testsets.py | 161 -------------------- 1 file changed, 161 deletions(-) delete mode 100755 egs/swbd/ASR/local/prepare_swbd_testsets.py diff --git a/egs/swbd/ASR/local/prepare_swbd_testsets.py b/egs/swbd/ASR/local/prepare_swbd_testsets.py deleted file mode 100755 index 62ad0fda82..0000000000 --- a/egs/swbd/ASR/local/prepare_swbd_testsets.py +++ /dev/null @@ -1,161 +0,0 @@ -import argparse -import logging -import os -from pathlib import Path -import tarfile -from itertools import chain -from typing import Dict, Optional, Union - -from lhotse import fix_manifests, validate_recordings_and_supervisions -from lhotse.audio import Recording, RecordingSet -from lhotse.supervision import SupervisionSegment, SupervisionSet -from lhotse.utils import Pathlike, check_and_rglob, resumable_download, safe_extract - -import sentencepiece as spm -from filter_cuts import filter_cuts -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - - -def prepare_switchboard( - audio_dir: Pathlike, - transcripts_dir: Optional[Pathlike] = None, - sentiment_dir: Optional[Pathlike] = None, - output_dir: Optional[Pathlike] = None, - omit_silence: bool = True, - absolute_paths: bool = False, -) -> Dict[str, Union[RecordingSet, SupervisionSet]]: - """ - Prepare manifests for the Switchboard corpus. - We create two manifests: one with recordings, and the other one with text supervisions. - When ``sentiment_dir`` is provided, we create another supervision manifest with sentiment annotations. - :param audio_dir: Path to ``LDC97S62`` package. - :param transcripts_dir: Path to the transcripts directory (typically named "swb_ms98_transcriptions"). - If not provided, the transcripts will be downloaded. - :param sentiment_dir: Optional path to ``LDC2020T14`` package which contains sentiment annotations - for SWBD segments. - :param output_dir: Directory where the manifests should be written. Can be omitted to avoid writing. - :param omit_silence: Whether supervision segments with ``[silence]`` token should be removed or kept. - :param absolute_paths: Whether to return absolute or relative (to the corpus dir) paths for recordings. - :return: A dict with manifests. The keys are: ``{'recordings', 'supervisions'}``. - """ - audio_paths = check_and_rglob(audio_dir, "*.sph") - text_paths = check_and_rglob(transcripts_dir, "*trans.text") - - groups = [] - name_to_text = {p.stem.split("-")[0]: p for p in text_paths} - for ap in audio_paths: - name = ap.stem.replace("sw0", "sw") - groups.append( - { - "audio": ap, - "text-0": name_to_text[f"{name}A"], - "text-1": name_to_text[f"{name}B"], - } - ) - - recordings = RecordingSet.from_recordings( - Recording.from_file( - group["audio"], relative_path_depth=None if absolute_paths else 3 - ) - for group in groups - ) - supervisions = SupervisionSet.from_segments( - chain.from_iterable( - make_segments( - transcript_path=group[f"text-{channel}"], - recording=recording, - channel=channel, - omit_silence=omit_silence, - ) - for group, recording in zip(groups, recordings) - for channel in [0, 1] - ) - ) - - recordings, supervisions = fix_manifests(recordings, supervisions) - validate_recordings_and_supervisions(recordings, supervisions) - - if sentiment_dir is not None: - parse_and_add_sentiment_labels(sentiment_dir, supervisions) - - if output_dir is not None: - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - recordings.to_file(output_dir / "swbd_recordings_all.jsonl.gz") - supervisions.to_file(output_dir / "swbd_supervisions_all.jsonl.gz") - return {"recordings": recordings, "supervisions": supervisions} - - -def make_segments( - transcript_path: Path, recording: Recording, channel: int, omit_silence: bool = True -): - lines = transcript_path.read_text().splitlines() - return [ - SupervisionSegment( - id=segment_id, - recording_id=recording.id, - start=float(start), - duration=round(float(end) - float(start), ndigits=8), - channel=channel, - text=" ".join(words), - language="English", - speaker=f"{recording.id}A", - ) - for segment_id, start, end, *words in map(str.split, lines) - if words[0] != "[silence]" or not omit_silence - ] - - -def download_and_untar( - target_dir: Pathlike = ".", force_download: bool = False, url: str = SWBD_TEXT_URL -) -> Path: - target_dir = Path(target_dir) - transcript_dir = target_dir / "swb_ms98_transcriptions" - if transcript_dir.is_dir(): - return transcript_dir - target_dir.mkdir(parents=True, exist_ok=True) - tar_name = "switchboard_word_alignments.tar.gz" - tar_path = target_dir / tar_name - resumable_download(url, filename=tar_path, force_download=force_download) - with tarfile.open(tar_path) as tar: - safe_extract(tar, path=target_dir) - return transcript_dir - - -def parse_and_add_sentiment_labels( - sentiment_dir: Pathlike, supervisions: SupervisionSet -): - """Read 'LDC2020T14' sentiment annotations and add then to the supervision segments.""" - import pandas as pd - - # Sanity checks - sentiment_dir = Path(sentiment_dir) - labels = sentiment_dir / "data" / "sentiment_labels.tsv" - assert sentiment_dir.is_dir() and labels.is_file() - # Read the TSV as a dataframe - df = pd.read_csv(labels, delimiter="\t", names=["id", "start", "end", "sentiment"]) - # We are going to match the segments in LDC2020T14 with the ones we already - # parsed from ISIP transcripts. We simply look which of the existing segments - # fall into a sentiment-annotated time span. When doing it this way, we use - # 51773 out of 52293 available sentiment annotations, which should be okay. - for _, row in df.iterrows(): - call_id = row["id"].split("_")[0] - matches = list( - supervisions.find( - recording_id=call_id, - start_after=row["start"] - 1e-2, - end_before=row["end"] + 1e-2, - ) - ) - if not matches: - continue - labels = row["sentiment"].split("#") - # SupervisionSegments returned from .find() are references to the ones in the - # SupervisionSet, so we can just modify them. We use the "custom" field - # to add the sentiment label. Since there were multiple annotators, - # we add all available labels and leave it up to the user to disambiguate them. - for segment in matches: - segment.custom = {f"sentiment{i}": label for i, label in enumerate(labels)} From 96738b538aef89ecabbe69d155863fb1903c7f53 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 26 Jun 2023 19:21:46 +0800 Subject: [PATCH 03/60] fixed formatting issues --- egs/swbd/ASR/conformer_ctc/ali.py | 395 ------------------- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 13 +- egs/swbd/ASR/conformer_ctc/decode.py | 9 +- egs/swbd/ASR/conformer_ctc/train.py | 1 + egs/swbd/ASR/local/filter_empty_text.py | 1 + 5 files changed, 20 insertions(+), 399 deletions(-) delete mode 100755 egs/swbd/ASR/conformer_ctc/ali.py diff --git a/egs/swbd/ASR/conformer_ctc/ali.py b/egs/swbd/ASR/conformer_ctc/ali.py deleted file mode 100755 index 42e14abac9..0000000000 --- a/egs/swbd/ASR/conformer_ctc/ali.py +++ /dev/null @@ -1,395 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - ./conformer_ctc/ali.py \ - --exp-dir ./conformer_ctc/exp \ - --lang-dir ./data/lang_bpe_500 \ - --epoch 20 \ - --avg 10 \ - --max-duration 300 \ - --dataset train-clean-100 \ - --out-dir data/ali -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import numpy as np -import torch -from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer -from lhotse import CutSet -from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import one_best_decoding -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - get_alignments, - setup_logger, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--out-dir", - type=str, - required=True, - help="""Output directory. - It contains 3 generated files: - - - labels_xxx.h5 - - aux_labels_xxx.h5 - - librispeech_cuts_xxx.jsonl.gz - - where xxx is the value of `--dataset`. For instance, if - `--dataset` is `train-clean-100`, it will contain 3 files: - - - `labels_train-clean-100.h5` - - `aux_labels_train-clean-100.h5` - - `librispeech_cuts_train-clean-100.jsonl.gz` - - Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise - alignment. The difference is that labels_xxx.h5 contains repeats. - """, - ) - - parser.add_argument( - "--dataset", - type=str, - required=True, - help="""The name of the dataset to compute alignments for. - Possible values are: - - test-clean. - - test-other - - train-clean-100 - - train-clean-360 - - train-other-500 - - dev-clean - - dev-other - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "lm_dir": Path("data/lm"), - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "subsampling_factor": 4, - # Set it to 0 since attention decoder - # is not used for computing alignments - "num_decoder_layers": 0, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "output_beam": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def compute_alignments( - model: torch.nn.Module, - dl: torch.utils.data.DataLoader, - labels_writer: FeaturesWriter, - aux_labels_writer: FeaturesWriter, - params: AttributeDict, - graph_compiler: BpeCtcTrainingGraphCompiler, -) -> CutSet: - """Compute the framewise alignments of a dataset. - - Args: - model: - The neural network model. - dl: - Dataloader containing the dataset. - params: - Parameters for computing alignments. - graph_compiler: - It converts token IDs to decoding graphs. - Returns: - Return a CutSet. Each cut has two custom fields: labels_alignment - and aux_labels_alignment, containing framewise alignments information. - Both are of type `lhotse.array.TemporalArray`. The difference between - the two alignments is that `labels_alignment` contain repeats. - """ - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - num_cuts = 0 - - device = graph_compiler.device - cuts = [] - for batch_idx, batch in enumerate(dl): - feature = batch["inputs"] - - # at entry, feature is [N, T, C] - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - cut_list = supervisions["cut"] - - for cut in cut_list: - assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" - - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - # we need also to sort cut_ids as encode_supervisions() - # reorders "texts". - # In general, new2old is an identity map since lhotse sorts the returned - # cuts by duration in descending order - new2old = supervision_segments[:, 0].tolist() - - cut_list = [cut_list[i] for i in new2old] - - token_ids = graph_compiler.texts_to_ids(texts) - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - lattice = k2.intersect_dense( - decoding_graph, - dense_fsa_vec, - params.output_beam, - ) - - best_path = one_best_decoding( - lattice=lattice, - use_double_scores=params.use_double_scores, - ) - - labels_ali = get_alignments(best_path, kind="labels") - aux_labels_ali = get_alignments(best_path, kind="aux_labels") - assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): - cut.labels_alignment = labels_writer.store_array( - key=cut.id, - value=np.asarray(labels, dtype=np.int32), - # frame shift is 0.01s, subsampling_factor is 4 - frame_shift=0.04, - temporal_dim=0, - start=0, - ) - cut.aux_labels_alignment = aux_labels_writer.store_array( - key=cut.id, - value=np.asarray(aux_labels, dtype=np.int32), - # frame shift is 0.01s, subsampling_factor is 4 - frame_shift=0.04, - temporal_dim=0, - start=0, - ) - - cuts += cut_list - - num_cuts += len(cut_list) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return CutSet.from_cuts(cuts) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - args.enable_spec_aug = False - args.enable_musan = False - args.return_cuts = True - args.concatenate_cuts = False - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-ali") - - logging.info(f"Computing alignments for {params.dataset} - started") - logging.info(params) - - out_dir = Path(params.out_dir) - out_dir.mkdir(exist_ok=True) - - out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" - out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - - for f in ( - out_labels_ali_filename, - out_aux_labels_ali_filename, - out_manifest_filename, - ): - if f.exists(): - logging.info(f"{f} exists - skipping") - return - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - model.to(device) - - if params.avg == 1: - load_checkpoint( - f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False - ) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.eval() - - librispeech = LibriSpeechAsrDataModule(args) - if params.dataset == "test-clean": - test_clean_cuts = librispeech.test_clean_cuts() - dl = librispeech.test_dataloaders(test_clean_cuts) - elif params.dataset == "test-other": - test_other_cuts = librispeech.test_other_cuts() - dl = librispeech.test_dataloaders(test_other_cuts) - elif params.dataset == "train-clean-100": - train_clean_100_cuts = librispeech.train_clean_100_cuts() - dl = librispeech.train_dataloaders(train_clean_100_cuts) - elif params.dataset == "train-clean-360": - train_clean_360_cuts = librispeech.train_clean_360_cuts() - dl = librispeech.train_dataloaders(train_clean_360_cuts) - elif params.dataset == "train-other-500": - train_other_500_cuts = librispeech.train_other_500_cuts() - dl = librispeech.train_dataloaders(train_other_500_cuts) - elif params.dataset == "dev-clean": - dev_clean_cuts = librispeech.dev_clean_cuts() - dl = librispeech.valid_dataloaders(dev_clean_cuts) - else: - assert params.dataset == "dev-other", f"{params.dataset}" - dev_other_cuts = librispeech.dev_other_cuts() - dl = librispeech.valid_dataloaders(dev_other_cuts) - - logging.info(f"Processing {params.dataset}") - with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: - with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: - cut_set = compute_alignments( - model=model, - dl=dl, - labels_writer=labels_writer, - aux_labels_writer=aux_labels_writer, - params=params, - graph_compiler=graph_compiler, - ) - - cut_set.to_file(out_manifest_filename) - - logging.info( - f"For dataset {params.dataset}, its alignments with repeats are " - f"saved to {out_labels_ali_filename}, the alignments without repeats " - f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " - f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" - ) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index c36b8727f5..c45f2cbd0b 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -1,5 +1,6 @@ # Copyright 2021 Piotr Żelasko # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -390,12 +391,20 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_all_cuts(self) -> CutSet: logging.info("switchboard: About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz").subset(last=2388).trim_to_supervisions(keep_all_channels=True) + return ( + load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz") + .subset(last=2388) + .trim_to_supervisions(keep_all_channels=True) + ) @lru_cache() def dev_cuts(self) -> CutSet: logging.info("switchboard: About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz").subset(first=50).trim_to_supervisions(keep_all_channels=True) + return ( + load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz") + .subset(first=50) + .trim_to_supervisions(keep_all_channels=True) + ) @lru_cache() def test_eval2000_cuts(self) -> CutSet: diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index a97936af84..f805becd02 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Modified by Zengrui Jin for the SwitchBoard corpus # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -779,8 +780,12 @@ def main(): args.return_cuts = True switchboard = SwitchBoardAsrDataModule(args) - test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions(keep_all_channels=True) - test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions(keep_all_channels=True) + test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( + keep_all_channels=True + ) + test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + keep_all_channels=True + ) test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py index 0c795868e3..370fa0f774 100755 --- a/egs/swbd/ASR/conformer_ctc/train.py +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang # Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py index d335abb464..6b3316800f 100755 --- a/egs/swbd/ASR/local/filter_empty_text.py +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -20,6 +20,7 @@ import logging from typing import List + def get_args(): parser = argparse.ArgumentParser() From 439855e3f32068fc245cf57b67f168368e2a7ad6 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 26 Jun 2023 19:32:50 +0800 Subject: [PATCH 04/60] Fixed formatting issues in bash scripts. --- egs/swbd/ASR/local/eval2000_data_prep.sh | 90 ++++++++++++------------ egs/swbd/ASR/local/rt03_data_prep.sh | 86 +++++++++++----------- egs/swbd/ASR/local/swbd1_data_prep.sh | 62 ++++++++-------- egs/swbd/ASR/local/swbd1_prepare_dict.sh | 53 ++++++++------ egs/swbd/ASR/prepare.sh | 82 ++++++++++----------- 5 files changed, 188 insertions(+), 185 deletions(-) diff --git a/egs/swbd/ASR/local/eval2000_data_prep.sh b/egs/swbd/ASR/local/eval2000_data_prep.sh index f881983286..a5bd3ebcf9 100755 --- a/egs/swbd/ASR/local/eval2000_data_prep.sh +++ b/egs/swbd/ASR/local/eval2000_data_prep.sh @@ -21,100 +21,98 @@ # $tdir/reference/hub5e00.english.000405.stm if [ $# -ne 2 ]; then - echo "Usage: "`basename $0`" " + echo "Usage: "$(basename $0)" " echo "See comments in the script for more details" exit 1 fi sdir=$1 tdir=$2 -[ ! -d $sdir/english ] \ - && echo Expecting directory $sdir/english to be present && exit 1; - [ -d $tdir/2000_hub5_eng_eval_tr ] \ - && tdir=$tdir/2000_hub5_eng_eval_tr - [ ! -d $tdir/reference ] \ - && echo Expecting directory $tdir/reference to be present && exit 1; +[ ! -d $sdir/english ] && + echo Expecting directory $sdir/english to be present && exit 1 +[ -d $tdir/2000_hub5_eng_eval_tr ] && + tdir=$tdir/2000_hub5_eng_eval_tr +[ ! -d $tdir/reference ] && + echo Expecting directory $tdir/reference to be present && exit 1 +dir=data/local/eval2000 +mkdir -p $dir - dir=data/local/eval2000 - mkdir -p $dir +find $sdir/english -iname '*.sph' | sort >$dir/sph.flist +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp - find $sdir/english -iname '*.sph' | sort > $dir/sph.flist - sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ - > $dir/sph.scp +sph2pipe=sph2pipe +[ ! -x $sph2pipe ] && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 - sph2pipe=sph2pipe - [ ! -x $sph2pipe ] \ - && echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1; - - awk -v sph2pipe=$sph2pipe '{ +awk -v sph2pipe=$sph2pipe '{ printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); - }' < $dir/sph.scp | sort > $dir/wav.scp || exit 1; - #side A - channel 1, side B - channel 2 + }' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 # Get segments file... # segments file format is: utt-id side-id start-time end-time, e.g.: # sw02001-A_000098-001156 sw02001-A 0.98 11.56 pem=$sdir/english/hub5e_00.pem -[ ! -f $pem ] && echo "$0: No such file $pem" && exit 1; +[ ! -f $pem ] && echo "$0: No such file $pem" && exit 1 # pem file has lines like: # en_4156 A unknown_speaker 301.85 302.48 # we ignore the warnings below for now, although they seem to indicate some problems # with the data. -grep -v ';;' $pem \ - | awk '{ +grep -v ';;' $pem | + awk '{ spk=$1"-"$2; utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - print utt,spk,$4,$5;}' \ - | sort -u | local/extend_segments.pl 0.1 > $dir/segments + print utt,spk,$4,$5;}' | + sort -u | local/extend_segments.pl 0.1 >$dir/segments # stm file has lines like: # en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER # TODO(arnab): We should really be lowercasing this since the Edinburgh # recipe uses lowercase. This is not used in the actual scoring. -grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ - | awk '{ +grep -v ';;' $tdir/reference/hub5e00.english.000405.stm | + awk '{ spk=$1"-"$2; utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' \ - | sort > $dir/text.all + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | + sort >$dir/text.all # We'll use the stm file for sclite scoring. There seem to be various errors # in the stm file that upset hubscr.pl, and we fix them here. sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ - $tdir/reference/hub5e00.english.000405.stm > $dir/stm - cp $tdir/reference/en20000405_hub5.glm $dir/glm + $tdir/reference/hub5e00.english.000405.stm >$dir/stm +cp $tdir/reference/en20000405_hub5.glm $dir/glm # next line uses command substitution # Just checking that the segments are the same in pem vs. stm. -! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && \ - echo "Segments from pem file and stm file do not match." && exit 1; +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && + echo "Segments from pem file and stm file do not match." && exit 1 -grep -v IGNORE_TIME_SEGMENT_ $dir/text.all > $dir/text +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text # create an utt2spk file that assumes each conversation side is # a separate speaker. -awk '{print $1,$2;}' $dir/segments > $dir/utt2spk -utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt +awk '{print $1,$2;}' $dir/segments >$dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt # cp $dir/segments $dir/segments.tmp # awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ # $dir/segments.tmp > $dir/segments -awk '{print $1}' $dir/wav.scp \ - | perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; print "$1-$2 $1 $2\n"; ' \ - > $dir/reco2file_and_channel || exit 1; - + >$dir/reco2file_and_channel || exit 1 - echo Data preparation and formatting completed for Eval 2000 - echo "(but not MFCC extraction)" +echo Data preparation and formatting completed for Eval 2000 +echo "(but not MFCC extraction)" - utils/fix_data_dir.sh $dir +utils/fix_data_dir.sh $dir - if [ $(wc -l < $dir/wav.scp) -ne 80 ]; then - echo "$0: error: expected 80 lines in wav.scp, got $(wc -l < $dir/wav.scp)" - exit 1; - fi +if [ $(wc -l <$dir/wav.scp) -ne 80 ]; then + echo "$0: error: expected 80 lines in wav.scp, got $(wc -l <$dir/wav.scp)" + exit 1 +fi diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh index 1545b7809a..8a5f64324f 100755 --- a/egs/swbd/ASR/local/rt03_data_prep.sh +++ b/egs/swbd/ASR/local/rt03_data_prep.sh @@ -15,30 +15,30 @@ if [ $# -ne 1 ]; then fi sdir=$1 -[ ! -d $sdir/data/audio/eval03/english/cts ] \ - && echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1; - [ ! -d $sdir/data/references/eval03/english/cts ] \ - && echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1; +[ ! -d $sdir/data/audio/eval03/english/cts ] && + echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1 +[ ! -d $sdir/data/references/eval03/english/cts ] && + echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1 - dir=data/local/rt03 - mkdir -p $dir +dir=data/local/rt03 +mkdir -p $dir - rtroot=$sdir - tdir=$sdir/data/references/eval03/english/cts - sdir=$sdir/data/audio/eval03/english/cts +rtroot=$sdir +tdir=$sdir/data/references/eval03/english/cts +sdir=$sdir/data/audio/eval03/english/cts - find -L $sdir -iname '*.sph' | sort > $dir/sph.flist - sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ - > $dir/sph.scp +find -L $sdir -iname '*.sph' | sort >$dir/sph.flist +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp - sph2pipe=sph2pipe - ! command -v "${sph2pipe}" &> /dev/null \ - && echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1; +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 - awk -v sph2pipe=$sph2pipe '{ +awk -v sph2pipe=$sph2pipe '{ printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); -}' < $dir/sph.scp | sort > $dir/wav.scp || exit 1; +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 #side A - channel 1, side B - channel 2 # Get segments file... @@ -50,58 +50,58 @@ sdir=$1 # en_4156 A unknown_speaker 301.85 302.48 #grep -v ';;' $pem \ -cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap \ - | awk '{ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ spk=$1"-"(($2==1)?"A":"B"); utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - print utt,spk,$4,$5;}' \ - | sort -u > $dir/segments + print utt,spk,$4,$5;}' | + sort -u >$dir/segments # stm file has lines like: # en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER # TODO(arnab): We should really be lowercasing this since the Edinburgh # recipe uses lowercase. This is not used in the actual scoring. #grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ -cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap \ - | awk '{ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ spk=$1"-"(($2==1)?"A":"B"); utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' \ - | sort > $dir/text.all + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | + sort >$dir/text.all # We'll use the stm file for sclite scoring. There seem to be various errors # in the stm file that upset hubscr.pl, and we fix them here. -cat $tdir/*.stm | \ - sed -e 's:((:(:' -e 's:::g' -e 's:::g' | \ - grep -v inter_segment_gap | \ +cat $tdir/*.stm | + sed -e 's:((:(:' -e 's:::g' -e 's:::g' | + grep -v inter_segment_gap | awk '{ - printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }'\ - > $dir/stm - #$tdir/reference/hub5e00.english.000405.stm > $dir/stm - cp $rtroot/data/trans_rules/en20030506.glm $dir/glm + printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }' \ + >$dir/stm +#$tdir/reference/hub5e00.english.000405.stm > $dir/stm +cp $rtroot/data/trans_rules/en20030506.glm $dir/glm # next line uses command substitution # Just checking that the segments are the same in pem vs. stm. -! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && \ - echo "Segments from pem file and stm file do not match." && exit 1; +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && + echo "Segments from pem file and stm file do not match." && exit 1 -grep -v IGNORE_TIME_SEGMENT_ $dir/text.all > $dir/text +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text # create an utt2spk file that assumes each conversation side is # a separate speaker. -awk '{print $1,$2;}' $dir/segments > $dir/utt2spk -utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt +awk '{print $1,$2;}' $dir/segments >$dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt # cp $dir/segments $dir/segments.tmp # awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ # $dir/segments.tmp > $dir/segments -awk '{print $1}' $dir/wav.scp \ - | perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; print "$1-$2 $1 $2\n"; ' \ - > $dir/reco2file_and_channel || exit 1; + >$dir/reco2file_and_channel || exit 1 - ./utils/fix_data_dir.sh $dir +./utils/fix_data_dir.sh $dir - echo Data preparation and formatting completed for RT-03 - echo "(but not MFCC extraction)" +echo Data preparation and formatting completed for RT-03 +echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh index 0027403540..1593594915 100755 --- a/egs/swbd/ASR/local/swbd1_data_prep.sh +++ b/egs/swbd/ASR/local/swbd1_data_prep.sh @@ -17,11 +17,10 @@ ## will be using "find" to locate this file so we don't make any assumptions ## on the directory structure. (Peng Qi, Aug 2014) - #check existing directories if [ $# != 1 -a $# != 2 ]; then echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" - exit 1; + exit 1 fi SWBD_DIR=$1 @@ -29,29 +28,27 @@ SWBD_DIR=$1 dir=data/local/train mkdir -p $dir - # Audio data directory check if [ ! -d $SWBD_DIR ]; then echo "Error: run.sh requires a directory argument" - exit 1; + exit 1 fi sph2pipe=sph2pipe -! command -v "${sph2pipe}" &> /dev/null \ - && echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1; +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 # Option A: SWBD dictionary file check -[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && \ - echo "SWBD dictionary file does not exist" && exit 1; +[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && + echo "SWBD dictionary file does not exist" && exit 1 # find sph audio files -find -L $SWBD_DIR -iname '*.sph' | sort > $dir/sph.flist +find -L $SWBD_DIR -iname '*.sph' | sort >$dir/sph.flist -n=`cat $dir/sph.flist | wc -l` -[ $n -ne 2435 ] && [ $n -ne 2438 ] && \ +n=$(cat $dir/sph.flist | wc -l) +[ $n -ne 2435 ] && [ $n -ne 2438 ] && echo Warning: expected 2435 or 2438 data data files, found $n - # (1a) Transcriptions preparation # make basic transcription file (add segments info) # **NOTE: In the default Kaldi recipe, everything is made uppercase, while we @@ -64,11 +61,11 @@ stime=$2; etime=$3; printf("%s-%s_%06.0f-%06.0f", name, side, int(100*stime+0.5), int(100*etime+0.5)); for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" -}' ./swb_ms98_transcriptions/*/*/*-trans.text > $dir/transcripts1.txt +}' ./swb_ms98_transcriptions/*/*/*-trans.text >$dir/transcripts1.txt # test if trans. file is sorted -export LC_ALL=C; -sort -c $dir/transcripts1.txt || exit 1; # check it's sorted. +export LC_ALL=C +sort -c $dir/transcripts1.txt || exit 1 # check it's sorted. # Remove SILENCE, and . @@ -77,22 +74,21 @@ sort -c $dir/transcripts1.txt || exit 1; # check it's sorted. # speech to somone; we will give phones to the other three (NSN, SPN, LAU). # There will also be a silence phone, SIL. # **NOTE: modified the pattern matches to make them case insensitive -cat $dir/transcripts1.txt \ - | perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; +cat $dir/transcripts1.txt | + perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; s///gi; s///gi; - print;' \ - | awk '{if(NF > 1) { print; } } ' > $dir/transcripts2.txt - + print;' | + awk '{if(NF > 1) { print; } } ' >$dir/transcripts2.txt # **NOTE: swbd1_map_words.pl has been modified to make the pattern matches # case insensitive -local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt > $dir/text +local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt >$dir/text # format acronyms in text python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ -M data/local/dict_nosp/acronyms.map - mv $dir/text_map $dir/text +mv $dir/text_map $dir/text # (1c) Make segment files from transcript #segments file format is: utt-id side-id start-time end-time, e.g.: @@ -102,15 +98,15 @@ segment=$1; split(segment,S,"[_-]"); side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; print segment " " audioname "-" side " " startf/100 " " endf/100 -}' < $dir/text > $dir/segments +}' <$dir/text >$dir/segments sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ - > $dir/sph.scp + >$dir/sph.scp awk -v sph2pipe=$sph2pipe '{ printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); -}' < $dir/sph.scp | sort > $dir/wav.scp || exit 1; +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 #side A - channel 1, side B - channel 2 # this file reco2file_and_channel maps recording-id (e.g. sw02001-A) @@ -118,15 +114,15 @@ printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); # sw02001-A sw02001 A # In this case it's trivial, but in other corpora the information might # be less obvious. Later it will be needed for ctm scoring. -awk '{print $1}' $dir/wav.scp \ - | perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; print "$1-$2 $1 $2\n"; ' \ - > $dir/reco2file_and_channel || exit 1; + >$dir/reco2file_and_channel || exit 1 - awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments > $dir/utt2spk \ - || exit 1; - sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt || exit 1; +awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments >$dir/utt2spk || + exit 1 +sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl >$dir/spk2utt || exit 1 - echo Switchboard-1 data preparation succeeded. +echo Switchboard-1 data preparation succeeded. - utils/fix_data_dir.sh data/local/train +utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh index 0c38a72dce..0bb98903fc 100755 --- a/egs/swbd/ASR/local/swbd1_prepare_dict.sh +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -5,32 +5,36 @@ # To be run from one directory above this script. - #check existing directories -[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1; +[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1 -srcdir=. # This is where we downloaded some stuff.. +srcdir=. # This is where we downloaded some stuff.. dir=./data/local/dict_nosp mkdir -p $dir srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text # assume swbd_p1_data_prep.sh was done already. -[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1; +[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1 -cp $srcdict $dir/lexicon0.txt || exit 1; +cp $srcdict $dir/lexicon0.txt || exit 1 chmod a+w $dir/lexicon0.txt -patch 0' | sort > $dir/lexicon1.txt || exit 1; +grep -v '^#' $dir/lexicon0.txt | awk 'NF>0' | sort >$dir/lexicon1.txt || exit 1 -cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | \ - grep -v sil > $dir/nonsilence_phones.txt || exit 1; +cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | + grep -v sil >$dir/nonsilence_phones.txt || exit 1 -( echo sil; echo spn; echo nsn; echo lau ) > $dir/silence_phones.txt +( + echo sil + echo spn + echo nsn + echo lau +) >$dir/silence_phones.txt -echo sil > $dir/optional_silence.txt +echo sil >$dir/optional_silence.txt # No "extra questions" in the input to this setup, as we don't # have stress or tone. @@ -41,9 +45,14 @@ cp local/MSU_single_letter.txt $dir/ # Add single letter lexicon # The original swbd lexicon does not have precise single letter lexicion # e.g. it does not have entry of W -( echo '!sil sil'; echo '[vocalized-noise] spn'; echo '[noise] nsn'; \ - echo '[laughter] lau'; echo ' spn' ) \ - | cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt > $dir/lexicon2.txt || exit 1; +( + echo '!sil sil' + echo '[vocalized-noise] spn' + echo '[noise] nsn' + echo '[laughter] lau' + echo ' spn' +) | + cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 # Map the words in the lexicon. That is-- for each word in the lexicon, we map it # to a new written form. The transformations we do are: @@ -77,16 +86,16 @@ cp local/MSU_single_letter.txt $dir/ # in the lexicon. local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ - > $dir/lexicon3.txt || exit 1; + >$dir/lexicon3.txt || exit 1 python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map - cat $dir/acronyms_raw.map | sort -u > $dir/acronyms.map +cat $dir/acronyms_raw.map | sort -u >$dir/acronyms.map - ( echo 'i ay' )| cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u > $dir/lexicon5.txt +(echo 'i ay') | cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u >$dir/lexicon5.txt - pushd $dir >&/dev/null - ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. - popd >&/dev/null - rm $dir/lexiconp.txt 2>/dev/null - echo Prepared input dictionary and phone-sets for Switchboard phase 1. +pushd $dir >&/dev/null +ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. +popd >&/dev/null +rm $dir/lexiconp.txt 2>/dev/null +echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index ae0b278b15..6ac7dbf9ef 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -78,28 +78,28 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then cp data/local/${x}/text data/local/${x}/text.org paste -d "" \ <(cut -f 1 -d" " data/local/${x}/text.org) \ - <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") \ - | sed -e 's/\s\+/ /g' > data/local/${x}/text - rm data/local/${x}/text.org - done - - python ./local/filter_empty_text.py --kaldi-data-dir data/local/eval2000 - ./utils/fix_data_dir.sh data/local/eval2000 - lhotse kaldi import data/local/eval2000 8000 data/manifests_eval2000 - mv data/manifests_eval2000/recordings.jsonl.gz data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz - mv data/manifests_eval2000/supervisions.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz - - python ./local/filter_empty_text.py --kaldi-data-dir data/local/rt03 - ./utils/fix_data_dir.sh data/local/rt03 - lhotse kaldi import data/local/rt03 8000 data/manifests_rt03 - mv data/manifests_rt03/recordings.jsonl.gz data/manifests_rt03/swbd_recordings_rt03.jsonl.gz - mv data/manifests_rt03/supervisions.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz - - lhotse fix data/manifests_train/swbd_recordings_all.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz data/manifests - lhotse fix data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz data/manifests - lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests - - touch data/manifests/.swbd.done + <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | + sed -e 's/\s\+/ /g' >data/local/${x}/text + rm data/local/${x}/text.org + done + + python ./local/filter_empty_text.py --kaldi-data-dir data/local/eval2000 + ./utils/fix_data_dir.sh data/local/eval2000 + lhotse kaldi import data/local/eval2000 8000 data/manifests_eval2000 + mv data/manifests_eval2000/recordings.jsonl.gz data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz + mv data/manifests_eval2000/supervisions.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz + + python ./local/filter_empty_text.py --kaldi-data-dir data/local/rt03 + ./utils/fix_data_dir.sh data/local/rt03 + lhotse kaldi import data/local/rt03 8000 data/manifests_rt03 + mv data/manifests_rt03/recordings.jsonl.gz data/manifests_rt03/swbd_recordings_rt03.jsonl.gz + mv data/manifests_rt03/supervisions.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz + + lhotse fix data/manifests_train/swbd_recordings_all.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz data/manifests + lhotse fix data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz data/manifests + lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + + touch data/manifests/.swbd.done fi fi @@ -260,11 +260,11 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then -ngram-order 3 \ -text ${lang_dir}/input.txt \ -lm data/lm/3-gram.arpa - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt fi if [ ! -f data/lm/G_4_gram.fst.txt ]; then @@ -273,11 +273,11 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then -ngram-order 4 \ -text ${lang_dir}/input.txt \ -lm data/lm/4-gram.arpa - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt fi fi @@ -325,7 +325,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --bpe-model $lang_dir/bpe.model \ --lm-data data/lang_phone/input.txt \ --lm-archive $out_dir/lm_data.pt - done + done fi # if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then @@ -373,8 +373,8 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then --bpe-model $lang_dir/bpe.model \ --lm-data $out_dir/${testset}.txt \ --lm-archive $out_dir/lm_data-${testset}.pt - done done + done fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then @@ -393,11 +393,11 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then --in-lm-data $out_dir/lm_data.pt \ --out-lm-data $out_dir/sorted_lm_data.pt \ --out-statistics $out_dir/statistics.txt - for testset in ${testsets[@]}; do - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-${testset}.pt \ - --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ - --out-statistics $out_dir/statistics-test-${testset}.txt - done - done + for testset in ${testsets[@]}; do + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-${testset}.pt \ + --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ + --out-statistics $out_dir/statistics-test-${testset}.txt + done + done fi From 70c7576e0a89ff9cd8805ece75b8367b1c06d737 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 26 Jun 2023 19:49:00 +0800 Subject: [PATCH 05/60] Update README.md typo fixed --- egs/swbd/ASR/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index b9bbab787a..bf28e19e09 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -9,7 +9,7 @@ Switchboard is a collection of about 2,400 two-sided telephone conversations amo **Caution**: The `conformer_ctc` recipe for Switchboard is currently very rough and has a high Word Error Rate, requiring more improvement and refinement. The TODO list for this recipe is as follows. ## TODO List -- [ ] Incorporate Lhotse for dataprocessing +- [ ] Incorporate Lhotse for data processing - [ ] Refer to Global Mapping Rules when computing Word Error Rate - [ ] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset - [ ] Switchboard transcript train/dev split for LM training From abbd0d92a7fc5bb0df8e4251cdca532bf9fe5e74 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 26 Jun 2023 20:09:06 +0800 Subject: [PATCH 06/60] Update prepare.sh removed commented scripts --- egs/swbd/ASR/prepare.sh | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 6ac7dbf9ef..1490886e87 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -337,13 +337,7 @@ fi # mkdir -p $out_dir # if [ ! -f $out_dir/valid.txt ]; then -# files=$( -# find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" -# find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" -# ) -# for f in ${files[@]}; do -# cat $f | cut -d " " -f 2- -# done >$out_dir/valid.txt +# TODO: generate valid.txt # fi # lang_dir=data/lang_bpe_${vocab_size} From 354e40925b511356c9db542a992a246d73bf47a4 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 26 Jun 2023 21:12:56 +0800 Subject: [PATCH 07/60] Update decode.py fixed a KeyError occurred during the computation of WER for nbest-oracle --- egs/swbd/ASR/conformer_ctc/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index f805becd02..f89ad17e00 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -366,7 +366,7 @@ def decode_one_batch( ref_texts=supervisions["text"], word_table=word_table, nbest_scale=params.nbest_scale, - oov="", + oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] From f2a0a104820e88c81660d5ec4ceebc8fe754391a Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 26 Jun 2023 21:15:38 +0800 Subject: [PATCH 08/60] Update RESULTS.md Updated nbest-oracle WERs --- egs/swbd/ASR/RESULTS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md index ef8216f036..a7c50879be 100644 --- a/egs/swbd/ASR/RESULTS.md +++ b/egs/swbd/ASR/RESULTS.md @@ -47,4 +47,8 @@ and the following command for decoding: --max-duration 50 ``` +For your reference, the nbest oracle WERs are: +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 25.64 | 26.84 | From 11faddc83007ab7baacb96cfade00b0bf3689384 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 27 Jun 2023 20:23:33 +0800 Subject: [PATCH 09/60] Update RESULTS.md Lower WERs reported --- egs/swbd/ASR/RESULTS.md | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md index a7c50879be..68a843c014 100644 --- a/egs/swbd/ASR/RESULTS.md +++ b/egs/swbd/ASR/RESULTS.md @@ -1,6 +1,53 @@ ## Results ### Switchboard BPE training results (Conformer-CTC) +#### 2023-06-27 + +The best WER, as of 2023-06-27, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 30.80 | 32.29 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.1 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.9 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ + --num-epochs 100 +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 99 \ + --avg 10 \ + --max-duration 50 +``` + #### 2023-06-26 The best WER, as of 2023-06-26, for the Switchboard is below From f85b95e73b31b6ef000010ed70373244484338e7 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 28 Jun 2023 12:44:41 +0800 Subject: [PATCH 10/60] Updated decode.py to obtain WERs for subsets. --- egs/swbd/ASR/conformer_ctc/decode.py | 75 ++++++++++++++++------------ 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index f89ad17e00..05d07f44fe 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -581,39 +581,52 @@ def save_results( enable_log = False else: enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{subset}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}-{}, WER of different settings are:\n".format( + test_set_name, subset + ) + note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) @torch.no_grad() From 1f85c6a3d8953eae2119ab510e49f0d65eb77061 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Fri, 7 Jul 2023 12:49:30 +0800 Subject: [PATCH 11/60] Update README.md --- egs/swbd/ASR/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index bf28e19e09..3b93a64cbe 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -11,7 +11,7 @@ Switchboard is a collection of about 2,400 two-sided telephone conversations amo ## TODO List - [ ] Incorporate Lhotse for data processing - [ ] Refer to Global Mapping Rules when computing Word Error Rate -- [ ] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset +- [x] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset - [ ] Switchboard transcript train/dev split for LM training - [ ] Fisher corpus LDC2004T19 LDC2005T19 LDC2004S13 LDC2005S13 for LM training From 301c3541cc9d4bf8214bba07bad9450415e15012 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Thu, 27 Jul 2023 15:19:45 +0800 Subject: [PATCH 12/60] Update prepare.sh --- egs/swbd/ASR/prepare.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 1490886e87..6940d7186d 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -10,8 +10,9 @@ stage=-1 stop_stage=100 # We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. +# directories and files. Most of them can't be downloaded automatically +# as they are not publically available and require a license purchased +# from the LDC. # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -26,6 +27,7 @@ swbd1_dir="/export/corpora3/LDC/LDC97S62" eval2000_dir="/export/corpora2/LDC/LDC2002S09/hub5e_00" eval2000_ref_dir="/export/corpora2/LDC/LDC2002T43" rt03_dir="/export/corpora/LDC/LDC2007S10" +fisher_dir="/export/corpora3/LDC/LDC2004T19" . shared/parse_options.sh || exit 1 From e0d06f1b4dc68e3bac6ce3c62bca1653c2f1e5d4 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:48:48 +0800 Subject: [PATCH 13/60] added normalization for eval2000 --- egs/swbd/ASR/local/normalize_eval2000.py | 230 +++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 egs/swbd/ASR/local/normalize_eval2000.py diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py new file mode 100644 index 0000000000..de19aa4362 --- /dev/null +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +def remove_punctutation_and_other_symbol(text: str) -> str: + text = text.replace("--", " ") + text = text.replace("//", " ") + text = text.replace(".", " ") + text = text.replace("?", " ") + text = text.replace("~", " ") + text = text.replace(",", " ") + text = text.replace(";", " ") + text = text.replace("(", " ") + text = text.replace(")", " ") + text = text.replace("&", " ") + text = text.replace("%", " ") + text = text.replace("*", " ") + text = text.replace("{", " ") + text = text.replace("}", " ") + return text + + +def eval2000_clean_eform(text: str, eform_count) -> str: + string_to_remove = [] + piece = text.split('">') + for i in range(0, len(piece)): + s = piece[i] + '">' + res = re.search(r"", s) + if res is not None: + res_rm = res.group(1) + string_to_remove.append(res_rm) + for p in string_to_remove: + eform_string = p + text = text.replace(eform_string, " ") + eform_1 = " str: + text = text.replace("[/BABY CRYING]", " ") + text = text.replace("[/CHILD]", " ") + text = text.replace("[[DISTORTED]]", " ") + text = text.replace("[/DISTORTION]", " ") + text = text.replace("[[DRAWN OUT]]", " ") + text = text.replace("[[DRAWN-OUT]]", " ") + text = text.replace("[[FAINT]]", " ") + text = text.replace("[SMACK]", " ") + text = text.replace("[[MUMBLES]]", " ") + text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") + text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") + text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") + text = text.replace( + "[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " " + ) + text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") + text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") + text = text.replace("[[PROLONGED]]", " ") + text = text.replace("[/RUNNING WATER]", " ") + text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") + text = text.replace("[[SINGING]]", " ") + text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") + text = text.replace("[/STATIC]", " ") + text = text.replace("['THIRTIETH' DRAWN OUT]", " ") + text = text.replace("[/VOICES]", " ") + text = text.replace("[[WHISPERED]]", " ") + text = text.replace("[DISTORTION]", " ") + text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") + text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") + text = text.replace("[CHILD'S VOICE]", " ") + text = text.replace("[CHILD SCREAMS]", " ") + text = text.replace("[CHILD VOICE]", " ") + text = text.replace("[CHILD YELLING]", " ") + text = text.replace("[CHILD SCREAMING]", " ") + text = text.replace("[CHILD'S VOICE IN BACKGROUND]", " ") + text = text.replace("[CHANNEL NOISE]", " ") + text = text.replace("[CHANNEL ECHO]", " ") + text = text.replace("[ECHO FROM OTHER CHANNEL]", " ") + text = text.replace("[ECHO OF OTHER CHANNEL]", " ") + text = text.replace("[CLICK]", " ") + text = text.replace("[DISTORTED]", " ") + text = text.replace("[BABY CRYING]", " ") + text = text.replace("[METALLIC KNOCKING SOUND]", " ") + text = text.replace("[METALLIC SOUND]", " ") + + text = text.replace("[PHONE JIGGLING]", " ") + text = text.replace("[BACKGROUND SOUND]", " ") + text = text.replace("[BACKGROUND VOICE]", " ") + text = text.replace("[BACKGROUND VOICES]", " ") + text = text.replace("[BACKGROUND NOISE]", " ") + text = text.replace("[CAR HORNS IN BACKGROUND]", " ") + text = text.replace("[CAR HORNS]", " ") + text = text.replace("[CARNATING]", " ") + text = text.replace("[CRYING CHILD]", " ") + text = text.replace("[CHOPPING SOUND]", " ") + text = text.replace("[BANGING]", " ") + text = text.replace("[CLICKING NOISE]", " ") + text = text.replace("[CLATTERING]", " ") + text = text.replace("[ECHO]", " ") + text = text.replace("[KNOCK]", " ") + text = text.replace("[NOISE-GOOD]", "[NOISE]") + text = text.replace("[RIGHT]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[SQUEAK]", " ") + text = text.replace("[STATIC]", " ") + text = text.replace("[[SAYS WITH HIGH-PITCHED SCREAMING LAUGHTER]]", " ") + text = text.replace("[UH]", "UH") + text = text.replace("[MN]", "[VOCALIZED-NOISE]") + text = text.replace("[VOICES]", " ") + text = text.replace("[WATER RUNNING]", " ") + text = text.replace("[SOUND OF TWISTING PHONE CORD]", " ") + text = text.replace("[SOUND OF SOMETHING FALLING]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[NOISE OF MOVING PHONE]", " ") + text = text.replace("[SOUND OF RUNNING WATER]", " ") + text = text.replace("[CHANNEL]", " ") + text = text.replace("-[W]HERE", "WHERE") + text = text.replace("Y[OU]I-", "YOU I") + text = text.replace("-[A]ND", "AND") + text = text.replace("JU[ST]", "JUST") + text = text.replace("{BREATH}", " ") + text = text.replace("{BREATHY}", " ") + text = text.replace("{CHANNEL NOISE}", " ") + text = text.replace("{CLEAR THROAT}", " ") + + text = text.replace("{CLEARING THROAT}", " ") + text = text.replace("{CLEARS THROAT}", " ") + text = text.replace("{COUGH}", " ") + text = text.replace("{DRAWN OUT}", " ") + text = text.replace("{EXHALATION}", " ") + text = text.replace("{EXHALE}", " ") + text = text.replace("{GASP}", " ") + text = text.replace("{HIGH SQUEAL}", " ") + text = text.replace("{INHALE}", " ") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LIPSMACK}", " ") + text = text.replace("{LIPSMACK}", " ") + + text = text.replace("{NOISE OF DISGUST}", " ") + text = text.replace("{SIGH}", " ") + text = text.replace("{SNIFF}", " ") + text = text.replace("{SNORT}", " ") + text = text.replace("{SHARP EXHALATION}", " ") + text = text.replace("{BREATH LAUGH}", " ") + return text + + +def remove_languagetag(text: str) -> str: + langtag = re.findall(r"<(.*?)>", text) + for t in langtag: + text = text.replace(t, " ") + text = text.replace("<", " ") + text = text.replace(">", " ") + return text + + +def eval2000_normalizer(text: str) -> str: + # print("TEXT original: ",text) + eform_count = text.count("contraction e_form") + # print("eform corunt:", eform_count) + if eform_count > 0: + text = eval2000_clean_eform(text, eform_count) + text = text.upper() + text = remove_languagetag(text) + text = replace_silphone(text) + text = remove_punctutation_and_other_symbol(text) + text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") + text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") + spaces = re.findall(r"\s+", text) + for sp in spaces: + text = text.replace(sp, " ") + text = text.strip() + # text = self.whitespace_regexp.sub(" ", text).strip() + # print(text) + return text + + +def main(): + args = get_args() + sups = load_manifest_lazy_or_eager(args.input_sups) + assert isinstance(sups, SupervisionSet) + + tot, skip = 0, 0 + with SupervisionSet.open_writer(args.output_sups) as writer: + for sup in tqdm(sups, desc="Normalizing supervisions"): + tot += 1 + sup.text = eval2000_normalizer(sup.text) + if not sup.text: + skip += 1 + continue + writer.write(sup) + + +if __name__ == "__main__": + main() \ No newline at end of file From 11fe0004f4e5fd6f39610cd4d01da94fd6d75dba Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Tue, 1 Aug 2023 17:15:01 +0800 Subject: [PATCH 14/60] minor updates --- egs/swbd/ASR/local/eval2000_data_prep.sh | 118 ----------------------- egs/swbd/ASR/local/normalize_eval2000.py | 4 + egs/swbd/ASR/prepare.sh | 21 +++- 3 files changed, 20 insertions(+), 123 deletions(-) delete mode 100755 egs/swbd/ASR/local/eval2000_data_prep.sh diff --git a/egs/swbd/ASR/local/eval2000_data_prep.sh b/egs/swbd/ASR/local/eval2000_data_prep.sh deleted file mode 100755 index a5bd3ebcf9..0000000000 --- a/egs/swbd/ASR/local/eval2000_data_prep.sh +++ /dev/null @@ -1,118 +0,0 @@ -#!/bin/bash - -# Hub-5 Eval 2000 data preparation -# Author: Arnab Ghoshal (Jan 2013) - -# To be run from one directory above this script. - -# The input is two directory names (possibly the same) containing the -# 2000 Hub5 english evaluation test set and transcripts, which are -# respectively: LDC2002S09 LDC2002T43 -# e.g. see -# http://www.ldc.upenn.edu/Catalog/catalogEntry.jsp?catalogId=LDC2002S09 -# http://www.ldc.upenn.edu/Catalog/CatalogEntry.jsp?catalogId=LDC2002T43 -# -# Example usage: -# local/eval2000_data_prep_edin.sh /exports/work/inf_hcrc_cstr_general/corpora/hub5/2000 /exports/work/inf_hcrc_cstr_general/corpora/hub5/2000/transcr -# The first directory ($sdir) contains the speech data, and the directory -# $sdir/english/ must exist. -# The second directory ($tdir) contains the transcripts, and the directory -# $tdir/reference must exist; in particular we need the file -# $tdir/reference/hub5e00.english.000405.stm - -if [ $# -ne 2 ]; then - echo "Usage: "$(basename $0)" " - echo "See comments in the script for more details" - exit 1 -fi - -sdir=$1 -tdir=$2 -[ ! -d $sdir/english ] && - echo Expecting directory $sdir/english to be present && exit 1 -[ -d $tdir/2000_hub5_eng_eval_tr ] && - tdir=$tdir/2000_hub5_eng_eval_tr -[ ! -d $tdir/reference ] && - echo Expecting directory $tdir/reference to be present && exit 1 - -dir=data/local/eval2000 -mkdir -p $dir - -find $sdir/english -iname '*.sph' | sort >$dir/sph.flist -sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ - >$dir/sph.scp - -sph2pipe=sph2pipe -[ ! -x $sph2pipe ] && - echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 - -awk -v sph2pipe=$sph2pipe '{ - printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); - printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); - }' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 -#side A - channel 1, side B - channel 2 - -# Get segments file... -# segments file format is: utt-id side-id start-time end-time, e.g.: -# sw02001-A_000098-001156 sw02001-A 0.98 11.56 -pem=$sdir/english/hub5e_00.pem -[ ! -f $pem ] && echo "$0: No such file $pem" && exit 1 -# pem file has lines like: -# en_4156 A unknown_speaker 301.85 302.48 - -# we ignore the warnings below for now, although they seem to indicate some problems -# with the data. -grep -v ';;' $pem | - awk '{ - spk=$1"-"$2; - utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - print utt,spk,$4,$5;}' | - sort -u | local/extend_segments.pl 0.1 >$dir/segments - -# stm file has lines like: -# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER -# TODO(arnab): We should really be lowercasing this since the Edinburgh -# recipe uses lowercase. This is not used in the actual scoring. -grep -v ';;' $tdir/reference/hub5e00.english.000405.stm | - awk '{ - spk=$1"-"$2; - utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); - printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | - sort >$dir/text.all - -# We'll use the stm file for sclite scoring. There seem to be various errors -# in the stm file that upset hubscr.pl, and we fix them here. -sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ - $tdir/reference/hub5e00.english.000405.stm >$dir/stm -cp $tdir/reference/en20000405_hub5.glm $dir/glm - -# next line uses command substitution -# Just checking that the segments are the same in pem vs. stm. -! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && - echo "Segments from pem file and stm file do not match." && exit 1 - -grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text - -# create an utt2spk file that assumes each conversation side is -# a separate speaker. -awk '{print $1,$2;}' $dir/segments >$dir/utt2spk -utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt - -# cp $dir/segments $dir/segments.tmp -# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ -# $dir/segments.tmp > $dir/segments - -awk '{print $1}' $dir/wav.scp | - perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; - print "$1-$2 $1 $2\n"; ' \ - >$dir/reco2file_and_channel || exit 1 - -echo Data preparation and formatting completed for Eval 2000 -echo "(but not MFCC extraction)" - -utils/fix_data_dir.sh $dir - -if [ $(wc -l <$dir/wav.scp) -ne 80 ]; then - echo "$0: error: expected 80 lines in wav.scp, got $(wc -l <$dir/wav.scp)" - exit 1 -fi diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py index de19aa4362..4de3ab0f82 100644 --- a/egs/swbd/ASR/local/normalize_eval2000.py +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -177,6 +177,10 @@ def replace_silphone(text: str) -> str: text = text.replace("{SNORT}", " ") text = text.replace("{SHARP EXHALATION}", " ") text = text.replace("{BREATH LAUGH}", " ") + + text = text.replace("[LAUGHTER]", " ") + text = text.replace("[NOISE]", " ") + text = text.replace("[VOCALIZED-NOISE]", " ") return text diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 6940d7186d..7c9156abe1 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -24,8 +24,15 @@ stop_stage=100 dl_dir=./download swbd1_dir="/export/corpora3/LDC/LDC97S62" -eval2000_dir="/export/corpora2/LDC/LDC2002S09/hub5e_00" -eval2000_ref_dir="/export/corpora2/LDC/LDC2002T43" + +# eval2000_dir contains the following files and directories +# downloaded from LDC website: +# - LDC2002S09 +# - hub5e_00 +# - LDC2002T43 +# - 2000_hub5_eng_eval_tr +eval2000_dir="/export/corpora2/LDC/eval2000" + rt03_dir="/export/corpora/LDC/LDC2007S10" fisher_dir="/export/corpora3/LDC/LDC2004T19" @@ -52,7 +59,7 @@ log() { } log "swbd1_dir: $swbd1_dir" -log "eval2000_dir: $eval2000_dir $eval2000_ref_dir" +log "eval2000_dir: $eval2000_dir" log "rt03_dir: $rt03_dir" if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then @@ -68,7 +75,11 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then mv data/manifests_train/recordings.jsonl.gz data/manifests_train/swbd_recordings_all.jsonl.gz mv data/manifests_train/supervisions.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz - ./local/eval2000_data_prep.sh $eval2000_dir $eval2000_ref_dir + lhotse prepare $eval2000_dir data/manifests_eval2000 + ./local/normalize_eval2000.py \ + data/manifests_eval2000/eval2000_supervisions_unnorm.jsonl.gz \ + data/manifests_eval2000/eval2000_supervisions.jsonl.gz + ./local/rt03_data_prep.sh $rt03_dir # normalize eval2000 and rt03 texts by @@ -76,7 +87,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # 2) remove tags (%AH) (%HESITATION) (%UH) # 3) remove # 4) remove "(" or ")" - for x in eval2000 rt03; do + for x in rt03; do cp data/local/${x}/text data/local/${x}/text.org paste -d "" \ <(cut -f 1 -d" " data/local/${x}/text.org) \ From 57e6808e8dcd2c4036b39b185ca382c11b97d132 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Tue, 1 Aug 2023 17:15:35 +0800 Subject: [PATCH 15/60] file permission changed --- egs/swbd/ASR/local/normalize_eval2000.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 egs/swbd/ASR/local/normalize_eval2000.py diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py old mode 100644 new mode 100755 From 099e789ba02751fefbd6752643e192918c847fc8 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Tue, 1 Aug 2023 17:24:50 +0800 Subject: [PATCH 16/60] minor updates --- .../normalize_and_filter_supervisions.py | 228 ++++++++++++++++++ egs/swbd/ASR/local/normalize_eval2000.py | 2 +- egs/swbd/ASR/prepare.sh | 9 +- 3 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 egs/swbd/ASR/local/normalize_and_filter_supervisions.py diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py new file mode 100644 index 0000000000..62f7efc68a --- /dev/null +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +# replacement function to convert lowercase letter to uppercase +def to_upper(match_obj): + if match_obj.group() is not None: + return match_obj.group().upper() + + +def insert_groups_and_capitalize_3(match): + return f"{match.group(1)} {match.group(2)} {match.group(3)}".upper() + + +def insert_groups_and_capitalize_2(match): + return f"{match.group(1)} {match.group(2)}".upper() + + +def insert_groups_and_capitalize_1(match): + return f"{match.group(1)}".upper() + + +def insert_groups_and_capitalize_1s(match): + return f"{match.group(1)}".upper() + "'s" + + +# fmt: off +class FisherSwbdNormalizer: + """Note: the functions "normalize" and "keep" implement the logic + similar to Kaldi's data prep scripts for Fisher and SWBD: One + notable difference is that we don't change [cough], [lipsmack], + etc. to [noise]. We also don't implement all the edge cases of + normalization from Kaldi (hopefully won't make too much + difference). + """ + def __init__(self) -> None: + + self.remove_regexp_before = re.compile( + r"|".join([ + # special symbols + r"\[\[skip.*\]\]", + r"\[skip.*\]", + r"\[pause.*\]", + r"\[silence\]", + r"", + r"", + ]) + ) + + # tuples of (pattern, replacement) + # note: Kaldi replaces sighs, coughs, etc with [noise]. + # We don't do that here. + # We also lowercase the text as the first operation. + self.replace_regexps: Tuple[re.Pattern, str] = [ + # SWBD: + # [LAUGHTER-STORY] -> STORY + (re.compile(r"\[laughter-(.*?)\]"), r"\1"), + # [WEA[SONABLE]-/REASONABLE] + (re.compile(r"\[\S+/(\S+)\]"), r"\1"), + # -[ADV]AN[TAGE]- -> AN + (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), + # ABSOLUTE[LY]- -> ABSOLUTE- + (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), + # [AN]Y- -> Y- + # -[AN]Y- -> Y- + (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), + # special tokens + (re.compile(r"\[laugh.*?\]"), r"[laughter]"), + (re.compile(r"\[sigh.*?\]"), r"[sigh]"), + (re.compile(r"\[cough.*?\]"), r"[cough]"), + (re.compile(r"\[mn.*?\]"), r"[vocalized-noise]"), + (re.compile(r"\[breath.*?\]"), r"[breath]"), + (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), + (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), + # abbreviations + (re.compile(r"(\w)\.(\w)\.(\w)",), insert_groups_and_capitalize_3), + (re.compile(r"(\w)\.(\w)",), insert_groups_and_capitalize_2), + (re.compile(r"([a-h,j-z])\.",), insert_groups_and_capitalize_1), + (re.compile(r"\._",), r" "), + (re.compile(r"_(\w)",), insert_groups_and_capitalize_1), + (re.compile(r"(\w)\.s",), insert_groups_and_capitalize_1s), + (re.compile(r"([A-Z])\'s",), insert_groups_and_capitalize_1s), + (re.compile(r"(\s\w\b|^\w\b)",), insert_groups_and_capitalize_1), + # words between apostrophes + (re.compile(r"'(\S*?)'"), r"\1"), + # dangling dashes (2 passes) + (re.compile(r"\s-\s"), r" "), + (re.compile(r"\s-\s"), r" "), + # special symbol with trailing dash + (re.compile(r"(\[.*?\])-"), r"\1"), + # Just remove all dashes + (re.compile(r"-"), r" "), + ] + + # unwanted symbols in the transcripts + self.remove_regexp_after = re.compile( + r"|".join([ + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ]) + ) + + self.whitespace_regexp = re.compile(r"\s+") + + def normalize(self, text: str) -> str: + text = text.lower() + + # first remove + text = self.remove_regexp_before.sub("", text) + + # then replace + for pattern, sub in self.replace_regexps: + text = pattern.sub(sub, text) + + # then remove + text = self.remove_regexp_after.sub("", text) + + # then clean up whitespace + text = self.whitespace_regexp.sub(" ", text).strip() + + return text +# fmt: on + + +def keep(sup: SupervisionSegment) -> bool: + if "((" in sup.text: + return False + + if " yes", + "[laugh] oh this is [laught] this is great [silence] yes", + "i don't kn- - know A.B.C's", + "so x. corp is good?", + "'absolutely yes", + "absolutely' yes", + "'absolutely' yes", + "'absolutely' yes 'aight", + "ABSOLUTE[LY]", + "ABSOLUTE[LY]-", + "[AN]Y", + "[AN]Y-", + "[ADV]AN[TAGE]", + "[ADV]AN[TAGE]-", + "-[ADV]AN[TAGE]", + "-[ADV]AN[TAGE]-", + "[WEA[SONABLE]-/REASONABLE]", + "[VOCALIZED-NOISE]-", + "~BULL", + "Frank E Peretti P E R E T T I", + "yeah yeah like Double O Seven he’s supposed to do it", + "P A P E R paper", + ]: + print(text) + print(normalizer.normalize(text)) + print() + + +if __name__ == "__main__": + test() +# main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py index 4de3ab0f82..2e4a3c07fc 100644 --- a/egs/swbd/ASR/local/normalize_eval2000.py +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 7c9156abe1..ab8f907c9d 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -68,7 +68,12 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to respective dirs mkdir -p data/manifests if [ ! -e data/manifests/.swbd.done ]; then - lhotse prepare switchboard --absolute-paths True $swbd1_dir data/manifests_train + lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd + ./local/normalize_and_filter_supervisions.py \ + data/manifests/swbd/swbd_supervisions.jsonl \ + data/manifests/swbd/swbd_supervisions_norm.jsonl + cp data/manifests/swbd/swbd_recordings.jsonl data/manifests/recordings_swbd.jsonl + ./local/swbd1_prepare_dict.sh ./local/swbd1_data_prep.sh $swbd1_dir lhotse kaldi import data/local/train 8000 data/manifests_train @@ -78,7 +83,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then lhotse prepare $eval2000_dir data/manifests_eval2000 ./local/normalize_eval2000.py \ data/manifests_eval2000/eval2000_supervisions_unnorm.jsonl.gz \ - data/manifests_eval2000/eval2000_supervisions.jsonl.gz + data/manifests_eval2000/eval2000_supervisions_norm.jsonl.gz ./local/rt03_data_prep.sh $rt03_dir From 675816509931070c0cfafbaf1a507d8c9b397b51 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Tue, 1 Aug 2023 19:11:56 +0800 Subject: [PATCH 17/60] minor updates --- egs/swbd/ASR/local/compute_fbank_swbd.py | 21 ++++++++++++--------- egs/swbd/ASR/prepare.sh | 5 ----- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py index a270f52ee4..556ef9e5a1 100755 --- a/egs/swbd/ASR/local/compute_fbank_swbd.py +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -74,12 +74,13 @@ def get_args(): def compute_fbank_switchboard( + dir_name: str, bpe_model: Optional[str] = None, dataset: Optional[str] = None, perturb_speed: Optional[bool] = True, ): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 @@ -89,11 +90,11 @@ def compute_fbank_switchboard( sp.load(bpe_model) if dataset is None: - dataset_parts = ("all", "eval2000", "rt03") + dataset_parts = ("all") else: dataset_parts = dataset.split(" ", -1) - prefix = "swbd" + prefix = dir_name suffix = "jsonl.gz" manifests = read_manifests_if_cached( dataset_parts=dataset_parts, @@ -151,8 +152,10 @@ def compute_fbank_switchboard( logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_switchboard( - bpe_model=args.bpe_model, - dataset=args.dataset, - perturb_speed=args.perturb_speed, - ) + for dir_name in ["swbd"]: + compute_fbank_switchboard( + dir_name=dir_name, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index ab8f907c9d..d099e34baf 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -74,11 +74,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then data/manifests/swbd/swbd_supervisions_norm.jsonl cp data/manifests/swbd/swbd_recordings.jsonl data/manifests/recordings_swbd.jsonl - ./local/swbd1_prepare_dict.sh - ./local/swbd1_data_prep.sh $swbd1_dir - lhotse kaldi import data/local/train 8000 data/manifests_train - mv data/manifests_train/recordings.jsonl.gz data/manifests_train/swbd_recordings_all.jsonl.gz - mv data/manifests_train/supervisions.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz lhotse prepare $eval2000_dir data/manifests_eval2000 ./local/normalize_eval2000.py \ From 5533c6278d146e59557bc15f142d523a75203058 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Tue, 8 Aug 2023 19:27:18 +0800 Subject: [PATCH 18/60] updated --- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 10 ++-- egs/swbd/ASR/local/compute_fbank_swbd.py | 6 +- .../normalize_and_filter_supervisions.py | 10 +++- egs/swbd/ASR/local/prepare_lang.py | 4 +- egs/swbd/ASR/local/prepare_lang_bpe.py | 2 +- egs/swbd/ASR/prepare.sh | 56 ++++++++++--------- 6 files changed, 47 insertions(+), 41 deletions(-) mode change 100644 => 100755 egs/swbd/ASR/local/normalize_and_filter_supervisions.py diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index c45f2cbd0b..41da5ab3df 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -225,6 +225,8 @@ def train_dataloaders( else: logging.info("Disable MUSAN") + cuts_train = cuts_train.trim_to_supervisions(keep_overlapping=False) + if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -392,25 +394,23 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: def train_all_cuts(self) -> CutSet: logging.info("switchboard: About to get train cuts") return ( - load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz") + load_manifest_lazy(self.args.manifest_dir / "swbd" / "swbd_cuts_all.jsonl.gz") .subset(last=2388) - .trim_to_supervisions(keep_all_channels=True) ) @lru_cache() def dev_cuts(self) -> CutSet: logging.info("switchboard: About to get dev cuts") return ( - load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_all.jsonl.gz") + load_manifest_lazy(self.args.manifest_dir / "swbd" / "swbd_cuts_all.jsonl.gz") .subset(first=50) - .trim_to_supervisions(keep_all_channels=True) ) @lru_cache() def test_eval2000_cuts(self) -> CutSet: logging.info("switchboard: About to get eval2000 cuts") return load_manifest_lazy( - self.args.manifest_dir / "swbd_cuts_eval2000.jsonl.gz" + self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" ) @lru_cache() diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py index 556ef9e5a1..cab5164466 100755 --- a/egs/swbd/ASR/local/compute_fbank_swbd.py +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -66,7 +66,7 @@ def get_args(): parser.add_argument( "--perturb-speed", type=str2bool, - default=True, + default=False, help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", ) @@ -90,7 +90,7 @@ def compute_fbank_switchboard( sp.load(bpe_model) if dataset is None: - dataset_parts = ("all") + dataset_parts = ("all",) else: dataset_parts = dataset.split(" ", -1) @@ -152,7 +152,7 @@ def compute_fbank_switchboard( logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - for dir_name in ["swbd"]: + for dir_name in ["swbd", "eval2000"]: compute_fbank_switchboard( dir_name=dir_name, bpe_model=args.bpe_model, diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py old mode 100644 new mode 100755 index 62f7efc68a..9970e112a6 --- a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -119,6 +119,9 @@ def __init__(self) -> None: (re.compile(r"(\[.*?\])-"), r"\1"), # Just remove all dashes (re.compile(r"-"), r" "), + + # Fix an issue related to [vocalized-noise] + (re.compile(r"\[vocalized noise\]"), r"\[vocalized-noise\]"), ] # unwanted symbols in the transcripts @@ -153,7 +156,7 @@ def normalize(self, text: str) -> str: # then clean up whitespace text = self.whitespace_regexp.sub(" ", text).strip() - return text + return text.upper() # fmt: on @@ -189,6 +192,7 @@ def main(): continue writer.write(sup) + print(f"tot: {tot}, skip: {skip}") def test(): @@ -224,5 +228,5 @@ def test(): if __name__ == "__main__": - test() -# main() + # test(); exit() + main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py index f7d4609a95..d913756a19 100755 --- a/egs/swbd/ASR/local/prepare_lang.py +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -249,7 +249,7 @@ def lexicon_to_fst( lexicon: Lexicon, token2id: Dict[str, int], word2id: Dict[str, int], - sil_token: str = "sil", + sil_token: str = "SIL", sil_prob: float = 0.5, need_self_loops: bool = False, ) -> k2.Fsa: @@ -346,7 +346,7 @@ def main(): args = get_args() lang_dir = Path(args.lang_dir) lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "sil" + sil_token = "SIL" sil_prob = 0.5 lexicon = read_lexicon(lexicon_filename) diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py index f5b3a25b34..db5e42b055 100755 --- a/egs/swbd/ASR/local/prepare_lang_bpe.py +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -178,7 +178,7 @@ def get_args(): parser.add_argument( "--oov", type=str, - default="", + default="", help="The out of vocabulary word in lexicon.", ) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index d099e34baf..0d64e6814c 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -23,7 +23,8 @@ stop_stage=100 # - speech dl_dir=./download -swbd1_dir="/export/corpora3/LDC/LDC97S62" +# swbd1_dir="/export/corpora3/LDC/LDC97S62" +swbd1_dir=./download/LDC97S62/ # eval2000_dir contains the following files and directories # downloaded from LDC website: @@ -70,15 +71,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ ! -e data/manifests/.swbd.done ]; then lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd ./local/normalize_and_filter_supervisions.py \ - data/manifests/swbd/swbd_supervisions.jsonl \ - data/manifests/swbd/swbd_supervisions_norm.jsonl - cp data/manifests/swbd/swbd_recordings.jsonl data/manifests/recordings_swbd.jsonl - + data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz - lhotse prepare $eval2000_dir data/manifests_eval2000 + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 ./local/normalize_eval2000.py \ - data/manifests_eval2000/eval2000_supervisions_unnorm.jsonl.gz \ - data/manifests_eval2000/eval2000_supervisions_norm.jsonl.gz + data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ + data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz ./local/rt03_data_prep.sh $rt03_dir @@ -96,20 +96,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then rm data/local/${x}/text.org done - python ./local/filter_empty_text.py --kaldi-data-dir data/local/eval2000 - ./utils/fix_data_dir.sh data/local/eval2000 - lhotse kaldi import data/local/eval2000 8000 data/manifests_eval2000 - mv data/manifests_eval2000/recordings.jsonl.gz data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz - mv data/manifests_eval2000/supervisions.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz - - python ./local/filter_empty_text.py --kaldi-data-dir data/local/rt03 - ./utils/fix_data_dir.sh data/local/rt03 - lhotse kaldi import data/local/rt03 8000 data/manifests_rt03 - mv data/manifests_rt03/recordings.jsonl.gz data/manifests_rt03/swbd_recordings_rt03.jsonl.gz - mv data/manifests_rt03/supervisions.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz - - lhotse fix data/manifests_train/swbd_recordings_all.jsonl.gz data/manifests_train/swbd_supervisions_all.jsonl.gz data/manifests - lhotse fix data/manifests_eval2000/swbd_recordings_eval2000.jsonl.gz data/manifests_eval2000/swbd_supervisions_eval2000.jsonl.gz data/manifests lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests touch data/manifests/.swbd.done @@ -128,7 +114,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for switchboard" + log "Stage 3: Compute fbank for SwitchBoard" mkdir -p data/fbank if [ ! -e data/fbank/.swbd.done ]; then ./local/compute_fbank_swbd.py @@ -150,13 +136,29 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then lang_dir=data/lang_phone mkdir -p $lang_dir + if ! which jq; then + echo "This script is intended to be used with jq but you have not installed jq + Note: in Linux, you can install jq with the following command: + 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + 2. chmod +x ./jq + 3. cp jq /usr/bin" && exit 1 + fi + if [ ! -f $lang_dir/text ] || [ ! -s $lang_dir/text ]; then + log "Prepare text." + gunzip -c data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + log "prepare dict" - cut -f 2- -d" " data/local/train/text >${lang_dir}/input.txt + ./local/swbd1_prepare_dict.sh $swbd1_dir + cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt # [noise] nsn # !sil sil # spn cat data/local/dict_nosp/lexicon.txt | - sort | uniq >$lang_dir/lexicon.txt + sort | uniq >$lang_dir/lexicon_lower.txt + + cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang.py --lang-dir $lang_dir @@ -192,7 +194,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -f $lang_dir/transcript_words.txt ]; then log "Generate data for BPE training" - cat ./data/local/train/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt + cat data/lang_phone/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt fi if [ ! -f $lang_dir/bpe.model ]; then @@ -239,7 +241,7 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then ./local/convert_transcript_words_to_tokens.py \ --lexicon $lang_dir/lexicon.txt \ --transcript $lang_dir/transcript_words.txt \ - --oov "" \ + --oov "" \ >$lang_dir/transcript_tokens.txt fi From e0ee8dd428e0bb3f2508182aa183145689f5c71b Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 11 Aug 2023 10:31:08 +0800 Subject: [PATCH 19/60] minor updates --- egs/swbd/ASR/local/compute_fbank_swbd.py | 6 +++--- egs/swbd/ASR/prepare.sh | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py index cab5164466..a89130baec 100755 --- a/egs/swbd/ASR/local/compute_fbank_swbd.py +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -111,7 +111,7 @@ def compute_fbank_switchboard( dataset_parts, ) - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=8000)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -121,7 +121,7 @@ def compute_fbank_switchboard( continue logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( - recordings=m["recordings"].resample(16000), + recordings=m["recordings"], supervisions=m["supervisions"], ) @@ -134,7 +134,7 @@ def compute_fbank_switchboard( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ).resample(16000) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 0d64e6814c..4609f329a4 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -80,23 +80,23 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz - ./local/rt03_data_prep.sh $rt03_dir + # ./local/rt03_data_prep.sh $rt03_dir # normalize eval2000 and rt03 texts by # 1) convert upper to lower # 2) remove tags (%AH) (%HESITATION) (%UH) # 3) remove # 4) remove "(" or ")" - for x in rt03; do - cp data/local/${x}/text data/local/${x}/text.org - paste -d "" \ - <(cut -f 1 -d" " data/local/${x}/text.org) \ - <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | - sed -e 's/\s\+/ /g' >data/local/${x}/text - rm data/local/${x}/text.org - done - - lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + # for x in rt03; do + # cp data/local/${x}/text data/local/${x}/text.org + # paste -d "" \ + # <(cut -f 1 -d" " data/local/${x}/text.org) \ + # <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | + # sed -e 's/\s\+/ /g' >data/local/${x}/text + # rm data/local/${x}/text.org + # done + + # lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests touch data/manifests/.swbd.done fi From ab07e58613ed7eef57715ca7fd182e057ea4680b Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:19:08 +0800 Subject: [PATCH 20/60] minor updates --- .../normalize_and_filter_supervisions.py | 123 +++++++++++++----- egs/swbd/ASR/local/normalize_eval2000.py | 7 +- egs/swbd/ASR/local/prepare_lang_bpe.py | 10 +- egs/swbd/ASR/local/swbd1_prepare_dict.sh | 10 +- egs/swbd/ASR/prepare.sh | 56 ++++---- 5 files changed, 131 insertions(+), 75 deletions(-) diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py index 9970e112a6..20ab90caf3 100755 --- a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -53,7 +53,6 @@ def insert_groups_and_capitalize_1s(match): return f"{match.group(1)}".upper() + "'s" -# fmt: off class FisherSwbdNormalizer: """Note: the functions "normalize" and "keep" implement the logic similar to Kaldi's data prep scripts for Fisher and SWBD: One @@ -62,18 +61,21 @@ class FisherSwbdNormalizer: normalization from Kaldi (hopefully won't make too much difference). """ - def __init__(self) -> None: + def __init__(self) -> None: self.remove_regexp_before = re.compile( - r"|".join([ - # special symbols - r"\[\[skip.*\]\]", - r"\[skip.*\]", - r"\[pause.*\]", - r"\[silence\]", - r"", - r"", - ]) + r"|".join( + [ + # special symbols + r"\[\[skip.*\]\]", + r"\[skip.*\]", + r"\[pause.*\]", + r"\[silence\]", + r"", + r"", + r"_1", + ] + ) ) # tuples of (pattern, replacement) @@ -102,14 +104,54 @@ def __init__(self) -> None: (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), # abbreviations - (re.compile(r"(\w)\.(\w)\.(\w)",), insert_groups_and_capitalize_3), - (re.compile(r"(\w)\.(\w)",), insert_groups_and_capitalize_2), - (re.compile(r"([a-h,j-z])\.",), insert_groups_and_capitalize_1), - (re.compile(r"\._",), r" "), - (re.compile(r"_(\w)",), insert_groups_and_capitalize_1), - (re.compile(r"(\w)\.s",), insert_groups_and_capitalize_1s), - (re.compile(r"([A-Z])\'s",), insert_groups_and_capitalize_1s), - (re.compile(r"(\s\w\b|^\w\b)",), insert_groups_and_capitalize_1), + ( + re.compile( + r"(\w)\.(\w)\.(\w)", + ), + insert_groups_and_capitalize_3, + ), + ( + re.compile( + r"(\w)\.(\w)", + ), + insert_groups_and_capitalize_2, + ), + ( + re.compile( + r"([a-h,j-z])\.", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"\._", + ), + r" ", + ), + ( + re.compile( + r"_(\w)", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"(\w)\.s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"([A-Z])\'s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"(\s\w\b|^\w\b)", + ), + insert_groups_and_capitalize_1, + ), # words between apostrophes (re.compile(r"'(\S*?)'"), r"\1"), # dangling dashes (2 passes) @@ -119,25 +161,29 @@ def __init__(self) -> None: (re.compile(r"(\[.*?\])-"), r"\1"), # Just remove all dashes (re.compile(r"-"), r" "), - - # Fix an issue related to [vocalized-noise] - (re.compile(r"\[vocalized noise\]"), r"\[vocalized-noise\]"), ] # unwanted symbols in the transcripts self.remove_regexp_after = re.compile( - r"|".join([ - # remaining punctuation - r"\.", - r",", - r"\?", - r"{", - r"}", - r"~", - r"_\d", - ]) + r"|".join( + [ + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ] + ) ) + self.post_fixes = [ + # Fix an issue related to [VOCALIZED NOISE] after dash removal + (re.compile(r"\[vocalized noise\]"), "[vocalized-noise]"), + ] + self.whitespace_regexp = re.compile(r"\s+") def normalize(self, text: str) -> str: @@ -153,11 +199,14 @@ def normalize(self, text: str) -> str: # then remove text = self.remove_regexp_after.sub("", text) + # post fixes + for pattern, sub in self.post_fixes: + text = pattern.sub(sub, text) + # then clean up whitespace text = self.whitespace_regexp.sub(" ", text).strip() return text.upper() -# fmt: on def keep(sup: SupervisionSegment) -> bool: @@ -186,7 +235,7 @@ def main(): skip += 1 continue - sup.text = normalizer.normalize(sup.text) + sup.text = normalizer.normalize(sup.text).upper() if not sup.text: skip += 1 continue @@ -219,8 +268,9 @@ def test(): "[VOCALIZED-NOISE]-", "~BULL", "Frank E Peretti P E R E T T I", - "yeah yeah like Double O Seven he’s supposed to do it", + "yeah yeah like Double O Seven he's supposed to do it", "P A P E R paper", + "[noise] okay_1 um let me see [laughter] i've been sitting here awhile", ]: print(text) print(normalizer.normalize(text)) @@ -228,5 +278,6 @@ def test(): if __name__ == "__main__": - # test(); exit() + test() + # exit() main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py index 2e4a3c07fc..38451f5063 100755 --- a/egs/swbd/ASR/local/normalize_eval2000.py +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -83,9 +83,7 @@ def replace_silphone(text: str) -> str: text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") - text = text.replace( - "[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " " - ) + text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " ") text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") text = text.replace("[[PROLONGED]]", " ") @@ -181,6 +179,7 @@ def replace_silphone(text: str) -> str: text = text.replace("[LAUGHTER]", " ") text = text.replace("[NOISE]", " ") text = text.replace("[VOCALIZED-NOISE]", " ") + text = text.replace("-", " ") return text @@ -231,4 +230,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py index db5e42b055..0fd937bc93 100755 --- a/egs/swbd/ASR/local/prepare_lang_bpe.py +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -210,15 +210,15 @@ def main(): excluded = [ "", - "!sil", - "", + "!SIL", + "", args.oov, "#0", "", "", - "[vocalized-noise]", - "[noise]", - "[laughter]", + "[VOCALIZED-NOISE]", + "[NOISE]", + "[LAUGHTER]", ] for w in excluded: diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh index 0bb98903fc..eff5fb5f1b 100755 --- a/egs/swbd/ASR/local/swbd1_prepare_dict.sh +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -46,11 +46,11 @@ cp local/MSU_single_letter.txt $dir/ # The original swbd lexicon does not have precise single letter lexicion # e.g. it does not have entry of W ( - echo '!sil sil' - echo '[vocalized-noise] spn' - echo '[noise] nsn' - echo '[laughter] lau' - echo ' spn' + echo '!SIL SIL' + echo '[VOCALIZED-NOISE] spn' + echo '[NOISE] nsn' + echo '[LAUGHTER] lau' + echo ' spn' ) | cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 4609f329a4..1ac5e8a1ec 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -43,9 +43,9 @@ fisher_dir="/export/corpora3/LDC/LDC2004T19" # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + # 5000 + # 2000 + # 1000 500 ) @@ -73,6 +73,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then ./local/normalize_and_filter_supervisions.py \ data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 @@ -149,8 +150,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then | jq '.text' | sed 's/"//g' > $lang_dir/text fi - log "prepare dict" - ./local/swbd1_prepare_dict.sh $swbd1_dir + log "Prepare dict" + ./local/swbd1_prepare_dict.sh cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt # [noise] nsn # !sil sil @@ -336,6 +337,10 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then out_dir=data/lm_training_bpe_${vocab_size} mkdir -p $out_dir + if [ ! -f $out_dir/train.txt ]; then + tail -n 250000 data/lang_phone/input.txt > $out_dir/train.txt + fi + ./local/prepare_lm_training_data.py \ --bpe-model $lang_dir/bpe.model \ --lm-data data/lang_phone/input.txt \ @@ -343,29 +348,29 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then done fi -# if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then -# log "Stage 12: Generate LM validation data" +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" -# for vocab_size in ${vocab_sizes[@]}; do -# log "Processing vocab_size == ${vocab_size}" -# out_dir=data/lm_training_bpe_${vocab_size} -# mkdir -p $out_dir + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir -# if [ ! -f $out_dir/valid.txt ]; then -# TODO: generate valid.txt -# fi + if [ ! -f $out_dir/valid.txt ]; then + head -n 14332 data/lang_phone/input.txt > $out_dir/valid.txt + fi -# lang_dir=data/lang_bpe_${vocab_size} -# ./local/prepare_lm_training_data.py \ -# --bpe-model $lang_dir/bpe.model \ -# --lm-data $out_dir/valid.txt \ -# --lm-archive $out_dir/lm_data-valid.pt -# done -# fi + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then log "Stage 13: Generate LM test data" - testsets=(eval2000 rt03) + testsets=(eval2000) for testset in ${testsets[@]}; do for vocab_size in ${vocab_sizes[@]}; do @@ -373,8 +378,9 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then out_dir=data/lm_training_bpe_${vocab_size} mkdir -p $out_dir - if [ ! -f $out_dir/test.txt ]; then - cat data/local/${testset}/text | cut -d " " -f 2- >$out_dir/${testset}.txt + if [ ! -f $out_dir/${testset}.txt ]; then + gunzip -c data/manifests/${testset}/eval2000_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $out_dir/${testset}.txt fi lang_dir=data/lang_bpe_${vocab_size} @@ -388,7 +394,7 @@ fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "Stage 14: Sort LM training data" - testsets=(eval2000 rt03) + testsets=(eval2000) # Sort LM training data by sentence length in descending order # for ease of training. # From afe2f2bbcd75441d66b40ad82fdaadd52e14c006 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:22:20 +0800 Subject: [PATCH 21/60] added gitignore --- egs/swbd/ASR/.gitignore | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 egs/swbd/ASR/.gitignore diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore new file mode 100644 index 0000000000..35ceb0a7fc --- /dev/null +++ b/egs/swbd/ASR/.gitignore @@ -0,0 +1,2 @@ +switchboard_word_alignments.tar.gz +./swb_ms98_transcriptions/* From 9d848b16eb63cebfd21ce6a77cfd1f76c0ffad41 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:24:11 +0800 Subject: [PATCH 22/60] updated gitignore --- egs/swbd/ASR/.gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore index 35ceb0a7fc..088ab9c361 100644 --- a/egs/swbd/ASR/.gitignore +++ b/egs/swbd/ASR/.gitignore @@ -1,2 +1,2 @@ switchboard_word_alignments.tar.gz -./swb_ms98_transcriptions/* +egs/swbd/ASR/swb_ms98_transcriptions/ From e13b01a31345a6eabc7e9b8ce3d571d5ac7aa5f6 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:25:46 +0800 Subject: [PATCH 23/60] updated gitignore --- egs/swbd/ASR/.gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore index 088ab9c361..11d6749227 100644 --- a/egs/swbd/ASR/.gitignore +++ b/egs/swbd/ASR/.gitignore @@ -1,2 +1,2 @@ switchboard_word_alignments.tar.gz -egs/swbd/ASR/swb_ms98_transcriptions/ +./swb_ms98_transcriptions/ From f7fac705d5ae1340b2720de6faf1ff3e0bbd239f Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:22:28 +0800 Subject: [PATCH 24/60] minor updates --- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 16 +-- egs/swbd/ASR/conformer_ctc/train.py | 2 +- egs/swbd/ASR/local/compute_fbank_eval2000.py | 138 +++++++++++++++++++ egs/swbd/ASR/local/compute_fbank_swbd.py | 106 +++++++------- egs/swbd/ASR/prepare.sh | 45 +++++- 5 files changed, 242 insertions(+), 65 deletions(-) create mode 100755 egs/swbd/ASR/local/compute_fbank_eval2000.py diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index 41da5ab3df..7fb4515e11 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -225,8 +225,6 @@ def train_dataloaders( else: logging.info("Disable MUSAN") - cuts_train = cuts_train.trim_to_supervisions(keep_overlapping=False) - if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -393,18 +391,16 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_all_cuts(self) -> CutSet: logging.info("switchboard: About to get train cuts") - return ( - load_manifest_lazy(self.args.manifest_dir / "swbd" / "swbd_cuts_all.jsonl.gz") - .subset(last=2388) - ) + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(last=2388) @lru_cache() def dev_cuts(self) -> CutSet: logging.info("switchboard: About to get dev cuts") - return ( - load_manifest_lazy(self.args.manifest_dir / "swbd" / "swbd_cuts_all.jsonl.gz") - .subset(first=50) - ) + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(first=50) @lru_cache() def test_eval2000_cuts(self) -> CutSet: diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py index 370fa0f774..8ed837aaf2 100755 --- a/egs/swbd/ASR/conformer_ctc/train.py +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -697,7 +697,7 @@ def remove_short_and_long_utt(c: Cut): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 600.0 + return 1.0 <= c.duration train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py new file mode 100755 index 0000000000..c6e2890fff --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + manifests = { + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + } + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + dir_name="eval2000", + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py index a89130baec..dd82220c04 100755 --- a/egs/swbd/ASR/local/compute_fbank_swbd.py +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -70,18 +70,25 @@ def get_args(): help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", ) + parser.add_argument( + "--split-index", + type=int, + required=True, + ) + return parser.parse_args() def compute_fbank_switchboard( dir_name: str, + split_index: int, bpe_model: Optional[str] = None, dataset: Optional[str] = None, perturb_speed: Optional[bool] = True, ): src_dir = Path(f"data/manifests/{dir_name}") - output_dir = Path(f"data/fbank/{dir_name}") - num_jobs = min(15, os.cpu_count()) + output_dir = Path(f"data/fbank/{dir_name}_split16") + num_jobs = min(1, os.cpu_count()) num_mel_bins = 80 if bpe_model: @@ -96,54 +103,48 @@ def compute_fbank_switchboard( prefix = dir_name suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None + split_dir = Path("data/manifests/swbd_split16/") - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=8000)) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], + partition = "all" + cuts_filename = ( + f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" + ) + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = ( + CutSet.from_file( + split_dir + / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" ) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 2.0) + ) - if "train" in partition: - if bpe_model: - cut_set = filter_cuts(cut_set, sp) - if perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None, + ) + cut_set.to_file(output_dir / cuts_filename) if __name__ == "__main__": @@ -152,10 +153,11 @@ def compute_fbank_switchboard( logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - for dir_name in ["swbd", "eval2000"]: - compute_fbank_switchboard( - dir_name=dir_name, - bpe_model=args.bpe_model, - dataset=args.dataset, - perturb_speed=args.perturb_speed, - ) + + compute_fbank_switchboard( + dir_name="swbd", + split_index=args.split_index, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 1ac5e8a1ec..50af972253 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -76,10 +76,37 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz + lhotse cut simple \ + -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ + -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/swbd/swbd_train_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz + + num_splits=16 + mkdir -p data/manifests/swbd_split${num_splits} + lhotse split ${num_splits} \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ + data/manifests/swbd_split${num_splits} + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 ./local/normalize_eval2000.py \ data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ + -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz + + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz # ./local/rt03_data_prep.sh $rt03_dir @@ -116,13 +143,27 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for SwitchBoard" - mkdir -p data/fbank if [ ! -e data/fbank/.swbd.done ]; then - ./local/compute_fbank_swbd.py + mkdir -p data/fbank/swbd_split${num_splits}/ + for index in $(seq 1 16); do + ./local/compute_fbank_swbd.py --split-index ${index} & + done + wait + pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") + lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz touch data/fbank/.swbd.done fi fi +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for eval2000" + if [ ! -e data/fbank/.eval2000.done ]; then + mkdir -p data/fbank/eval2000/ + ./local/compute_fbank_eval2000.py + touch data/fbank/.eval2000.done + fi +fi + if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for musan" mkdir -p data/fbank From 15f6dcff9a14afea651fb36b9931ace04ea5608c Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:36:47 +0800 Subject: [PATCH 25/60] minor updates --- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 5 +- .../ASR/local/display_manifest_statistics.py | 207 ++++-------------- 2 files changed, 43 insertions(+), 169 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index 7fb4515e11..d2d2e52eb7 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -99,7 +99,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--bucketing-sampler", type=str2bool, - default=True, + default=False, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) @@ -259,7 +259,7 @@ def train_dataloaders( num_frame_masks=num_frame_masks, features_mask_size=27, num_feature_masks=2, - frames_mask_size=100, + frames_mask_size=50, ) ) else: @@ -299,6 +299,7 @@ def train_dataloaders( shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=self.args.drop_last, + buffer_size=50000, ) else: logging.info("Using SingleCutSampler.") diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py index a6eb1e2b25..3a96dc9185 100755 --- a/egs/swbd/ASR/local/display_manifest_statistics.py +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -41,171 +41,44 @@ def main(): main() """ -## train-clean-100 -Cuts count: 85617 -Total duration (hours): 303.8 -Speech duration (hours): 303.8 (100.0%) -*** -Duration statistics (seconds): -mean 12.8 -std 3.8 -min 1.3 -0.1% 1.9 -0.5% 2.2 -1% 2.5 -5% 4.2 -10% 6.4 -25% 11.4 -50% 13.8 -75% 15.3 -90% 16.7 -95% 17.3 -99% 18.1 -99.5% 18.4 -99.9% 18.8 -max 27.2 - -## train-clean-360 -Cuts count: 312042 -Total duration (hours): 1098.2 -Speech duration (hours): 1098.2 (100.0%) -*** -Duration statistics (seconds): -mean 12.7 -std 3.8 -min 1.0 -0.1% 1.8 -0.5% 2.2 -1% 2.5 -5% 4.2 -10% 6.2 -25% 11.2 -50% 13.7 -75% 15.3 -90% 16.6 -95% 17.3 -99% 18.1 -99.5% 18.4 -99.9% 18.8 -max 33.0 - -## train-other 500 -Cuts count: 446064 -Total duration (hours): 1500.6 -Speech duration (hours): 1500.6 (100.0%) -*** -Duration statistics (seconds): -mean 12.1 -std 4.2 -min 0.8 -0.1% 1.7 -0.5% 2.1 -1% 2.3 -5% 3.5 -10% 5.0 -25% 9.8 -50% 13.4 -75% 15.1 -90% 16.5 -95% 17.2 -99% 18.1 -99.5% 18.4 -99.9% 18.9 -max 31.0 - -## dev-clean -Cuts count: 2703 -Total duration (hours): 5.4 -Speech duration (hours): 5.4 (100.0%) -*** -Duration statistics (seconds): -mean 7.2 -std 4.7 -min 1.4 -0.1% 1.6 -0.5% 1.8 -1% 1.9 -5% 2.4 -10% 2.7 -25% 3.8 -50% 5.9 -75% 9.3 -90% 13.3 -95% 16.4 -99% 23.8 -99.5% 28.5 -99.9% 32.3 -max 32.6 - -## dev-other -Cuts count: 2864 -Total duration (hours): 5.1 -Speech duration (hours): 5.1 (100.0%) -*** -Duration statistics (seconds): -mean 6.4 -std 4.3 -min 1.1 -0.1% 1.3 -0.5% 1.7 -1% 1.8 -5% 2.2 -10% 2.6 -25% 3.5 -50% 5.3 -75% 7.9 -90% 12.0 -95% 15.0 -99% 22.2 -99.5% 27.1 -99.9% 32.4 -max 35.2 - -## test-clean -Cuts count: 2620 -Total duration (hours): 5.4 -Speech duration (hours): 5.4 (100.0%) -*** -Duration statistics (seconds): -mean 7.4 -std 5.2 -min 1.3 -0.1% 1.6 -0.5% 1.8 -1% 2.0 -5% 2.3 -10% 2.7 -25% 3.7 -50% 5.8 -75% 9.6 -90% 14.6 -95% 17.8 -99% 25.5 -99.5% 28.4 -99.9% 32.8 -max 35.0 - -## test-other -Cuts count: 2939 -Total duration (hours): 5.3 -Speech duration (hours): 5.3 (100.0%) -*** -Duration statistics (seconds): -mean 6.5 -std 4.4 -min 1.2 -0.1% 1.5 -0.5% 1.8 -1% 1.9 -5% 2.3 -10% 2.6 -25% 3.4 -50% 5.2 -75% 8.2 -90% 12.6 -95% 15.8 -99% 21.4 -99.5% 23.8 -99.9% 33.5 -max 34.5 +Cut statistics: +╒═══════════════════════════╤═══════════╕ +│ Cuts count: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Total duration (hh:mm:ss) │ 281:01:26 │ +├───────────────────────────┼───────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼───────────┤ +│ std │ 3.3 │ +├───────────────────────────┼───────────┤ +│ min │ 2.0 │ +├───────────────────────────┼───────────┤ +│ 25% │ 3.2 │ +├───────────────────────────┼───────────┤ +│ 50% │ 5.2 │ +├───────────────────────────┼───────────┤ +│ 75% │ 8.3 │ +├───────────────────────────┼───────────┤ +│ 99% │ 14.4 │ +├───────────────────────────┼───────────┤ +│ 99.5% │ 14.7 │ +├───────────────────────────┼───────────┤ +│ 99.9% │ 15.0 │ +├───────────────────────────┼───────────┤ +│ max │ 57.5 │ +├───────────────────────────┼───────────┤ +│ Recordings available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Features available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Supervisions available: │ 167244 │ +╘═══════════════════════════╧═══════════╛ +Speech duration statistics: +╒══════════════════════════════╤═══════════╤══════════════════════╕ +│ Total speech duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total speaking time duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧═══════════╧══════════════════════╛ """ From 9594efd782193810bbc5e9656b3e08c8a4c7660f Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:44:16 +0800 Subject: [PATCH 26/60] minor fixes --- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 12 ++++++------ egs/swbd/ASR/prepare.sh | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index d2d2e52eb7..fc011bb5fc 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -391,26 +391,26 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def train_all_cuts(self) -> CutSet: - logging.info("switchboard: About to get train cuts") + logging.info("SwitchBoard: About to get train cuts") return load_manifest_lazy( self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" - ).subset(last=2388) + ).subset(last=166844) @lru_cache() def dev_cuts(self) -> CutSet: - logging.info("switchboard: About to get dev cuts") + logging.info("SwitchBoard: About to get dev cuts") return load_manifest_lazy( self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" - ).subset(first=50) + ).subset(first=300) @lru_cache() def test_eval2000_cuts(self) -> CutSet: - logging.info("switchboard: About to get eval2000 cuts") + logging.info("SwitchBoard: About to get eval2000 cuts") return load_manifest_lazy( self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" ) @lru_cache() def test_rt03_cuts(self) -> CutSet: - logging.info("switchboard: About to get rt03 cuts") + logging.info("SwitchBoard: About to get rt03 cuts") return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 50af972253..72eeb002f7 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -142,7 +142,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for SwitchBoard" + log "Stage 3 I: Compute fbank for SwitchBoard" if [ ! -e data/fbank/.swbd.done ]; then mkdir -p data/fbank/swbd_split${num_splits}/ for index in $(seq 1 16); do @@ -156,7 +156,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for eval2000" + log "Stage 3 II: Compute fbank for eval2000" if [ ! -e data/fbank/.eval2000.done ]; then mkdir -p data/fbank/eval2000/ ./local/compute_fbank_eval2000.py From 60e974f41b06601f5dec82846d0191518a0d9e3a Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:48:14 +0800 Subject: [PATCH 27/60] minor updates --- egs/swbd/ASR/README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index 3b93a64cbe..0e0b2675fd 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -6,10 +6,11 @@ Switchboard is a collection of about 2,400 two-sided telephone conversations amo (The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) -**Caution**: The `conformer_ctc` recipe for Switchboard is currently very rough and has a high Word Error Rate, requiring more improvement and refinement. The TODO list for this recipe is as follows. +**Caution**: The `conformer_ctc` recipe for Switchboard is currently very rough and produces a high Word Error Rate, requiring more improvement and refinement. The TODO list for this recipe is as follows. ## TODO List -- [ ] Incorporate Lhotse for data processing +- [x] Incorporate Lhotse for data processing +- [x] Further text normalization - [ ] Refer to Global Mapping Rules when computing Word Error Rate - [x] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset - [ ] Switchboard transcript train/dev split for LM training @@ -27,3 +28,5 @@ See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored to incorporate with Lhotse and icefall. + +Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). From 7feaa6185d7e56a74c253eda80b563946b515447 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Wed, 23 Aug 2023 14:45:10 +0800 Subject: [PATCH 28/60] minor updates --- egs/swbd/ASR/conformer_ctc/decode.py | 14 +++--- egs/swbd/ASR/local/compute_fbank_eval2000.py | 10 +++- .../ASR/local/display_manifest_statistics.py | 47 +++++++++++++++++-- egs/swbd/ASR/local/prepare_lang_bpe.py | 3 -- egs/swbd/ASR/local/train_bpe_model.py | 2 + egs/swbd/ASR/prepare.sh | 4 +- 6 files changed, 64 insertions(+), 16 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 05d07f44fe..68b84a5f89 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -796,15 +796,17 @@ def main(): test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( keep_all_channels=True ) - test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( - keep_all_channels=True - ) + # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + # keep_all_channels=True + # ) test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) - test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) - test_sets = ["test-eval2000", "test-rt03"] - test_dl = [test_eval2000_dl, test_rt03_dl] + # test_sets = ["test-eval2000", "test-rt03"] + # test_dl = [test_eval2000_dl, test_rt03_dl] + test_sets = ["test-eval2000"] + test_dl = [test_eval2000_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py index c6e2890fff..fa23a9628e 100755 --- a/egs/swbd/ASR/local/compute_fbank_eval2000.py +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -97,7 +97,7 @@ def compute_fbank_switchboard( prefix = dir_name suffix = "jsonl.gz" manifests = { - "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + "eval2000": "data/manifests/eval2000/eval2000_cuts_all.jsonl.gz", } assert manifests is not None @@ -111,7 +111,12 @@ def compute_fbank_switchboard( logging.info(f"{prefix} already exists - skipping.") return logging.info(f"Processing {prefix}") - cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + cut_set = ( + CutSet.from_file(manifests[prefix]) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 0.5) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -121,6 +126,7 @@ def compute_fbank_switchboard( executor=ex, storage_type=LilcomChunkyWriter, ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) cut_set.to_file(output_dir / cuts_filename) diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py index 3a96dc9185..48b7c40343 100755 --- a/egs/swbd/ASR/local/display_manifest_statistics.py +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -30,8 +30,8 @@ def main(): # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" - # path = "./data/fbank/swbd_cuts_eval2000.jsonl.gz" - path = "./data/fbank/swbd_cuts_all.jsonl.gz" + path = "./data/fbank/eval2000/eval2000_cuts_all.jsonl.gz" + # path = "./data/fbank/swbd_cuts_all.jsonl.gz" cuts = load_manifest_lazy(path) cuts.describe() @@ -41,7 +41,7 @@ def main(): main() """ -Cut statistics: +Training Cut statistics: ╒═══════════════════════════╤═══════════╕ │ Cuts count: │ 167244 │ ├───────────────────────────┼───────────┤ @@ -81,4 +81,45 @@ def main(): ├──────────────────────────────┼───────────┼──────────────────────┤ │ Total silence duration │ 00:00:00 │ 0.00% of recording │ ╘══════════════════════════════╧═══════════╧══════════════════════╛ + +Eval2000 Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 2709 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:39:19 │ +├───────────────────────────┼──────────┤ +│ mean │ 2.2 │ +├───────────────────────────┼──────────┤ +│ std │ 1.8 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 0.7 │ +├───────────────────────────┼──────────┤ +│ 50% │ 1.7 │ +├───────────────────────────┼──────────┤ +│ 75% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 8.3 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 11.3 │ +├───────────────────────────┼──────────┤ +│ max │ 14.1 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 2709 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 2709 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:39:19 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:39:19 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ """ diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py index 0fd937bc93..d82a085ece 100755 --- a/egs/swbd/ASR/local/prepare_lang_bpe.py +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -216,9 +216,6 @@ def main(): "#0", "", "", - "[VOCALIZED-NOISE]", - "[NOISE]", - "[LAUGHTER]", ] for w in excluded: diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py index 43142aee4c..9b4e286359 100755 --- a/egs/swbd/ASR/local/train_bpe_model.py +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -75,6 +75,8 @@ def main(): # If you change it, you should also change other # places that are using it. + user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] + model_file = Path(model_prefix + ".model") if not model_file.is_file(): spm.SentencePieceTrainer.train( diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 72eeb002f7..b07260f5b6 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -45,7 +45,7 @@ fisher_dir="/export/corpora3/LDC/LDC2004T19" vocab_sizes=( # 5000 # 2000 - # 1000 + 1000 500 ) @@ -197,7 +197,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then # [noise] nsn # !sil sil # spn - cat data/local/dict_nosp/lexicon.txt | + cat data/local/dict_nosp/lexicon.txt | sed 's/-//g' | sed 's/\[vocalizednoise\]/\[vocalized-noise\]/g' | sort | uniq >$lang_dir/lexicon_lower.txt cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt From 3672d23631ad94925f7321e0ee405652910e7532 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Thu, 24 Aug 2023 13:19:48 +0800 Subject: [PATCH 29/60] updated --- egs/swbd/ASR/local/compute_fbank_eval2000.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py index fa23a9628e..d446e8ff3a 100755 --- a/egs/swbd/ASR/local/compute_fbank_eval2000.py +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -97,7 +97,7 @@ def compute_fbank_switchboard( prefix = dir_name suffix = "jsonl.gz" manifests = { - "eval2000": "data/manifests/eval2000/eval2000_cuts_all.jsonl.gz", + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", } assert manifests is not None @@ -111,12 +111,7 @@ def compute_fbank_switchboard( logging.info(f"{prefix} already exists - skipping.") return logging.info(f"Processing {prefix}") - cut_set = ( - CutSet.from_file(manifests[prefix]) - .resample(16000) - .to_eager() - .filter(lambda c: c.duration > 0.5) - ) + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) cut_set = cut_set.compute_and_store_features( extractor=extractor, From d335152fd2cc38f21a7d9eda5895a16f8a2c9c17 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 25 Aug 2023 19:30:05 +0800 Subject: [PATCH 30/60] minor updates --- egs/swbd/ASR/README.md | 2 +- .../ASR/local/display_manifest_statistics.py | 32 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index 0e0b2675fd..e391860b95 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -27,6 +27,6 @@ See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. -A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored to incorporate with Lhotse and icefall. +A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py index 48b7c40343..9aa2048632 100755 --- a/egs/swbd/ASR/local/display_manifest_statistics.py +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -84,41 +84,41 @@ def main(): Eval2000 Cut statistics: ╒═══════════════════════════╤══════════╕ -│ Cuts count: │ 2709 │ +│ Cuts count: │ 4473 │ ├───────────────────────────┼──────────┤ -│ Total duration (hh:mm:ss) │ 01:39:19 │ +│ Total duration (hh:mm:ss) │ 03:37:13 │ ├───────────────────────────┼──────────┤ -│ mean │ 2.2 │ +│ mean │ 2.9 │ ├───────────────────────────┼──────────┤ -│ std │ 1.8 │ +│ std │ 2.6 │ ├───────────────────────────┼──────────┤ │ min │ 0.1 │ ├───────────────────────────┼──────────┤ -│ 25% │ 0.7 │ +│ 25% │ 1.2 │ ├───────────────────────────┼──────────┤ -│ 50% │ 1.7 │ +│ 50% │ 2.1 │ ├───────────────────────────┼──────────┤ -│ 75% │ 3.1 │ +│ 75% │ 4.0 │ ├───────────────────────────┼──────────┤ -│ 99% │ 8.0 │ +│ 99% │ 12.6 │ ├───────────────────────────┼──────────┤ -│ 99.5% │ 8.3 │ +│ 99.5% │ 13.7 │ ├───────────────────────────┼──────────┤ -│ 99.9% │ 11.3 │ +│ 99.9% │ 14.7 │ ├───────────────────────────┼──────────┤ -│ max │ 14.1 │ +│ max │ 15.5 │ ├───────────────────────────┼──────────┤ -│ Recordings available: │ 2709 │ +│ Recordings available: │ 4473 │ ├───────────────────────────┼──────────┤ -│ Features available: │ 0 │ +│ Features available: │ 4473 │ ├───────────────────────────┼──────────┤ -│ Supervisions available: │ 2709 │ +│ Supervisions available: │ 4473 │ ╘═══════════════════════════╧══════════╛ Speech duration statistics: ╒══════════════════════════════╤══════════╤══════════════════════╕ -│ Total speech duration │ 01:39:19 │ 100.00% of recording │ +│ Total speech duration │ 03:37:13 │ 100.00% of recording │ ├──────────────────────────────┼──────────┼──────────────────────┤ -│ Total speaking time duration │ 01:39:19 │ 100.00% of recording │ +│ Total speaking time duration │ 03:37:13 │ 100.00% of recording │ ├──────────────────────────────┼──────────┼──────────────────────┤ │ Total silence duration │ 00:00:00 │ 0.00% of recording │ ╘══════════════════════════════╧══════════╧══════════════════════╛ From 1715567f57a343b85345f260151e40b54488b0fa Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Sat, 2 Sep 2023 01:41:29 +0800 Subject: [PATCH 31/60] Update train_bpe_model.py --- egs/swbd/ASR/local/train_bpe_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py index 9b4e286359..350b61b42d 100755 --- a/egs/swbd/ASR/local/train_bpe_model.py +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -75,7 +75,7 @@ def main(): # If you change it, you should also change other # places that are using it. - user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] + user_defined_symbols += ["▁[LAUGHTER]", "▁[NOISE]", "▁[VOCALIZED-NOISE]"] model_file = Path(model_prefix + ".model") if not model_file.is_file(): From fae87b3009e0b56ea46fce675c82864511d2280d Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:09:32 +0800 Subject: [PATCH 32/60] minor updates --- egs/swbd/ASR/README.md | 6 +- egs/swbd/ASR/conformer_ctc/decode.py | 20 +++- egs/swbd/ASR/conformer_ctc/sclite_scoring.py | 111 +++++++++++++++++++ egs/swbd/ASR/prepare.sh | 4 + 4 files changed, 137 insertions(+), 4 deletions(-) create mode 100755 egs/swbd/ASR/conformer_ctc/sclite_scoring.py diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index e391860b95..10e38c501c 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -11,10 +11,8 @@ Switchboard is a collection of about 2,400 two-sided telephone conversations amo ## TODO List - [x] Incorporate Lhotse for data processing - [x] Further text normalization -- [ ] Refer to Global Mapping Rules when computing Word Error Rate - [x] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset -- [ ] Switchboard transcript train/dev split for LM training -- [ ] Fisher corpus LDC2004T19 LDC2005T19 LDC2004S13 LDC2005S13 for LM training +- [x] Switchboard transcript train/dev split for LM training ## Performance Record | | eval2000 | rt03 | @@ -30,3 +28,5 @@ The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ct A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). + +The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 68b84a5f89..396a264a28 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -30,6 +30,8 @@ from asr_datamodule import SwitchBoardAsrDataModule from conformer import Conformer +from sclite_scoring import asr_text_post_processing + from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.decode import ( @@ -233,6 +235,17 @@ def get_params() -> AttributeDict: return params +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -591,6 +604,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = post_processing(results) results = ( sorted(list(filter(lambda x: x[0].startswith(prefix), results))) if subset != "avg" @@ -605,7 +619,11 @@ def save_results( errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{subset}-{key}", results, enable_log=enable_log + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=enable_log, + sclite_mode=True, ) test_set_wers[key] = wer diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py new file mode 100755 index 0000000000..0ad1fc2e92 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]" +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index b07260f5b6..f3299353ba 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -108,6 +108,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz + sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm + cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm + # ./local/rt03_data_prep.sh $rt03_dir # normalize eval2000 and rt03 texts by From b45d83fa5346d08ac6086183303afd6b5f842544 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:17:46 +0800 Subject: [PATCH 33/60] minor update on text norm --- egs/swbd/ASR/local/normalize_eval2000.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py index 38451f5063..7316193d03 100755 --- a/egs/swbd/ASR/local/normalize_eval2000.py +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -145,6 +145,7 @@ def replace_silphone(text: str) -> str: text = text.replace("[NOISE OF MOVING PHONE]", " ") text = text.replace("[SOUND OF RUNNING WATER]", " ") text = text.replace("[CHANNEL]", " ") + text = text.replace("[SILENCE]", " ") text = text.replace("-[W]HERE", "WHERE") text = text.replace("Y[OU]I-", "YOU I") text = text.replace("-[A]ND", "AND") From 675194a6511afb212321114a4ac0c48d5099005f Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:27:30 +0800 Subject: [PATCH 34/60] fixed a formatting issue --- egs/swbd/ASR/conformer_ctc/sclite_scoring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py index 0ad1fc2e92..0d7770ddda 100755 --- a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -46,7 +46,7 @@ "[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]", - "[SILENCE]" + "[SILENCE]", ] non_scoring_words = ( conversational_filler + unk_tags + switchboard_garbage_utterance_tags From 7c44c3a2c09cf011c56804a04f6c9882a8667e9b Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Mon, 4 Sep 2023 12:42:49 +0800 Subject: [PATCH 35/60] default params updated --- egs/swbd/ASR/conformer_ctc/decode.py | 2 +- egs/swbd/ASR/conformer_ctc/export.py | 4 ++-- egs/swbd/ASR/conformer_ctc/train.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 396a264a28..2bbade3744 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -66,7 +66,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=98, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py index fbcbd7b29a..1bb6277ad9 100755 --- a/egs/swbd/ASR/conformer_ctc/export.py +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -39,7 +39,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=34, + default=98, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -47,7 +47,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=20, + default=55, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py index 8ed837aaf2..7f1eebbcf2 100755 --- a/egs/swbd/ASR/conformer_ctc/train.py +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -93,7 +93,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=78, + default=98, help="Number of epochs to train.", ) From 93e22a96a495e5b676414326e1ae3514a5332abb Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:44:31 +0800 Subject: [PATCH 36/60] Update RESULTS.md --- egs/swbd/ASR/RESULTS.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md index 68a843c014..d285261acc 100644 --- a/egs/swbd/ASR/RESULTS.md +++ b/egs/swbd/ASR/RESULTS.md @@ -1,6 +1,18 @@ ## Results ### Switchboard BPE training results (Conformer-CTC) +#### 2023-09-04 + +The best WER, as of 2023-09-04, for the Switchboard is below + +Results using attention decoder is given as: + +| | eval2000-swbd | eval2000-callhome | eval2000-avg | +|--------------------------------|-----------------|---------------------|--------------| +| `conformer_ctc` | 9.48 | 17.73 | 13.67 | + +Decoding results and models can be found here: +https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 #### 2023-06-27 The best WER, as of 2023-06-27, for the Switchboard is below From c61b3ac9147c694b1743071b54d2ac6c32ea202d Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 6 Sep 2023 12:09:02 +0800 Subject: [PATCH 37/60] Update prepare.sh --- egs/swbd/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index f3299353ba..47d12613bf 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -31,7 +31,7 @@ swbd1_dir=./download/LDC97S62/ # - LDC2002S09 # - hub5e_00 # - LDC2002T43 -# - 2000_hub5_eng_eval_tr +# - reference eval2000_dir="/export/corpora2/LDC/eval2000" rt03_dir="/export/corpora/LDC/LDC2007S10" From 098a70428d32da40af42f0cf080d90aeafa674c7 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 13 Sep 2023 01:39:02 +0800 Subject: [PATCH 38/60] Update train_bpe_model.py --- egs/swbd/ASR/local/train_bpe_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py index 350b61b42d..9b4e286359 100755 --- a/egs/swbd/ASR/local/train_bpe_model.py +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -75,7 +75,7 @@ def main(): # If you change it, you should also change other # places that are using it. - user_defined_symbols += ["▁[LAUGHTER]", "▁[NOISE]", "▁[VOCALIZED-NOISE]"] + user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] model_file = Path(model_prefix + ".model") if not model_file.is_file(): From 841a153a01f107ea887ec8c535bfbdf272106636 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:48:44 +0800 Subject: [PATCH 39/60] zipformer recipe for swbd --- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 2 +- egs/swbd/ASR/conformer_ctc/sclite_scoring.py | 29 +- egs/swbd/ASR/zipformer/.gitignore | 1 + egs/swbd/ASR/zipformer/asr_datamodule.py | 1 + egs/swbd/ASR/zipformer/beam_search.py | 1 + egs/swbd/ASR/zipformer/ctc_decode.py | 847 +++++++ egs/swbd/ASR/zipformer/decode.py | 1045 ++++++++ egs/swbd/ASR/zipformer/decode_stream.py | 148 ++ egs/swbd/ASR/zipformer/decoder.py | 122 + egs/swbd/ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 774 ++++++ egs/swbd/ASR/zipformer/export-onnx.py | 619 +++++ egs/swbd/ASR/zipformer/export.py | 526 ++++ .../ASR/zipformer/generate_averaged_model.py | 193 ++ egs/swbd/ASR/zipformer/jit_pretrained.py | 280 ++ egs/swbd/ASR/zipformer/jit_pretrained_ctc.py | 436 ++++ .../ASR/zipformer/jit_pretrained_streaming.py | 273 ++ egs/swbd/ASR/zipformer/joiner.py | 66 + egs/swbd/ASR/zipformer/model.py | 358 +++ egs/swbd/ASR/zipformer/onnx_check.py | 240 ++ egs/swbd/ASR/zipformer/onnx_decode.py | 323 +++ .../zipformer/onnx_pretrained-streaming.py | 543 ++++ egs/swbd/ASR/zipformer/onnx_pretrained.py | 418 +++ egs/swbd/ASR/zipformer/optim.py | 1173 +++++++++ egs/swbd/ASR/zipformer/pretrained.py | 381 +++ egs/swbd/ASR/zipformer/pretrained_ctc.py | 455 ++++ egs/swbd/ASR/zipformer/profile.py | 176 ++ egs/swbd/ASR/zipformer/scaling.py | 1840 ++++++++++++++ egs/swbd/ASR/zipformer/scaling_converter.py | 104 + egs/swbd/ASR/zipformer/sclite_scoring.py | 138 + .../ASR/zipformer/streaming_beam_search.py | 295 +++ egs/swbd/ASR/zipformer/streaming_decode.py | 876 +++++++ egs/swbd/ASR/zipformer/subsampling.py | 410 +++ egs/swbd/ASR/zipformer/test_scaling.py | 82 + egs/swbd/ASR/zipformer/test_subsampling.py | 152 ++ egs/swbd/ASR/zipformer/train.py | 1383 ++++++++++ egs/swbd/ASR/zipformer/zipformer.py | 2254 +++++++++++++++++ 37 files changed, 16963 insertions(+), 2 deletions(-) create mode 100644 egs/swbd/ASR/zipformer/.gitignore create mode 120000 egs/swbd/ASR/zipformer/asr_datamodule.py create mode 120000 egs/swbd/ASR/zipformer/beam_search.py create mode 100755 egs/swbd/ASR/zipformer/ctc_decode.py create mode 100755 egs/swbd/ASR/zipformer/decode.py create mode 100644 egs/swbd/ASR/zipformer/decode_stream.py create mode 100644 egs/swbd/ASR/zipformer/decoder.py create mode 120000 egs/swbd/ASR/zipformer/encoder_interface.py create mode 100755 egs/swbd/ASR/zipformer/export-onnx-streaming.py create mode 100755 egs/swbd/ASR/zipformer/export-onnx.py create mode 100755 egs/swbd/ASR/zipformer/export.py create mode 100755 egs/swbd/ASR/zipformer/generate_averaged_model.py create mode 100755 egs/swbd/ASR/zipformer/jit_pretrained.py create mode 100755 egs/swbd/ASR/zipformer/jit_pretrained_ctc.py create mode 100755 egs/swbd/ASR/zipformer/jit_pretrained_streaming.py create mode 100644 egs/swbd/ASR/zipformer/joiner.py create mode 100644 egs/swbd/ASR/zipformer/model.py create mode 100755 egs/swbd/ASR/zipformer/onnx_check.py create mode 100755 egs/swbd/ASR/zipformer/onnx_decode.py create mode 100755 egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py create mode 100755 egs/swbd/ASR/zipformer/onnx_pretrained.py create mode 100644 egs/swbd/ASR/zipformer/optim.py create mode 100755 egs/swbd/ASR/zipformer/pretrained.py create mode 100755 egs/swbd/ASR/zipformer/pretrained_ctc.py create mode 100755 egs/swbd/ASR/zipformer/profile.py create mode 100644 egs/swbd/ASR/zipformer/scaling.py create mode 100644 egs/swbd/ASR/zipformer/scaling_converter.py create mode 100755 egs/swbd/ASR/zipformer/sclite_scoring.py create mode 100644 egs/swbd/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/swbd/ASR/zipformer/streaming_decode.py create mode 100644 egs/swbd/ASR/zipformer/subsampling.py create mode 100755 egs/swbd/ASR/zipformer/test_scaling.py create mode 100755 egs/swbd/ASR/zipformer/test_subsampling.py create mode 100755 egs/swbd/ASR/zipformer/train.py create mode 100644 egs/swbd/ASR/zipformer/zipformer.py diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index fc011bb5fc..59d73c660d 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -99,7 +99,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--bucketing-sampler", type=str2bool, - default=False, + default=True, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py index 0d7770ddda..fe0fbbeeea 100755 --- a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -40,6 +40,8 @@ "HUM", "AW", "OH", + "HMM", + "UMM", ] unk_tags = ["", ""] switchboard_garbage_utterance_tags = [ @@ -59,9 +61,34 @@ def asr_text_post_processing(text: str) -> str: # 2. remove non-scoring words from evaluation remaining_words = [] - for word in text.split(): + text_split = text.split() + word_to_skip = 0 + for idx, word in enumerate(text_split): + if word_to_skip > 0: + word_to_skip -= 1 + continue if word in non_scoring_words: continue + elif word == "CANCELLED": + remaining_words.append("CANCELED") + continue + elif word == "AIRFLOW": + remaining_words.append("AIR") + remaining_words.append("FLOW") + continue + elif word == "PHD": + remaining_words.append("P") + remaining_words.append("H") + remaining_words.append("D") + continue + elif word == "ONTO": + remaining_words.append("ON") + remaining_words.append("TO") + continue + elif word == "DAY" and text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + continue remaining_words.append(word) return " ".join(remaining_words) diff --git a/egs/swbd/ASR/zipformer/.gitignore b/egs/swbd/ASR/zipformer/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/swbd/ASR/zipformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/swbd/ASR/zipformer/asr_datamodule.py b/egs/swbd/ASR/zipformer/asr_datamodule.py new file mode 120000 index 0000000000..a73848de92 --- /dev/null +++ b/egs/swbd/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/swbd/ASR/zipformer/beam_search.py b/egs/swbd/ASR/zipformer/beam_search.py new file mode 120000 index 0000000000..e24eca39f2 --- /dev/null +++ b/egs/swbd/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/swbd/ASR/zipformer/ctc_decode.py b/egs/swbd/ASR/zipformer/ctc_decode.py new file mode 100755 index 0000000000..4db50b9813 --- /dev/null +++ b/egs/swbd/ASR/zipformer/ctc_decode.py @@ -0,0 +1,847 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(3) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(4) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(5) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/decode.py b/egs/swbd/ASR/zipformer/decode.py new file mode 100755 index 0000000000..10de28dfcc --- /dev/null +++ b/egs/swbd/ASR/zipformer/decode.py @@ -0,0 +1,1045 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params +from sclite_scoring import asr_text_post_processing + + +from lhotse.cut import Cut +from icefall import LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"beam_size_{params.beam_size}_{key}"] = hyps + return ans + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir + / f"recogs-{test_set_name}-{subset}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) + ) + results = post_processing(results) + + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir + / f"errs-{test_set_name}-{subset}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=True, + sclite_mode=True, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{subset}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + switchboard = SwitchBoardAsrDataModule(args) + + test_eval2000_cuts = ( + switchboard.test_eval2000_cuts() + .trim_to_supervisions(keep_all_channels=True) + .filter(remove_short_utt) + ) + # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + # keep_all_channels=True + # ) + + test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) + # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + + # test_sets = ["test-eval2000", "test-rt03"] + # test_dl = [test_eval2000_dl, test_rt03_dl] + test_dl = [test_eval2000_dl] + test_sets = ["test-eval2000"] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/decode_stream.py b/egs/swbd/ASR/zipformer/decode_stream.py new file mode 100644 index 0000000000..d6918bf328 --- /dev/null +++ b/egs/swbd/ASR/zipformer/decode_stream.py @@ -0,0 +1,148 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, at encoder output + self.done_frames: int = 0 + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + if params.decoding_method == "greedy_search": + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[-1] * (params.context_size - 1) + [params.blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/swbd/ASR/zipformer/decoder.py b/egs/swbd/ASR/zipformer/decoder.py new file mode 100644 index 0000000000..e8db988f6e --- /dev/null +++ b/egs/swbd/ASR/zipformer/decoder.py @@ -0,0 +1,122 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scaling import Balancer + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + ) + # the balancers are to avoid any drift in the magnitude of the + # embeddings, which would interact badly with parameter averaging. + self.balancer = Balancer(decoder_dim, channel_dim=-1, + min_positive=0.0, max_positive=1.0, + min_abs=0.5, max_abs=1.0, + prob=0.05) + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + self.balancer2 = Balancer(decoder_dim, channel_dim=-1, + min_positive=0.0, max_positive=1.0, + min_abs=0.5, max_abs=1.0, + prob=0.05) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + + embedding_out = self.balancer(embedding_out) + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + embedding_out = self.balancer2(embedding_out) + + return embedding_out diff --git a/egs/swbd/ASR/zipformer/encoder_interface.py b/egs/swbd/ASR/zipformer/encoder_interface.py new file mode 120000 index 0000000000..653c5b09af --- /dev/null +++ b/egs/swbd/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/swbd/ASR/zipformer/export-onnx-streaming.py b/egs/swbd/ASR/zipformer/export-onnx-streaming.py new file mode 100755 index 0000000000..a951aeef35 --- /dev/null +++ b/egs/swbd/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1,774 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +The --chunk-size in training is "16,32,64,-1", so we select one of them +(excluding -1) during streaming export. The same applies to `--left-context`, +whose value is "64,128,256,-1". + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1-chunk-16-left-64.onnx + - decoder-epoch-99-avg-1-chunk-16-left-64.onnx + - joiner-epoch-99-avg-1-chunk-16-left-64.onnx + +See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + self.pad_length = 7 + 2 * 3 + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + N = x.size(0) + T = self.chunk_size * 2 + self.pad_length + x_lens = torch.tensor([T] * N, device=x.device) + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=x, + x_lens=x_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) + + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) + encoder_states = states[:-2] + logging.info(f"len_encoder_states={len(encoder_states)}") + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + + return encoder_out, new_states + + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) + states.append(processed_lens) + + return states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.chunk_size * 2 + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + T = decode_chunk_len + encoder_model.pad_length + + x = torch.rand(1, T, 80, dtype=torch.float32) + init_state = encoder_model.get_init_states() + num_encoders = len(encoder_model.encoder.encoder_dim) + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel)) + ds = encoder_model.encoder.downsampling_factor + left_context_len = encoder_model.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim)) + value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim)) + num_heads = ",".join(map(str, encoder_model.encoder.num_heads)) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "streaming zipformer2", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 32+7+2*3=45 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + "query_head_dims": query_head_dims, + "value_head_dims": value_head_dims, + "num_heads": num_heads, + } + logging.info(f"meta_data: {meta_data}") + + for i in range(len(init_state[:-2]) // 6): + build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + embed_states = init_state[-2] + name = "embed_states" + logging.info(f"{name}.shape: {embed_states.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = "processed_lens" + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/export-onnx.py b/egs/swbd/ASR/zipformer/export-onnx.py new file mode 100755 index 0000000000..e0d664009c --- /dev/null +++ b/egs/swbd/ASR/zipformer/export-onnx.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/export.py b/egs/swbd/ASR/zipformer/export.py new file mode 100755 index 0000000000..2b8d1aaf36 --- /dev/null +++ b/egs/swbd/ASR/zipformer/export.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/generate_averaged_model.py b/egs/swbd/ASR/zipformer/generate_averaged_model.py new file mode 100755 index 0000000000..68111fad71 --- /dev/null +++ b/egs/swbd/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +./zipformer/generate_averaged_model.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(2) use the checkpoint exp_dir/checkpoint-iter.pt +./zipformer/generate_averaged_model.py \ + --iter 22000 \ + --avg 5 \ + --exp-dir ./zipformer/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + symbol_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = symbol_table[""] + params.unk_id = symbol_table[""] + params.vocab_size = len(symbol_table) + + print("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/jit_pretrained.py b/egs/swbd/ASR/zipformer/jit_pretrained.py new file mode 100755 index 0000000000..a41fbc1c97 --- /dev/null +++ b/egs/swbd/ASR/zipformer/jit_pretrained.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer/exp/cpu_jit.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = model.decoder.blank_id + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/jit_pretrained_ctc.py b/egs/swbd/ASR/zipformer/jit_pretrained_ctc.py new file mode 100755 index 0000000000..660a4bfc60 --- /dev/null +++ b/egs/swbd/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./zipformer/jit_pretrained_ctc.py \ + --model-filename ./zipformer/exp/jit_script.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./zipformer/jit_pretrained_ctc.py \ + --model-filename ./zipformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./zipformer/jit_pretrained_ctc.py \ + --model-filename ./zipformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) whole-lattice-rescoring +./zipformer/jit_pretrained_ctc.py \ + --model-filename ./zipformer/exp/jit_script.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from ctc_decode import get_decoding_params +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a token table, + i.e., lang_dir/token.txt, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + nbest n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + whole-lattice n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(features, feature_lengths) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = ctc_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i].item() // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = params.vocab_size - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = [[token_table[i] for i in ids] for ids in token_ids] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/jit_pretrained_streaming.py b/egs/swbd/ASR/zipformer/jit_pretrained_streaming.py new file mode 100755 index 0000000000..d4ceacefd3 --- /dev/null +++ b/egs/swbd/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained_streaming.py \ + --nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List, Optional + +import k2 +import kaldifeat +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model jit_script.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, + device: torch.device = torch.device("cpu"), +): + assert encoder_out.ndim == 2 + context_size = decoder.context_size + blank_id = decoder.blank_id + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor( + decoder_input, dtype=torch.int32, device=device + ).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.eval() + model.to(device) + + encoder = model.encoder + decoder = model.decoder + joiner = model.joiner + + token_table = k2.SymbolTable.from_file(args.tokens) + context_size = decoder.context_size + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + + chunk_length = encoder.chunk_size * 2 + T = chunk_length + encoder.pad_length + + logging.info(f"chunk_length: {chunk_length}") + logging.info(f"T: {T}") + + states = encoder.get_init_states(device=device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32, device=device) + encoder_out, out_lens, states = encoder( + features=frames, + feature_lengths=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device + ) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/joiner.py b/egs/swbd/ASR/zipformer/joiner.py new file mode 100644 index 0000000000..f03cc930ec --- /dev/null +++ b/egs/swbd/ASR/zipformer/joiner.py @@ -0,0 +1,66 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/swbd/ASR/zipformer/model.py b/egs/swbd/ASR/zipformer/model.py new file mode 100644 index 0000000000..f2f86af47a --- /dev/null +++ b/egs/swbd/ASR/zipformer/model.py @@ -0,0 +1,358 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos, make_pad_mask +from scaling import ScaledLinear + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.25 + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.25 + ) + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="sum", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + # Compute CTC loss + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss diff --git a/egs/swbd/ASR/zipformer/onnx_check.py b/egs/swbd/ASR/zipformer/onnx_check.py new file mode 100755 index 0000000000..93bd3a211c --- /dev/null +++ b/egs/swbd/ASR/zipformer/onnx_check.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./zipformer/export.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - jit_script.pt + +3. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./zipformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +import torch +from onnx_pretrained import OnnxModel + +from icefall import is_module_available + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") + + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T + + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) + + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) + + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + test_encoder(torch_model, onnx_model) + + logging.info("Test decoder") + test_decoder(torch_model, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_model, onnx_model) + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/onnx_decode.py b/egs/swbd/ASR/zipformer/onnx_decode.py new file mode 100755 index 0000000000..2aca36ca94 --- /dev/null +++ b/egs/swbd/ASR/zipformer/onnx_decode.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +2. Run this file + +./zipformer/onnx_decode.py \ + --exp-dir $repo/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats +from k2 import SymbolTable + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + The token table. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + The token table. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = SymbolTable.from_file(args.tokens) + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py b/egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py new file mode 100755 index 0000000000..500b2cd097 --- /dev/null +++ b/egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1,543 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + logging.info(f"encoder_meta={encoder_meta}") + + model_type = encoder_meta["model_type"] + assert model_type == "zipformer2", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + query_head_dims = encoder_meta["query_head_dims"] + value_head_dims = encoder_meta["value_head_dims"] + num_heads = encoder_meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + encoder_input[name] = tensors[0] + encoder_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + encoder_input[name] = tensors[1] + encoder_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + encoder_input[name] = tensors[2] + encoder_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + encoder_input[name] = tensors[3] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + encoder_input[name] = tensors[4] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + encoder_input[name] = tensors[5] + encoder_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + encoder_input[name] = embed_states + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + encoder_input[name] = processed_lens + encoder_output.append(f"new_{name}") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + token_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/onnx_pretrained.py b/egs/swbd/ASR/zipformer/onnx_pretrained.py new file mode 100755 index 0000000000..032b07721a --- /dev/null +++ b/egs/swbd/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, joiner_dim) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.run_decoder(decoder_input) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) + + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = model.run_decoder(decoder_input) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/optim.py b/egs/swbd/ASR/zipformer/optim.py new file mode 100644 index 0000000000..abfb2092cd --- /dev/null +++ b/egs/swbd/ASR/zipformer/optim.py @@ -0,0 +1,1173 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + assert "named_params" in cur_group + name_list = [ x[0] for x in cur_group["named_params"] ] + p_list = [ x[1] for x in cur_group["named_params"] ] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, + (param_max_rms - param_rms) / param_rms) + + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/swbd/ASR/zipformer/pretrained.py b/egs/swbd/ASR/zipformer/pretrained.py new file mode 100755 index 0000000000..3104b60847 --- /dev/null +++ b/egs/swbd/ASR/zipformer/pretrained.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/pretrained_ctc.py b/egs/swbd/ASR/zipformer/pretrained_ctc.py new file mode 100755 index 0000000000..9dff2e6fc9 --- /dev/null +++ b/egs/swbd/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +(1) ctc-decoding +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from ctc_decode import get_decoding_params +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a token table, + i.e., lang_dir/tokens.txt, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + nbest n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + whole-lattice n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank + params.blank_id = token_table[""] + assert params.blank_id == 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + batch_size = ctc_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i].item() // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + max_token_id = params.vocab_size - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = [[token_table[i] for i in ids] for ids in token_ids] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/zipformer/profile.py b/egs/swbd/ASR/zipformer/profile.py new file mode 100755 index 0000000000..b460b53389 --- /dev/null +++ b/egs/swbd/ASR/zipformer/profile.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: ./zipformer/profile.py +""" + +import argparse +import logging +import sentencepiece as spm +import torch + +from typing import Tuple +from torch import Tensor, nn + +from icefall.utils import make_pad_mask +from icefall.profiler import get_model_profile +from scaling import BiasNorm +from train import ( + get_encoder_embed, + get_encoder_model, + get_joiner_model, + add_model_arguments, + get_params, +) +from zipformer import BypassModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + add_model_arguments(parser) + + return parser + + +def _bias_norm_flops_compute(module, input, output): + assert len(input) == 1, len(input) + # estimate as layer_norm, see icefall/profiler.py + flops = input[0].numel() * 5 + module.__flops__ += int(flops) + + +def _swoosh_module_flops_compute(module, input, output): + # For SwooshL and SwooshR modules + assert len(input) == 1, len(input) + # estimate as swish/silu, see icefall/profiler.py + flops = input[0].numel() + module.__flops__ += int(flops) + + +def _bypass_module_flops_compute(module, input, output): + # For Bypass module + assert len(input) == 2, len(input) + flops = input[0].numel() * 2 + module.__flops__ += int(flops) + + +MODULE_HOOK_MAPPING = { + BiasNorm: _bias_norm_flops_compute, + BypassModule: _bypass_module_flops_compute, +} + + +class Model(nn.Module): + """A Wrapper for encoder, encoder_embed, and encoder_proj""" + + def __init__( + self, + encoder: nn.Module, + encoder_embed: nn.Module, + encoder_proj: nn.Module, + ) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, feature: Tensor, feature_lens: Tensor + ) -> Tuple[Tensor, Tensor]: + x, x_lens = self.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder( + x, x_lens, src_key_padding_mask + ) + + encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + logits = self.encoder_proj(encoder_out) + + return logits, encoder_out_lens + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + + # We only profile the encoder part + model = Model( + encoder=get_encoder_model(params), + encoder_embed=get_encoder_embed(params), + encoder_proj=get_joiner_model(params).encoder_proj, + ) + model.eval() + model.to(device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # for 30-second input + B, T, D = 1, 3000, 80 + feature = torch.ones(B, T, D, dtype=torch.float32).to(device) + feature_lens = torch.full((B,), T, dtype=torch.int64).to(device) + + flops, params = get_model_profile( + model=model, + args=(feature, feature_lens), + module_hoop_mapping=MODULE_HOOK_MAPPING, + ) + logging.info(f"For the encoder part, params: {params}, flops: {flops}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/zipformer/scaling.py b/egs/swbd/ASR/zipformer/scaling.py new file mode 100644 index 0000000000..7c98ef0456 --- /dev/null +++ b/egs/swbd/ASR/zipformer/scaling.py @@ -0,0 +1,1840 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple, Union +import logging +import k2 +from torch.cuda.amp import custom_fwd, custom_bwd +import random +import torch +import math +import torch.nn as nn +from torch import Tensor + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [ (float(x), float(y)) for x,y in args ] + for (x,y) in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear( + * [(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear( + * [(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear( (0, x) ) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear( (0, x) ) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, + p: 'PiecewiseLinear', + include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise linear + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) + y_vals1 = [ self(x) for x in x_vals ] + y_vals2 = [ p(x) for x in x_vals ] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [ self(x) for x in x_vals ] + y_vals2 = [ p(x) for x in x_vals ] + return ( PiecewiseLinear(* zip(x_vals, y_vals1)), + PiecewiseLinear(* zip(x_vals, y_vals2)) ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specify the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + def __init__(self, + *args, + default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + + def __float__(self): + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, + default=self.default) + else: + return ScheduledFloat(self.schedule + x.schedule, + default=self.default+x.default) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), + default=self.default) + else: + return ScheduledFloat(self.schedule.max(x.schedule), + default=max(self.default, x.default)) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = (x_abs < min_abs) + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = (x > self.cutoff) + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1-q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + ans, = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return x * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, + store_output_for_backprop: bool) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, + scales.detach(), bias.detach(), log_scale.detach()) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * + self.log_scale.exp()) + return x * scales + + log_scale = limit_param_value(self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training) + + return BiasNormFunction.apply(x, self.bias, log_scale, + self.channel_dim, + self.store_output_for_backprop) + + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, + initial_scale: float = 1.0, + **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + def __init__(self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d(in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True) + + self.chunkwise_conv = nn.Conv1d(in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_(self.causal_conv.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + + def forward(self, + x: Tensor, + chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + # half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, + num_channels, chunk_size) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape(batch_size, num_chunks, + num_channels, chunk_size).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros(channels, t, + device=left_edge.device, + dtype=left_edge.dtype) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Streaming Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + cache: cached left context of shape (batch_size, channels, left_pad) + """ + (batch_size, num_channels, seq_len) = x.shape + + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + # Pad cache + assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -left_pad:] + + x_causal = self.causal_conv(x) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size=seq_len) + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal, cache + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None]: + x, = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] + uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = (m_loss + r_loss) + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, + name: str = None) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, ::dim+1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, + num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + x: Tensor, + module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float,float]], + grad_scale: FloatLike): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, + x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device), None + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + x, = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value(x: Tensor, + min: float, max: float, + prob: float = 0.6, + training: bool = True): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = (y * (1 - s) + s) + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = (d * ((ceil - floor) / 255.0) + floor) + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, + p=float(self.p), + training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + ans, = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + + y.backward(gradient = torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation. + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + # return SwooshLFunction.apply(x) + +class SwooshLOnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation. + """ + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient = torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation. + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) + +class SwooshROnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation. + """ + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int]): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = ((1.0 / (1.0 - dropout_p)) * + (torch.rand(*dropout_shape, + device=x.device, dtype=x.dtype) > dropout_p)) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + 'SwooshL': k2.swoosh_l_forward, + 'SwooshR': k2.swoosh_r_forward + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + 'SwooshL': k2.swoosh_l_forward_and_deriv, + 'SwooshR': k2.swoosh_r_forward_and_deriv + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), + y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + def __init__(self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = 'SwooshL', + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear(in_channels, out_channels, + bias=bias, + initial_scale=initial_scale) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter('bias', l.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, + x: Tensor): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + if self.activation == 'SwooshL': + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, + self.weight, + self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, self.weight, self.bias, self.activation, + float(self.dropout_p), self.dropout_shared_dim) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = ((1.2-(-0.043637))/255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = (1.0 / 255.0) + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = (1.0 / 255.0) + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:,0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:,0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear( (0, 10.0) ) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) + for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ['SwooshL', 'SwooshR']: + m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=0.5)) + m2 = ActivationDropoutAndLinear(in_channels, out_channels, + bias=bias, initial_scale=0.5, + activation=activation, + dropout_p=dropout_p) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose(SwooshRFunction.apply(x1), + SwooshRForward(x1), + atol=1.0e-03) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, + atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, + atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_double_swish_deriv() + _test_swooshr_deriv() + _test_swooshl_deriv() + _test_activation_dropout_and_linear() diff --git a/egs/swbd/ASR/zipformer/scaling_converter.py b/egs/swbd/ASR/zipformer/scaling_converter.py new file mode 100644 index 0000000000..76622fa129 --- /dev/null +++ b/egs/swbd/ASR/zipformer/scaling_converter.py @@ -0,0 +1,104 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file replaces various modules in a model. +Specifically, ActivationBalancer is replaced with an identity operator; +Whiten is also replaced with an identity operator; +BasicNorm is replaced by a module with `exp` removed. +""" + +import copy +from typing import List, Tuple + +import torch +import torch.nn as nn +from scaling import ( + Balancer, + Dropout3, + ScaleGrad, + SwooshL, + SwooshLOnnx, + SwooshR, + SwooshROnnx, + Whiten, +) +from zipformer import CompactRelPositionalEncoding + + +# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa +# get_submodule was added to nn.Module at v1.9.0 +def get_submodule(model, target): + if target == "": + return model + atoms: List[str] = target.split(".") + mod: torch.nn.Module = model + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) + mod = getattr(mod, item) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not " "an nn.Module") + return mod + + +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, + is_pnnx: bool = False, + is_onnx: bool = False, +): + """ + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. + is_onnx: + True if we are going to export the model for ONNX. + Return: + Return a model without scaled layers. + """ + if not inplace: + model = copy.deepcopy(model) + + d = {} + for name, m in model.named_modules(): + if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): + d[name] = nn.Identity() + elif is_onnx and isinstance(m, SwooshR): + d[name] = SwooshROnnx() + elif is_onnx and isinstance(m, SwooshL): + d[name] = SwooshLOnnx() + elif is_onnx and isinstance(m, CompactRelPositionalEncoding): + # We want to recreate the positional encoding vector when + # the input changes, so we have to use torch.jit.script() + # to replace torch.jit.trace() + d[name] = torch.jit.script(m) + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(get_submodule(model, parent), child, v) + else: + setattr(model, k, v) + + return model diff --git a/egs/swbd/ASR/zipformer/sclite_scoring.py b/egs/swbd/ASR/zipformer/sclite_scoring.py new file mode 100755 index 0000000000..fe0fbbeeea --- /dev/null +++ b/egs/swbd/ASR/zipformer/sclite_scoring.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", + "HMM", + "UMM", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]", +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + text_split = text.split() + word_to_skip = 0 + for idx, word in enumerate(text_split): + if word_to_skip > 0: + word_to_skip -= 1 + continue + if word in non_scoring_words: + continue + elif word == "CANCELLED": + remaining_words.append("CANCELED") + continue + elif word == "AIRFLOW": + remaining_words.append("AIR") + remaining_words.append("FLOW") + continue + elif word == "PHD": + remaining_words.append("P") + remaining_words.append("H") + remaining_words.append("D") + continue + elif word == "ONTO": + remaining_words.append("ON") + remaining_words.append("TO") + continue + elif word == "DAY" and text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/zipformer/streaming_beam_search.py b/egs/swbd/ASR/zipformer/streaming_beam_search.py new file mode 100644 index 0000000000..3c8565b330 --- /dev/null +++ b/egs/swbd/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1,295 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import k2 +import torch +import torch.nn as nn +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from decode_stream import DecodeStream + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + blank_penalty: float = 0.0, +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + num_active_paths: int = 4, + blank_penalty: float = 0.0, +) -> None: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + num_active_paths: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + streams: List[DecodeStream], + beam: float, + max_states: int, + max_contexts: int, + blank_penalty: float = 0.0, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first generated by Fsa-based beam search, then we get the + recognition by applying shortest path on the lattice. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + streams: + A list of stream objects. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + B, T, C = encoder_out.shape + assert B == len(streams) + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyp_tokens[i] diff --git a/egs/swbd/ASR/zipformer/streaming_decode.py b/egs/swbd/ASR/zipformer/streaming_decode.py new file mode 100755 index 0000000000..44ff392a3a --- /dev/null +++ b/egs/swbd/ASR/zipformer/streaming_decode.py @@ -0,0 +1,876 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat( + [state_list[i][-1] for i in range(batch_size)], dim=0 + ) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk( + chunks=batch_size, dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states( + model=model, batch_size=1, device=device + ) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/subsampling.py b/egs/swbd/ASR/zipformer/subsampling.py new file mode 100644 index 0000000000..6532ddccb6 --- /dev/null +++ b/egs/swbd/ASR/zipformer/subsampling.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple +import warnings + +import torch +from torch import Tensor, nn +from scaling import ( + Balancer, + BiasNorm, + Dropout3, + FloatLike, + Optional, + ScaledConv2d, + ScaleGrad, + ScheduledFloat, + SwooshL, + SwooshR, + Whiten, +) + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + layerdrop_rate: FloatLike = None, + ): + super().__init__() + self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + hidden_channels = channels * hidden_ratio + if layerdrop_rate is None: + layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) + self.layerdrop_rate = layerdrop_rate + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=self.padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1 + ) + + self.hidden_balancer = Balancer( + hidden_channels, + channel_dim=1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + self.activation = SwooshL() + self.pointwise_conv2 = ScaledConv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + initial_scale=0.01, + ) + + self.out_balancer = Balancer( + channels, + channel_dim=1, + min_positive=0.4, + max_positive=0.6, + min_abs=1.0, + max_abs=6.0, + ) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = ( + torch.rand( + (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device + ) + > layerdrop_rate + ) + else: + mask = None + # turns out this caching idea does not work with --world-size > 1 + # return caching_eval(self.forward_internal, x, mask) + return self.forward_internal(x, mask) + + def forward_internal( + self, x: Tensor, layer_skip_mask: Optional[Tensor] = None + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + if layer_skip_mask is not None: + x = x * layer_skip_mask + + x = bypass + x + x = self.out_balancer(x) + + if x.requires_grad: + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) + + return x + + def streaming_forward( + self, + x: Tensor, + cached_left_pad: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - The returned value has the same shape as x. + - Updated cached_left_pad. + """ + padding = self.padding + + # The length without right padding for depth-wise conv + T = x.size(2) - padding[0] + + bypass = x[:, :, :T, :] + + # Pad left side + assert cached_left_pad.size(2) == padding[0], ( + cached_left_pad.size(2), + padding[0], + ) + x = torch.cat([cached_left_pad, x], dim=2) + # Update cached left padding + cached_left_pad = x[:, :, T : padding[0] + T, :] + + # depthwise_conv + x = torch.nn.functional.conv2d( + x, + weight=self.depthwise_conv.weight, + bias=self.depthwise_conv.bias, + padding=(0, padding[1]), + groups=self.depthwise_conv.groups, + ) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + return x, cached_left_pad + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: FloatLike = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + super().__init__() + + # The ScaleGrad module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, max_abs=1.0), + SwooshR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + Balancer(layer2_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + Balancer(layer3_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + + # (in_channels-3)//4 + self.out_width = (((in_channels - 1) // 2) - 1) // 2 + self.layer3_channels = layer3_channels + + self.out = nn.Linear(self.out_width * layer3_channels, out_channels) + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=ScheduledFloat( + (0.0, 4.0), (20000.0, 8.0), default=4.0 + ), + prob=(0.025, 0.25), + grad_scale=0.02, + ) + + # max_log_eps=0.0 is to prevent both eps and the output of self.out from + # getting large, there is an unnecessary degree of freedom. + self.out_norm = BiasNorm(out_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, (T-7)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) + # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite + # gradients. + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, (T-7)//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, (T-7)//2, odim) + x = self.out_whiten(x) + x = self.out_norm(x) + x = self.dropout(x) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens - 7) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) + + return x, x_lens + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + cached_left_pad: Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, (T-7)//2, odim) + - output lengths, of shape (batch_size,) + - updated cache + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + + # T' = (T-7)//2 + x = self.conv(x) + + # T' = (T-7)//2-3 + x, cached_left_pad = self.convnext.streaming_forward( + x, cached_left_pad=cached_left_pad + ) + + # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, T', out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, T', odim) + x = self.out_norm(x) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert self.convnext.padding[0] == 3 + # The ConvNeXt module needs 3 frames of right padding after subsampling + x_lens = (x_lens - 7) // 2 - 3 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # The ConvNeXt module needs 3 frames of right padding after subsampling + assert self.convnext.padding[0] == 3 + x_lens = (x_lens - 7) // 2 - 3 + + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) + + return x, x_lens, cached_left_pad + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + """Get initial states for Conv2dSubsampling module. + It is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + """ + left_pad = self.convnext.padding[0] + freq = self.out_width + channels = self.layer3_channels + cached_embed_left_pad = torch.zeros( + batch_size, channels, left_pad, freq + ).to(device) + + return cached_embed_left_pad diff --git a/egs/swbd/ASR/zipformer/test_scaling.py b/egs/swbd/ASR/zipformer/test_scaling.py new file mode 100755 index 0000000000..5c04291e73 --- /dev/null +++ b/egs/swbd/ASR/zipformer/test_scaling.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +import matplotlib.pyplot as plt +import torch +from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR + + +def test_piecewise_linear(): + # An identity map in the range [0, 1]. + # 1 - identity map in the range [1, 2] + # x1=0, y1=0 + # x2=1, y2=1 + # x3=2, y3=0 + pl = PiecewiseLinear((0, 0), (1, 1), (2, 0)) + assert pl(0.25) == 0.25, pl(0.25) + assert pl(0.625) == 0.625, pl(0.625) + assert pl(1.25) == 0.75, pl(1.25) + + assert pl(-10) == pl(0), pl(-10) # out of range + assert pl(10) == pl(2), pl(10) # out of range + + # multiplication + pl10 = pl * 10 + assert pl10(1) == 10 * pl(1) + assert pl10(0.5) == 10 * pl(0.5) + + +def test_scheduled_float(): + # Initial value is 0.2 and it decreases linearly towards 0 at 4000 + dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0) + dropout.batch_count = 0 + assert float(dropout) == 0.2, (float(dropout), dropout.batch_count) + + dropout.batch_count = 1000 + assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count) + + dropout.batch_count = 2000 + assert float(dropout) == 0.1, (float(dropout), dropout.batch_count) + + dropout.batch_count = 3000 + assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count) + + dropout.batch_count = 4000 + assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) + + dropout.batch_count = 5000 # out of range + assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) + + +def test_swoosh(): + x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32) + x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32) + x = torch.cat([x1, x2[1:]]) + + left = SwooshL()(x) + r = SwooshR()(x) + + relu = torch.nn.functional.relu(x) + print(left[x == 0], r[x == 0]) + plt.plot(x, left, "k") + plt.plot(x, r, "r") + plt.plot(x, relu, "b") + plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax] + plt.legend( + [ + "SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ", + "SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687", + "ReLU(x) = max(0, x)", + ] + ) + plt.grid() + plt.savefig("swoosh.pdf") + + +def main(): + test_piecewise_linear() + test_scheduled_float() + test_swoosh() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/test_subsampling.py b/egs/swbd/ASR/zipformer/test_subsampling.py new file mode 100755 index 0000000000..078227fb68 --- /dev/null +++ b/egs/swbd/ASR/zipformer/test_subsampling.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 + +import torch +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling + + +def test_conv2d_subsampling(): + layer1_channels = 8 + layer2_channels = 32 + layer3_channels = 128 + + out_channels = 192 + encoder_embed = Conv2dSubsampling( + in_channels=80, + out_channels=out_channels, + layer1_channels=layer1_channels, + layer2_channels=layer2_channels, + layer3_channels=layer3_channels, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + N = 2 + T = 200 + num_features = 80 + x = torch.rand(N, T, num_features) + x_copy = x.clone() + + x = x.unsqueeze(1) # (N, 1, T, num_features) + + x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1) + assert x.shape == (N, layer1_channels, T - 2, num_features) + # (2, 8, 198, 80) + + x = encoder_embed.conv[1](x) # scale grad + x = encoder_embed.conv[2](x) # balancer + x = encoder_embed.conv[3](x) # swooshR + + x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2 + assert x.shape == ( + N, + layer2_channels, + ((T - 2) - 3) // 2 + 1, + (num_features - 3) // 2 + 1, + ) + # (2, 32, 98, 39) + + x = encoder_embed.conv[5](x) # balancer + x = encoder_embed.conv[6](x) # swooshR + + # conv2d: + # in 32, out 128, kernel 3, stride (1, 2) + x = encoder_embed.conv[7](x) + assert x.shape == ( + N, + layer3_channels, + (((T - 2) - 3) // 2 + 1) - 2, + (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + ) + # (2, 128, 96, 19) + + x = encoder_embed.conv[8](x) # balancer + x = encoder_embed.conv[9](x) # swooshR + + # (((T - 2) - 3) // 2 + 1) - 2 + # = (T - 2) - 3) // 2 + 1 - 2 + # = ((T - 2) - 3) // 2 - 1 + # = (T - 2 - 3) // 2 - 1 + # = (T - 5) // 2 - 1 + # = (T - 7) // 2 + assert x.shape[2] == (x_copy.shape[1] - 7) // 2 + + # (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + # = ((num_features - 3) // 2 + 1 - 3) // 2 + 1, + # = ((num_features - 3) // 2 - 2) // 2 + 1, + # = (num_features - 3 - 4) // 2 // 2 + 1, + # = (num_features - 7) // 2 // 2 + 1, + # = (num_features - 7) // 4 + 1, + # = (num_features - 3) // 4 + assert x.shape[3] == (x_copy.shape[2] - 3) // 4 + + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # Input shape to convnext is + # + # (N, layer3_channels, (T-7)//2, (num_features - 3)//4) + + # conv2d: in layer3_channels, out layer3_channels, groups layer3_channels + # kernel_size 7, padding 3 + x = encoder_embed.convnext.depthwise_conv(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1 + x = encoder_embed.convnext.pointwise_conv1(x) + assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4) + + x = encoder_embed.convnext.hidden_balancer(x) # balancer + x = encoder_embed.convnext.activation(x) # swooshL + + # conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1 + x = encoder_embed.convnext.pointwise_conv2(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # bypass and layer drop, omitted here. + x = encoder_embed.convnext.out_balancer(x) + + # Note: the input and output shape of ConvNeXt are the same + + x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1) + assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4)) + + x = encoder_embed.out(x) + assert x.shape == (N, (T - 7) // 2, out_channels) + + x = encoder_embed.out_whiten(x) + x = encoder_embed.out_norm(x) + # final layer is dropout + + # test streaming forward + + subsampling_factor = 2 + cached_left_padding = encoder_embed.get_init_states(batch_size=N) + depthwise_conv_kernel_size = 7 + pad_size = (depthwise_conv_kernel_size - 1) // 2 + + assert cached_left_padding.shape == ( + N, + layer3_channels, + pad_size, + (num_features - 3) // 4, + ) + + chunk_size = 16 + right_padding = pad_size * subsampling_factor + T = chunk_size * subsampling_factor + 7 + right_padding + x = torch.rand(N, T, num_features) + x_lens = torch.tensor([T] * N) + y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward( + x, x_lens, cached_left_padding + ) + + assert y.shape == (N, chunk_size, out_channels), y.shape + assert next_cached_left_padding.shape == cached_left_padding.shape + + assert y.shape[1] == y_lens[0] == y_lens[1] + + +def main(): + test_conv2d_subsampling() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/train.py b/egs/swbd/ASR/zipformer/train.py new file mode 100755 index 0000000000..a4cac01142 --- /dev/null +++ b/egs/swbd/ASR/zipformer/train.py @@ -0,0 +1,1383 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + switchboard = SwitchBoardAsrDataModule(args) + + train_cuts = switchboard.train_all_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = switchboard.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = switchboard.dev_cuts() + valid_dl = switchboard.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/zipformer/zipformer.py b/egs/swbd/ASR/zipformer/zipformer.py new file mode 100644 index 0000000000..b39af02b8a --- /dev/null +++ b/egs/swbd/ASR/zipformer/zipformer.py @@ -0,0 +1,2254 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union +import logging +import torch +import random +from encoder_interface import EncoderInterface +from scaling import ( + Balancer, + BiasNorm, + Dropout2, + ChunkCausalDepthwiseConv1d, + ActivationDropoutAndLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + penalize_abs_values_gt, + softmax, + ScheduledFloat, + FloatLike, + limit_param_value, + convert_num_channels, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), + (20000.0, 0.1)) + + def _to_tuple(x): + """ Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u,d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample(max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout) + + def get_feature_masks( + self, + x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [ 1.0 ] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = (torch.rand(1, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and(mask1, + (torch.rand(1, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones(1, batch_size, channels, + dtype=x.dtype, device=x.device) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module(x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=(None if src_key_padding_mask is None + else src_key_padding_mask[...,::ds]), + attn_mask=attn_mask) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, + chunk_size: int, + left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all (chunk_size * left_context_chunks >= + (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders)) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, + src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [ outputs[-1] ] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) + cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) + states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), + (20000.0, ratio * x), + default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), + ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, + straight_through_rate=0) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, pos_dim=pos_dim, num_heads=num_heads, + query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.feed_forward1 = FeedforwardModule(embed_dim, + (feedforward_dim * 3) // 4, + dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, + feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, + (feedforward_dim * 5) // 4, + dropout) + + self.nonlin_attention = NonlinAttention(embed_dim, + hidden_channels=3 * embed_dim // 4) + + self.conv_module1 = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + + self.conv_module2 = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.2, max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.balancer2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.1, max_abs=4.0, + ) + + def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) + selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), + ff2_skip_rate) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), + ff3_skip_rate) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, + length_factor=1.0) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1. / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 + ) = states[i * 6: (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2 + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value(self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max)) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, + src_orig: Tensor, + src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + def __init__(self, + encoder: nn.Module, + dim: int, + downsample: int, + dropout: FloatLike): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, + downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds,::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + def __init__(self, + channels: int, + downsample: int, + dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + def __init__(self, + num_channels: int, + upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + def __init__( + self, embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T-1), T, + device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = (self.embed_dim ** 0.5) + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + : + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), + (4000.0, 0.0)) + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=query_head_dim**-0.25) + + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer(key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(pos_dim, + num_heads * pos_head_dim, + bias=False, + initial_scale=0.05) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt(attn_scores, + limit=25.0, + penalty=1.0e-04, + name=self.name) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy( + self, + attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).mean(dim=(1,2)) + logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, + num_heads * value_head_dim, + bias=True) + + self.out_proj = ScaledLinear(num_heads * value_head_dim, + embed_dim, bias=True, + initial_scale=0.05) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = x.permute(2, 1, 0, 3).contiguous().view( + seq_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = x.permute(2, 1, 0, 3).contiguous().view( + seq_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model. + """ + def __init__(self, + embed_dim: int, + feedforward_dim: int, + dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer(feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, + activation='SwooshL', + dropout_p=dropout, + dropout_shared_dim=0, bias=True, + initial_scale=0.1) + + self.out_whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear(hidden_channels, channels, + bias=True, + initial_scale=0.05) + + self.whiten1 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.whiten2 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) +attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + def __init__( + self, channels: int, kernel_size: int, causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ChunkCausalDepthwiseConv1d( + channels=bottleneck_dim, + kernel_size=kernel_size) if causal else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2) + + self.balancer2 = Balancer( + bottleneck_dim, channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, channels, activation='SwooshR', + dropout_p=0.0, initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=-1) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0: + # Not support exporting a model for simulated streaming decoding + assert self.causal, "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,) + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) From 4ab83db8176185f1df1920ac7d517e966298d638 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:52:10 +0800 Subject: [PATCH 40/60] minor fixes --- egs/swbd/ASR/zipformer/sclite_scoring.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/swbd/ASR/zipformer/sclite_scoring.py b/egs/swbd/ASR/zipformer/sclite_scoring.py index fe0fbbeeea..239204221e 100755 --- a/egs/swbd/ASR/zipformer/sclite_scoring.py +++ b/egs/swbd/ASR/zipformer/sclite_scoring.py @@ -85,9 +85,13 @@ def asr_text_post_processing(text: str) -> str: remaining_words.append("ON") remaining_words.append("TO") continue - elif word == "DAY" and text_split[idx + 1] == "CARE": - remaining_words.append("DAYCARE") - word_to_skip = 1 + elif word == "DAY": + try: + if text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + except: + remaining_words.append(word) continue remaining_words.append(word) From c0f2abd2a74bd3814d809a5eaa1d3ccfd2e94f47 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:58:30 +0800 Subject: [PATCH 41/60] minor fix --- egs/swbd/ASR/conformer_ctc/sclite_scoring.py | 16 +++++++++++++--- egs/swbd/ASR/zipformer/sclite_scoring.py | 6 ++++++ icefall/checkpoint.py | 2 ++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py index fe0fbbeeea..0383c4d71b 100755 --- a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -81,13 +81,23 @@ def asr_text_post_processing(text: str) -> str: remaining_words.append("H") remaining_words.append("D") continue + elif word == "UCLA": + remaining_words.append("U") + remaining_words.append("C") + remaining_words.append("L") + remaining_words.append("A") + continue elif word == "ONTO": remaining_words.append("ON") remaining_words.append("TO") continue - elif word == "DAY" and text_split[idx + 1] == "CARE": - remaining_words.append("DAYCARE") - word_to_skip = 1 + elif word == "DAY": + try: + if text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + except: + remaining_words.append(word) continue remaining_words.append(word) diff --git a/egs/swbd/ASR/zipformer/sclite_scoring.py b/egs/swbd/ASR/zipformer/sclite_scoring.py index 239204221e..0383c4d71b 100755 --- a/egs/swbd/ASR/zipformer/sclite_scoring.py +++ b/egs/swbd/ASR/zipformer/sclite_scoring.py @@ -81,6 +81,12 @@ def asr_text_post_processing(text: str) -> str: remaining_words.append("H") remaining_words.append("D") continue + elif word == "UCLA": + remaining_words.append("U") + remaining_words.append("C") + remaining_words.append("L") + remaining_words.append("A") + continue elif word == "ONTO": remaining_words.append("ON") remaining_words.append("TO") diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index c83c56a53b..f7712628b3 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -76,7 +76,9 @@ def save_checkpoint( if isinstance(model, DDP): model = model.module + # import pdb + # pdb.set_trace() checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict() if optimizer is not None else None, From bd053ca90aa67ddc482709722de0971014d3f635 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:00:20 +0800 Subject: [PATCH 42/60] minor fix --- icefall/checkpoint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index f7712628b3..ea50b32a9c 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -76,9 +76,7 @@ def save_checkpoint( if isinstance(model, DDP): model = model.module - # import pdb - - # pdb.set_trace() + checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict() if optimizer is not None else None, From c3dc8f005b72eef6f0283e07450d05a222b153a3 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:02:20 +0800 Subject: [PATCH 43/60] minor fix --- icefall/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index ea50b32a9c..c83c56a53b 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -76,7 +76,7 @@ def save_checkpoint( if isinstance(model, DDP): model = model.module - + checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict() if optimizer is not None else None, From e13497c907580dd6fd4ae99b90d5d24a6de4a830 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:43:30 +0800 Subject: [PATCH 44/60] removed `zipformer` recipe for now --- egs/swbd/ASR/zipformer/.gitignore | 1 - egs/swbd/ASR/zipformer/asr_datamodule.py | 1 - egs/swbd/ASR/zipformer/beam_search.py | 1 - egs/swbd/ASR/zipformer/ctc_decode.py | 847 ------- egs/swbd/ASR/zipformer/decode.py | 1045 -------- egs/swbd/ASR/zipformer/decode_stream.py | 148 -- egs/swbd/ASR/zipformer/decoder.py | 122 - egs/swbd/ASR/zipformer/encoder_interface.py | 1 - .../ASR/zipformer/export-onnx-streaming.py | 774 ------ egs/swbd/ASR/zipformer/export-onnx.py | 619 ----- egs/swbd/ASR/zipformer/export.py | 526 ---- .../ASR/zipformer/generate_averaged_model.py | 193 -- egs/swbd/ASR/zipformer/jit_pretrained.py | 280 -- egs/swbd/ASR/zipformer/jit_pretrained_ctc.py | 436 ---- .../ASR/zipformer/jit_pretrained_streaming.py | 273 -- egs/swbd/ASR/zipformer/joiner.py | 66 - egs/swbd/ASR/zipformer/model.py | 358 --- egs/swbd/ASR/zipformer/onnx_check.py | 240 -- egs/swbd/ASR/zipformer/onnx_decode.py | 323 --- .../zipformer/onnx_pretrained-streaming.py | 543 ---- egs/swbd/ASR/zipformer/onnx_pretrained.py | 418 --- egs/swbd/ASR/zipformer/optim.py | 1173 --------- egs/swbd/ASR/zipformer/pretrained.py | 381 --- egs/swbd/ASR/zipformer/pretrained_ctc.py | 455 ---- egs/swbd/ASR/zipformer/profile.py | 176 -- egs/swbd/ASR/zipformer/scaling.py | 1840 -------------- egs/swbd/ASR/zipformer/scaling_converter.py | 104 - egs/swbd/ASR/zipformer/sclite_scoring.py | 148 -- .../ASR/zipformer/streaming_beam_search.py | 295 --- egs/swbd/ASR/zipformer/streaming_decode.py | 876 ------- egs/swbd/ASR/zipformer/subsampling.py | 410 --- egs/swbd/ASR/zipformer/test_scaling.py | 82 - egs/swbd/ASR/zipformer/test_subsampling.py | 152 -- egs/swbd/ASR/zipformer/train.py | 1383 ---------- egs/swbd/ASR/zipformer/zipformer.py | 2254 ----------------- 35 files changed, 16944 deletions(-) delete mode 100644 egs/swbd/ASR/zipformer/.gitignore delete mode 120000 egs/swbd/ASR/zipformer/asr_datamodule.py delete mode 120000 egs/swbd/ASR/zipformer/beam_search.py delete mode 100755 egs/swbd/ASR/zipformer/ctc_decode.py delete mode 100755 egs/swbd/ASR/zipformer/decode.py delete mode 100644 egs/swbd/ASR/zipformer/decode_stream.py delete mode 100644 egs/swbd/ASR/zipformer/decoder.py delete mode 120000 egs/swbd/ASR/zipformer/encoder_interface.py delete mode 100755 egs/swbd/ASR/zipformer/export-onnx-streaming.py delete mode 100755 egs/swbd/ASR/zipformer/export-onnx.py delete mode 100755 egs/swbd/ASR/zipformer/export.py delete mode 100755 egs/swbd/ASR/zipformer/generate_averaged_model.py delete mode 100755 egs/swbd/ASR/zipformer/jit_pretrained.py delete mode 100755 egs/swbd/ASR/zipformer/jit_pretrained_ctc.py delete mode 100755 egs/swbd/ASR/zipformer/jit_pretrained_streaming.py delete mode 100644 egs/swbd/ASR/zipformer/joiner.py delete mode 100644 egs/swbd/ASR/zipformer/model.py delete mode 100755 egs/swbd/ASR/zipformer/onnx_check.py delete mode 100755 egs/swbd/ASR/zipformer/onnx_decode.py delete mode 100755 egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py delete mode 100755 egs/swbd/ASR/zipformer/onnx_pretrained.py delete mode 100644 egs/swbd/ASR/zipformer/optim.py delete mode 100755 egs/swbd/ASR/zipformer/pretrained.py delete mode 100755 egs/swbd/ASR/zipformer/pretrained_ctc.py delete mode 100755 egs/swbd/ASR/zipformer/profile.py delete mode 100644 egs/swbd/ASR/zipformer/scaling.py delete mode 100644 egs/swbd/ASR/zipformer/scaling_converter.py delete mode 100755 egs/swbd/ASR/zipformer/sclite_scoring.py delete mode 100644 egs/swbd/ASR/zipformer/streaming_beam_search.py delete mode 100755 egs/swbd/ASR/zipformer/streaming_decode.py delete mode 100644 egs/swbd/ASR/zipformer/subsampling.py delete mode 100755 egs/swbd/ASR/zipformer/test_scaling.py delete mode 100755 egs/swbd/ASR/zipformer/test_subsampling.py delete mode 100755 egs/swbd/ASR/zipformer/train.py delete mode 100644 egs/swbd/ASR/zipformer/zipformer.py diff --git a/egs/swbd/ASR/zipformer/.gitignore b/egs/swbd/ASR/zipformer/.gitignore deleted file mode 100644 index e47ac15828..0000000000 --- a/egs/swbd/ASR/zipformer/.gitignore +++ /dev/null @@ -1 +0,0 @@ -swoosh.pdf diff --git a/egs/swbd/ASR/zipformer/asr_datamodule.py b/egs/swbd/ASR/zipformer/asr_datamodule.py deleted file mode 120000 index a73848de92..0000000000 --- a/egs/swbd/ASR/zipformer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/swbd/ASR/zipformer/beam_search.py b/egs/swbd/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f2..0000000000 --- a/egs/swbd/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/swbd/ASR/zipformer/ctc_decode.py b/egs/swbd/ASR/zipformer/ctc_decode.py deleted file mode 100755 index 4db50b9813..0000000000 --- a/egs/swbd/ASR/zipformer/ctc_decode.py +++ /dev/null @@ -1,847 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Liyong Guo, -# Quandong Wang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) ctc-decoding -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --decoding-method ctc-decoding - -(2) 1best -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --decoding-method 1best - -(3) nbest -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --decoding-method nbest - -(4) nbest-rescoring -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method nbest-rescoring - -(5) whole-lattice-rescoring -./zipformer/ctc_decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --max-duration 600 \ - --hlg-scale 0.6 \ - --nbest-scale 1.0 \ - --lm-dir data/lm \ - --decoding-method whole-lattice-rescoring -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="ctc-decoding", - help="""Decoding method. - Supported values are: - - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (2) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--hlg-scale", - type=float, - default=0.6, - help="""The scale to be applied to `hlg.scores`. - """, - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The n-gram LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. - - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. - - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. - - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - torch.div( - supervisions["start_frame"], - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - supervisions["num_frames"], - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.decoding_method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.decoding_method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.decoding_method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.decoding_method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - else: - assert False, f"Unsupported decoding method: {params.decoding_method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.decoding_method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.decoding_method is ctc-decoding. - word_table: - It is the word symbol table. - G: - An LM. It is not None when params.decoding_method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - assert params.decoding_method in ( - "ctc-decoding", - "1best", - "nbest", - "nbest-rescoring", - "whole-lattice-rescoring", - "nbest-oracle", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - params.vocab_size = num_classes - # and are defined in local/train_bpe_model.py - params.blank_id = 0 - - if params.decoding_method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - HLG.scores *= params.hlg_scale - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.decoding_method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.decoding_method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/decode.py b/egs/swbd/ASR/zipformer/decode.py deleted file mode 100755 index 10de28dfcc..0000000000 --- a/egs/swbd/ASR/zipformer/decode.py +++ /dev/null @@ -1,1045 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import SwitchBoardAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_lm_rescore, - modified_beam_search_lm_rescore_LODR, - modified_beam_search_lm_shallow_fusion, - modified_beam_search_LODR, -) -from train import add_model_arguments, get_model, get_params -from sclite_scoring import asr_text_post_processing - - -from lhotse.cut import Cut -from icefall import LmScorer, NgramLm -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--use-shallow-fusion", - type=str2bool, - default=False, - help="""Use neural network LM for shallow fusion. - If you want to use LODR, you will also need to set this to true - """, - ) - - parser.add_argument( - "--lm-type", - type=str, - default="rnn", - help="Type of NN lm", - choices=["rnn", "transformer"], - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.3, - help="""The scale of the neural network LM - Used only when `--use-shallow-fusion` is set to True. - """, - ) - - parser.add_argument( - "--tokens-ngram", - type=int, - default=2, - help="""The order of the ngram lm. - """, - ) - - parser.add_argument( - "--backoff-id", - type=int, - default=500, - help="ID of the backoff symbol in the ngram LM", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - LM: - A neural network language model. - ngram_lm: - A ngram language model - ngram_lm_scale: - The scale for the ngram language model. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": - hyp_tokens = modified_beam_search_lm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_LODR": - hyp_tokens = modified_beam_search_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LODR_lm=ngram_lm, - LODR_lm_scale=ngram_lm_scale, - LM=LM, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_lm_rescore": - lm_scale_list = [0.01 * i for i in range(10, 50)] - ans_dict = modified_beam_search_lm_rescore( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - lm_scale_list=lm_scale_list, - ) - elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": - lm_scale_list = [0.02 * i for i in range(2, 30)] - ans_dict = modified_beam_search_lm_rescore_LODR( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - LM=LM, - LODR_lm=ngram_lm, - sp=sp, - lm_scale_list=lm_scale_list, - ) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - elif params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"beam_size_{params.beam_size}_{key}"] = hyps - return ans - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - if test_set_name == "test-eval2000": - subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} - elif test_set_name == "test-rt03": - subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} - else: - raise NotImplementedError(f"No implementation for testset {test_set_name}") - for subset, prefix in subsets.items(): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir - / f"recogs-{test_set_name}-{subset}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - results = ( - sorted(list(filter(lambda x: x[0].startswith(prefix), results))) - if subset != "avg" - else sorted(results) - ) - results = post_processing(results) - - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir - / f"errs-{test_set_name}-{subset}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, - f"{test_set_name}-{subset}-{key}", - results, - enable_log=True, - sclite_mode=True, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{subset}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - SwitchBoardAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" - - if "LODR" in params.decoding_method: - params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - # only load the neural network LM if required - if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): - LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) - LM.to(device) - LM.eval() - else: - LM = None - - # only load N-gram LM when needed - if params.decoding_method == "modified_beam_search_lm_rescore_LODR": - try: - import kenlm - except ImportError: - print("Please install kenlm first. You can use") - print(" pip install https://github.com/kpu/kenlm/archive/master.zip") - print("to install it") - import sys - - sys.exit(-1) - ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") - logging.info(f"lm filename: {ngram_file_name}") - ngram_lm = kenlm.Model(ngram_file_name) - ngram_lm_scale = None # use a list to search - - elif params.decoding_method == "modified_beam_search_LODR": - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"Loading token level lm: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") - ngram_lm_scale = params.ngram_lm_scale - else: - ngram_lm = None - ngram_lm_scale = None - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - switchboard = SwitchBoardAsrDataModule(args) - - test_eval2000_cuts = ( - switchboard.test_eval2000_cuts() - .trim_to_supervisions(keep_all_channels=True) - .filter(remove_short_utt) - ) - # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( - # keep_all_channels=True - # ) - - test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) - # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) - - # test_sets = ["test-eval2000", "test-rt03"] - # test_dl = [test_eval2000_dl, test_rt03_dl] - test_dl = [test_eval2000_dl] - test_sets = ["test-eval2000"] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/decode_stream.py b/egs/swbd/ASR/zipformer/decode_stream.py deleted file mode 100644 index d6918bf328..0000000000 --- a/egs/swbd/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, at encoder output - self.done_frames: int = 0 - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - if params.decoding_method == "greedy_search": - self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[-1] * (params.context_size - 1) + [params.blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - tail_pad_len: int = 0, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length + tail_pad_len), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/swbd/ASR/zipformer/decoder.py b/egs/swbd/ASR/zipformer/decoder.py deleted file mode 100644 index e8db988f6e..0000000000 --- a/egs/swbd/ASR/zipformer/decoder.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from scaling import Balancer - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - ) - # the balancers are to avoid any drift in the magnitude of the - # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) - - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim // 4, # group size == 4 - bias=False, - ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - - embedding_out = self.balancer(embedding_out) - - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - embedding_out = self.balancer2(embedding_out) - - return embedding_out diff --git a/egs/swbd/ASR/zipformer/encoder_interface.py b/egs/swbd/ASR/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09af..0000000000 --- a/egs/swbd/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/swbd/ASR/zipformer/export-onnx-streaming.py b/egs/swbd/ASR/zipformer/export-onnx-streaming.py deleted file mode 100755 index a951aeef35..0000000000 --- a/egs/swbd/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1,774 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx-streaming.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal True \ - --chunk-size 16 \ - --left-context-frames 64 - -The --chunk-size in training is "16,32,64,-1", so we select one of them -(excluding -1) during streaming export. The same applies to `--left-context`, -whose value is "64,128,256,-1". - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1-chunk-16-left-64.onnx - - decoder-epoch-99-avg-1-chunk-16-left-64.onnx - - joiner-epoch-99-avg-1-chunk-16-left-64.onnx - -See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - self.pad_length = 7 + 2 * 3 - - def forward( - self, - x: torch.Tensor, - states: List[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: - N = x.size(0) - T = self.chunk_size * 2 + self.pad_length - x_lens = torch.tensor([T] * N, device=x.device) - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=x, - x_lens=x_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) - - src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) - encoder_states = states[:-2] - logging.info(f"len_encoder_states={len(encoder_states)}") - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - - return encoder_out, new_states - - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) - states.append(processed_lens) - - return states - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - encoder_model.encoder.__class__.forward = ( - encoder_model.encoder.__class__.streaming_forward - ) - - decode_chunk_len = encoder_model.chunk_size * 2 - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - T = decode_chunk_len + encoder_model.pad_length - - x = torch.rand(1, T, 80, dtype=torch.float32) - init_state = encoder_model.get_init_states() - num_encoders = len(encoder_model.encoder.encoder_dim) - logging.info(f"num_encoders: {num_encoders}") - logging.info(f"len(init_state): {len(init_state)}") - - inputs = {} - input_names = ["x"] - - outputs = {} - output_names = ["encoder_out"] - - def build_inputs_outputs(tensors, i): - assert len(tensors) == 6, len(tensors) - - # (downsample_left, batch_size, key_dim) - name = f"cached_key_{i}" - logging.info(f"{name}.shape: {tensors[0].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f"cached_nonlin_attn_{i}" - logging.info(f"{name}.shape: {tensors[1].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val1_{i}" - logging.info(f"{name}.shape: {tensors[2].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val2_{i}" - logging.info(f"{name}.shape: {tensors[3].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv1_{i}" - logging.info(f"{name}.shape: {tensors[4].shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv2_{i}" - logging.info(f"{name}.shape: {tensors[5].shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) - encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) - cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel)) - ds = encoder_model.encoder.downsampling_factor - left_context_len = encoder_model.left_context_len - left_context_len = [left_context_len // k for k in ds] - left_context_len = ",".join(map(str, left_context_len)) - query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim)) - value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim)) - num_heads = ",".join(map(str, encoder_model.encoder.num_heads)) - - meta_data = { - "model_type": "zipformer2", - "version": "1", - "model_author": "k2-fsa", - "comment": "streaming zipformer2", - "decode_chunk_len": str(decode_chunk_len), # 32 - "T": str(T), # 32+7+2*3=45 - "num_encoder_layers": num_encoder_layers, - "encoder_dims": encoder_dims, - "cnn_module_kernels": cnn_module_kernels, - "left_context_len": left_context_len, - "query_head_dims": query_head_dims, - "value_head_dims": value_head_dims, - "num_heads": num_heads, - } - logging.info(f"meta_data: {meta_data}") - - for i in range(len(init_state[:-2]) // 6): - build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) - - # (batch_size, channels, left_pad, freq) - embed_states = init_state[-2] - name = "embed_states" - logging.info(f"{name}.shape: {embed_states.shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (batch_size,) - processed_lens = init_state[-1] - name = "processed_lens" - logging.info(f"{name}.shape: {processed_lens.shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - logging.info(inputs) - logging.info(outputs) - logging.info(input_names) - logging.info(output_names) - - torch.onnx.export( - encoder_model, - (x, init_state), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "x": {0: "N"}, - "encoder_out": {0: "N"}, - **inputs, - **outputs, - }, - ) - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - suffix += f"-chunk-{params.chunk_size}" - suffix += f"-left-{params.left_context_frames}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/export-onnx.py b/egs/swbd/ASR/zipformer/export-onnx.py deleted file mode 100755 index e0d664009c..0000000000 --- a/egs/swbd/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1,619 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal False \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "zipformer2", - "version": "1", - "model_author": "k2-fsa", - "comment": "non-streaming zipformer2", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/export.py b/egs/swbd/ASR/zipformer/export.py deleted file mode 100755 index 2b8d1aaf36..0000000000 --- a/egs/swbd/ASR/zipformer/export.py +++ /dev/null @@ -1,526 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -(1) Export to torchscript model using torch.jit.script() - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. -You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. - -Check ./jit_pretrained_streaming.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -- For non-streaming model: - -To use the generated file with `zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - -- streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - # You will find the pre-trained models in exp dir -""" - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor, nn -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/generate_averaged_model.py b/egs/swbd/ASR/zipformer/generate_averaged_model.py deleted file mode 100755 index 68111fad71..0000000000 --- a/egs/swbd/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the checkpoint exp_dir/epoch-xxx.pt -./zipformer/generate_averaged_model.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. - -(2) use the checkpoint exp_dir/checkpoint-iter.pt -./zipformer/generate_averaged_model.py \ - --iter 22000 \ - --avg 5 \ - --exp-dir ./zipformer/exp - -It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. -""" - - -import argparse -from pathlib import Path - -import k2 -import torch -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - symbol_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.unk_id = symbol_table[""] - params.vocab_size = len(symbol_table) - - print("About to create model") - model = get_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/jit_pretrained.py b/egs/swbd/ASR/zipformer/jit_pretrained.py deleted file mode 100755 index a41fbc1c97..0000000000 --- a/egs/swbd/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1,280 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained.py \ - --nn-model-filename ./zipformer/exp/cpu_jit.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = model.decoder.blank_id - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - features=features, - feature_lengths=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - - s = "\n" - - token_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/jit_pretrained_ctc.py b/egs/swbd/ASR/zipformer/jit_pretrained_ctc.py deleted file mode 100755 index 660a4bfc60..0000000000 --- a/egs/swbd/ASR/zipformer/jit_pretrained_ctc.py +++ /dev/null @@ -1,436 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -(1) ctc-decoding -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method ctc-decoding \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) 1best -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --method 1best \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) nbest-rescoring -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method nbest-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) whole-lattice-rescoring -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method whole-lattice-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from ctc_decode import get_decoding_params -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import get_params - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.utils import get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the torchscript model.", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a token table, - i.e., lang_dir/token.txt, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an LM, the path with - the highest score is the decoding result. - We call it HLG decoding + nbest n-gram LM rescoring. - (3) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + whole-lattice n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or nbest-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and nbest-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help=""" - Used only when method is nbest-rescoring. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.model_filename) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(features, feature_lengths) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - batch_size = ctc_output.shape[0] - supervision_segments = torch.tensor( - [ - [i, 0, feature_lengths[i].item() // params.subsampling_factor] - for i in range(batch_size) - ], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - max_token_id = params.vocab_size - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = [[token_table[i] for i in ids] for ids in token_ids] - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - G = G.to(device) - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - if params.method == "nbest-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=[params.ngram_lm_scale], - nbest_scale=params.nbest_scale, - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - if params.method == "ctc-decoding": - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/jit_pretrained_streaming.py b/egs/swbd/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 100755 index d4ceacefd3..0000000000 --- a/egs/swbd/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -# flake8: noqa -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained_streaming.py \ - --nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - /path/to/foo.wav \ -""" - -import argparse -import logging -import math -from typing import List, Optional - -import k2 -import kaldifeat -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model jit_script.pt", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, - device: torch.device = torch.device("cpu"), -): - assert encoder_out.ndim == 2 - context_size = decoder.context_size - blank_id = decoder.blank_id - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) - # decoder_input.shape (1,, 1 context_size) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - else: - assert decoder_out.ndim == 2 - assert hyp is not None, hyp - - T = encoder_out.size(0) - for i in range(T): - cur_encoder_out = encoder_out[i : i + 1] - joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - - decoder_input = torch.tensor( - decoder_input, dtype=torch.int32, device=device - ).unsqueeze(0) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - - return hyp, decoder_out - - -def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - model.eval() - model.to(device) - - encoder = model.encoder - decoder = model.decoder - joiner = model.joiner - - token_table = k2.SymbolTable.from_file(args.tokens) - context_size = decoder.context_size - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor(args.sample_rate) - - logging.info(f"Reading sound files: {args.sound_file}") - wave_samples = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=args.sample_rate, - )[0] - logging.info(wave_samples.shape) - - logging.info("Decoding started") - - chunk_length = encoder.chunk_size * 2 - T = chunk_length + encoder.pad_length - - logging.info(f"chunk_length: {chunk_length}") - logging.info(f"T: {T}") - - states = encoder.get_init_states(device=device) - - tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - chunk = int(0.25 * args.sample_rate) # 0.2 second - num_processed_frames = 0 - - hyp = None - decoder_out = None - - start = 0 - while start < wave_samples.numel(): - logging.info(f"{start}/{wave_samples.numel()}") - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - online_fbank.accept_waveform( - sampling_rate=args.sample_rate, - waveform=samples, - ) - while online_fbank.num_frames_ready - num_processed_frames >= T: - frames = [] - for i in range(T): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) - x_lens = torch.tensor([T], dtype=torch.int32, device=device) - encoder_out, out_lens, states = encoder( - features=frames, - feature_lengths=x_lens, - states=states, - ) - num_processed_frames += chunk_length - - hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device - ) - - text = "" - for i in hyp[context_size:]: - text += token_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -torch.set_num_threads(4) -torch.set_num_interop_threads(1) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/joiner.py b/egs/swbd/ASR/zipformer/joiner.py deleted file mode 100644 index f03cc930ec..0000000000 --- a/egs/swbd/ASR/zipformer/joiner.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/swbd/ASR/zipformer/model.py b/egs/swbd/ASR/zipformer/model.py deleted file mode 100644 index f2f86af47a..0000000000 --- a/egs/swbd/ASR/zipformer/model.py +++ /dev/null @@ -1,358 +0,0 @@ -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear - - -class AsrModel(nn.Module): - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - decoder: Optional[nn.Module] = None, - joiner: Optional[nn.Module] = None, - encoder_dim: int = 384, - decoder_dim: int = 512, - vocab_size: int = 500, - use_transducer: bool = True, - use_ctc: bool = False, - ): - """A joint CTC & Transducer ASR model. - - - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) - - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) - - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) - - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dim) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - It is used when use_transducer is True. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - It is used when use_transducer is True. - use_transducer: - Whether use transducer head. Default: True. - use_ctc: - Whether use CTC head. Default: False. - """ - super().__init__() - - assert ( - use_transducer or use_ctc - ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" - - assert isinstance(encoder, EncoderInterface), type(encoder) - - self.encoder_embed = encoder_embed - self.encoder = encoder - - self.use_transducer = use_transducer - if use_transducer: - # Modules for Transducer head - assert decoder is not None - assert hasattr(decoder, "blank_id") - assert joiner is not None - - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_scale=0.25 - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, vocab_size, initial_scale=0.25 - ) - else: - assert decoder is None - assert joiner is None - - self.use_ctc = use_ctc - if use_ctc: - # Modules for CTC head - self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), - nn.Linear(encoder_dim, vocab_size), - nn.LogSoftmax(dim=-1), - ) - - def forward_encoder( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute encoder outputs. - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - - Returns: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - """ - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - x, x_lens = self.encoder_embed(x, x_lens) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - - return encoder_out, encoder_out_lens - - def forward_ctc( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - targets: torch.Tensor, - target_lengths: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - targets: - Target Tensor of shape (sum(target_lengths)). The targets are assumed - to be un-padded and concatenated within 1 dimension. - """ - # Compute CTC log-prob - ctc_output = self.ctc_output(encoder_out) # (N, T, C) - - ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets, - input_lengths=encoder_out_lens, - target_lengths=target_lengths, - reduction="sum", - ) - return ctc_loss - - def forward_transducer( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - y: k2.RaggedTensor, - y_lens: torch.Tensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute Transducer loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - """ - # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = encoder_out_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return simple_loss, pruned_loss - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss) - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) - - # Compute encoder outputs - encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) - - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - if self.use_transducer: - # Compute transducer loss - simple_loss, pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) - else: - simple_loss = torch.empty(0) - pruned_loss = torch.empty(0) - - if self.use_ctc: - # Compute CTC loss - targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) - else: - ctc_loss = torch.empty(0) - - return simple_loss, pruned_loss, ctc_loss diff --git a/egs/swbd/ASR/zipformer/onnx_check.py b/egs/swbd/ASR/zipformer/onnx_check.py deleted file mode 100755 index 93bd3a211c..0000000000 --- a/egs/swbd/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model via torchscript (torch.jit.script()) - -./zipformer/export.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp/ \ - --jit 1 - -It will generate the following file in $repo/exp: - - jit_script.pt - -3. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -4. Run this file - -./zipformer/onnx_check.py \ - --jit-filename $repo/exp/jit_script.pt \ - --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx -""" - -import argparse -import logging - -import torch -from onnx_pretrained import OnnxModel - -from icefall import is_module_available - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - return parser - - -def test_encoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - C = 80 - for i in range(3): - N = torch.randint(low=1, high=20, size=(1,)).item() - T = torch.randint(low=30, high=50, size=(1,)).item() - logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - - x = torch.rand(N, T, C) - x_lens = torch.randint(low=30, high=T + 1, size=(N,)) - x_lens[0] = T - - torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) - torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - - onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - - assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( - (torch_encoder_out - onnx_encoder_out).abs().max() - ) - - -def test_decoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - context_size = onnx_model.context_size - vocab_size = onnx_model.vocab_size - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_decoder: iter {i}, N={N}") - x = torch.randint( - low=1, - high=vocab_size, - size=(N, context_size), - dtype=torch.int64, - ) - torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) - torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) - torch_decoder_out = torch_decoder_out.squeeze(1) - - onnx_decoder_out = onnx_model.run_decoder(x) - assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( - (torch_decoder_out - onnx_decoder_out).abs().max() - ) - - -def test_joiner( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] - decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_joiner: iter {i}, N={N}") - encoder_out = torch.rand(N, encoder_dim) - decoder_out = torch.rand(N, decoder_dim) - - projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) - projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - - torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) - onnx_joiner_out = onnx_model.run_joiner( - projected_encoder_out, projected_decoder_out - ) - - assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( - (torch_joiner_out - onnx_joiner_out).abs().max() - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_model = torch.jit.load(args.jit_filename) - - onnx_model = OnnxModel( - encoder_model_filename=args.onnx_encoder_filename, - decoder_model_filename=args.onnx_decoder_filename, - joiner_model_filename=args.onnx_joiner_filename, - ) - - logging.info("Test encoder") - test_encoder(torch_model, onnx_model) - - logging.info("Test decoder") - test_decoder(torch_model, onnx_model) - - logging.info("Test joiner") - test_joiner(torch_model, onnx_model) - logging.info("Finished checking ONNX models") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# See https://github.com/pytorch/pytorch/issues/38342 -# and https://github.com/pytorch/pytorch/issues/33354 -# -# If we don't do this, the delay increases whenever there is -# a new request that changes the actual batch size. -# If you use `py-spy dump --pid --native`, you will -# see a lot of time is spent in re-compiling the torch script model. -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/onnx_decode.py b/egs/swbd/ASR/zipformer/onnx_decode.py deleted file mode 100755 index 2aca36ca94..0000000000 --- a/egs/swbd/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,323 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -2. Run this file - -./zipformer/onnx_decode.py \ - --exp-dir $repo/exp \ - --max-duration 600 \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ -""" - - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel - -from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - The token table. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - hyps = [token_ids_to_words(h).split() for h in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - The token table. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = SymbolTable.from_file(args.tokens) - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py b/egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 100755 index 500b2cd097..0000000000 --- a/egs/swbd/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1,543 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script loads ONNX models exported by ./export-onnx-streaming.py -and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx-streaming.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal True \ - --chunk-size 16 \ - --left-context-frames 64 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file with the exported ONNX models - -./zipformer/onnx_pretrained-streaming.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav - -Note: Even though this script only supports decoding a single file, -the exported ONNX models do support batch processing. -""" - -import argparse -import logging -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - ) - self.init_encoder_states() - - def init_encoder_states(self, batch_size: int = 1): - encoder_meta = self.encoder.get_modelmeta().custom_metadata_map - logging.info(f"encoder_meta={encoder_meta}") - - model_type = encoder_meta["model_type"] - assert model_type == "zipformer2", model_type - - decode_chunk_len = int(encoder_meta["decode_chunk_len"]) - T = int(encoder_meta["T"]) - - num_encoder_layers = encoder_meta["num_encoder_layers"] - encoder_dims = encoder_meta["encoder_dims"] - cnn_module_kernels = encoder_meta["cnn_module_kernels"] - left_context_len = encoder_meta["left_context_len"] - query_head_dims = encoder_meta["query_head_dims"] - value_head_dims = encoder_meta["value_head_dims"] - num_heads = encoder_meta["num_heads"] - - def to_int_list(s): - return list(map(int, s.split(","))) - - num_encoder_layers = to_int_list(num_encoder_layers) - encoder_dims = to_int_list(encoder_dims) - cnn_module_kernels = to_int_list(cnn_module_kernels) - left_context_len = to_int_list(left_context_len) - query_head_dims = to_int_list(query_head_dims) - value_head_dims = to_int_list(value_head_dims) - num_heads = to_int_list(num_heads) - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - logging.info(f"num_encoder_layers: {num_encoder_layers}") - logging.info(f"encoder_dims: {encoder_dims}") - logging.info(f"cnn_module_kernels: {cnn_module_kernels}") - logging.info(f"left_context_len: {left_context_len}") - logging.info(f"query_head_dims: {query_head_dims}") - logging.info(f"value_head_dims: {value_head_dims}") - logging.info(f"num_heads: {num_heads}") - - num_encoders = len(num_encoder_layers) - - self.states = [] - for i in range(num_encoders): - num_layers = num_encoder_layers[i] - key_dim = query_head_dims[i] * num_heads[i] - embed_dim = encoder_dims[i] - nonlin_attn_head_dim = 3 * embed_dim // 4 - value_dim = value_head_dims[i] * num_heads[i] - conv_left_pad = cnn_module_kernels[i] // 2 - - for layer in range(num_layers): - cached_key = torch.zeros( - left_context_len[i], batch_size, key_dim - ).numpy() - cached_nonlin_attn = torch.zeros( - 1, batch_size, left_context_len[i], nonlin_attn_head_dim - ).numpy() - cached_val1 = torch.zeros( - left_context_len[i], batch_size, value_dim - ).numpy() - cached_val2 = torch.zeros( - left_context_len[i], batch_size, value_dim - ).numpy() - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() - self.states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() - self.states.append(embed_states) - processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() - self.states.append(processed_lens) - - self.num_encoders = num_encoders - - self.segment = T - self.offset = decode_chunk_len - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def _build_encoder_input_output( - self, - x: torch.Tensor, - ) -> Tuple[Dict[str, np.ndarray], List[str]]: - encoder_input = {"x": x.numpy()} - encoder_output = ["encoder_out"] - - def build_inputs_outputs(tensors, i): - assert len(tensors) == 6, len(tensors) - - # (downsample_left, batch_size, key_dim) - name = f"cached_key_{i}" - encoder_input[name] = tensors[0] - encoder_output.append(f"new_{name}") - - # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f"cached_nonlin_attn_{i}" - encoder_input[name] = tensors[1] - encoder_output.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val1_{i}" - encoder_input[name] = tensors[2] - encoder_output.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val2_{i}" - encoder_input[name] = tensors[3] - encoder_output.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv1_{i}" - encoder_input[name] = tensors[4] - encoder_output.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv2_{i}" - encoder_input[name] = tensors[5] - encoder_output.append(f"new_{name}") - - for i in range(len(self.states[:-2]) // 6): - build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) - - # (batch_size, channels, left_pad, freq) - name = "embed_states" - embed_states = self.states[-2] - encoder_input[name] = embed_states - encoder_output.append(f"new_{name}") - - # (batch_size,) - name = "processed_lens" - processed_lens = self.states[-1] - encoder_input[name] = processed_lens - encoder_output.append(f"new_{name}") - - return encoder_input, encoder_output - - def _update_states(self, states: List[np.ndarray]): - self.states = states - - def run_encoder(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - Returns: - Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 - """ - encoder_input, encoder_output_names = self._build_encoder_input_output(x) - - out = self.encoder.run(encoder_output_names, encoder_input) - - self._update_states(out[1:]) - - return torch.from_numpy(out[0]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - context_size: int, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -) -> List[int]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (1, T, joiner_dim) - context_size: - The context size of the decoder model. - decoder_out: - Optional. Decoder output of the previous chunk. - hyp: - Decoding results for previous chunks. - Returns: - Return the decoded results so far. - """ - - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - else: - assert hyp is not None, hyp - - encoder_out = encoder_out.squeeze(0) - T = encoder_out.size(0) - for t in range(T): - cur_encoder_out = encoder_out[t : t + 1] - joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor([decoder_input], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - - return hyp, decoder_out - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - sample_rate = 16000 - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - - logging.info(f"Reading sound files: {args.sound_file}") - waves = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=sample_rate, - )[0] - - tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) - wave_samples = torch.cat([waves, tail_padding]) - - num_processed_frames = 0 - segment = model.segment - offset = model.offset - - context_size = model.context_size - hyp = None - decoder_out = None - - chunk = int(1 * sample_rate) # 1 second - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, - ) - - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - frames = frames.unsqueeze(0) - encoder_out = model.run_encoder(frames) - hyp, decoder_out = greedy_search( - model, - encoder_out, - context_size, - decoder_out, - hyp, - ) - - token_table = k2.SymbolTable.from_file(args.tokens) - - text = "" - for i in hyp[context_size:]: - text += token_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/onnx_pretrained.py b/egs/swbd/ASR/zipformer/onnx_pretrained.py deleted file mode 100755 index 032b07721a..0000000000 --- a/egs/swbd/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1,418 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file - -./zipformer/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - token_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/optim.py b/egs/swbd/ASR/zipformer/optim.py deleted file mode 100644 index abfb2092cd..0000000000 --- a/egs/swbd/ASR/zipformer/optim.py +++ /dev/null @@ -1,1173 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - Unlike common optimizers, which accept model.parameters() or groups of parameters(), - this optimizer could accept model.named_parameters() or groups of named_parameters(). - See comments of function _get_names_of_parameters for its 4 possible cases. - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): - - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - # If params only contains parameters or group of parameters, - # i.e when parameter names are not given, - # this flag will be set to False in funciton _get_names_of_parameters. - self.show_dominant_parameters = True - param_groups, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(param_groups, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - - def _get_names_of_parameters( - self, params_or_named_params - ) -> Tuple[List[Dict], List[List[str]]]: - """ - Args: - params_or_named_params: according to the way ScaledAdam is initialized in train.py, - this argument could be one of following 4 cases, - case 1, a generator of parameter, e.g.: - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 2, a list of parameter groups with different config, e.g.: - model_param_groups = [ - {'params': model.encoder.parameters(), 'lr': 0.05}, - {'params': model.decoder.parameters(), 'lr': 0.01}, - {'params': model.joiner.parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) - - case 3, a generator of named_parameter, e.g.: - optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 4, a list of named_parameter groups with different config, e.g.: - model_named_param_groups = [ - {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, - {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, - {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) - - For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. - For case 3 and case 4, firstly, names and params are extracted from input named_params, - then, these extracted params are used to initialize the underlying torch.optimizer, - and these extracted names are mainly used by function - `_show_gradient_dominating_parameter` - - Returns: - Returns a tuple containing 2 elements: - - `param_groups` with type List[Dict], each Dict element is a parameter group. - An example of `param_groups` could be: - [ - {'params': `one iterable of Parameter`, 'lr': 0.05}, - {'params': `another iterable of Parameter`, 'lr': 0.08}, - {'params': `a third iterable of Parameter`, 'lr': 0.1}, - ] - - `param_gruops_names` with type List[List[str]], - each `List[str]` is for a group['params'] in param_groups, - and each `str` is the name of a parameter. - A dummy name "foo" is related to each parameter, - if input are params without names, i.e. case 1 or case 2. - """ - # variable naming convention in this function: - # p is short for param. - # np is short for named_param. - # p_or_np is short for param_or_named_param. - # cur is short for current. - # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. - # groups is a List[group] - - iterable_or_groups = list(params_or_named_params) - if len(iterable_or_groups) == 0: - raise ValueError("optimizer got an empty parameter list") - - # The first value of returned tuple. A list of dicts containing at - # least 'params' as a key. - param_groups = [] - - # The second value of returned tuple, - # a List[List[str]], each sub-List is for a group. - param_groups_names = [] - - if not isinstance(iterable_or_groups[0], dict): - # case 1 or case 3, - # the input is an iterable of parameter or named parameter. - param_iterable_cur_group = [] - param_names_cur_group = [] - for p_or_np in iterable_or_groups: - if isinstance(p_or_np, tuple): - # case 3 - name, param = p_or_np - else: - # case 1 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) - param_groups.append({"params": param_iterable_cur_group}) - param_groups_names.append(param_names_cur_group) - else: - # case 2 or case 4 - # the input is groups of parameter or named parameter. - for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] - del cur_group["named_params"] - cur_group["params"] = p_list - param_groups.append(cur_group) - param_groups_names.append(name_list) - - return param_groups, param_groups_names - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - - with self.batched_params(group["params"], group_params_names) as batches: - - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - if step % clipping_update_period == 0: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - quartiles = [] - for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - threshold = clipping_scale * median - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - return ans - - def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor - ): - """ - Show information of parameter which dominates tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 - # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( - dim=list(range(1, batch_grad.ndim)) - ) - - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.info( - f"Parameter dominating tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq={(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad = grad * clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.info( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - _test_scaled_adam(hidden_dim) - _test_eden() diff --git a/egs/swbd/ASR/zipformer/pretrained.py b/egs/swbd/ASR/zipformer/pretrained.py deleted file mode 100755 index 3104b60847..0000000000 --- a/egs/swbd/ASR/zipformer/pretrained.py +++ /dev/null @@ -1,381 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.utils import make_pad_mask - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/pretrained_ctc.py b/egs/swbd/ASR/zipformer/pretrained_ctc.py deleted file mode 100755 index 9dff2e6fc9..0000000000 --- a/egs/swbd/ASR/zipformer/pretrained_ctc.py +++ /dev/null @@ -1,455 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -(1) ctc-decoding -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method ctc-decoding \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) 1best -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --method 1best \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) nbest-rescoring -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method nbest-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - - -(4) whole-lattice-rescoring -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method whole-lattice-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from ctc_decode import get_decoding_params -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.utils import get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a token table, - i.e., lang_dir/tokens.txt, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an LM, the path with - the highest score is the decoding result. - We call it HLG decoding + nbest n-gram LM rescoring. - (3) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + whole-lattice n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or nbest-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and nbest-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help=""" - Used only when method is nbest-rescoring. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + 1 # +1 for blank - params.blank_id = token_table[""] - assert params.blank_id == 0 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - batch_size = ctc_output.shape[0] - supervision_segments = torch.tensor( - [ - [i, 0, feature_lengths[i].item() // params.subsampling_factor] - for i in range(batch_size) - ], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - max_token_id = params.vocab_size - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = [[token_table[i] for i in ids] for ids in token_ids] - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - G = G.to(device) - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - if params.method == "nbest-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=[params.ngram_lm_scale], - nbest_scale=params.nbest_scale, - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - if params.method == "ctc-decoding": - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/swbd/ASR/zipformer/profile.py b/egs/swbd/ASR/zipformer/profile.py deleted file mode 100755 index b460b53389..0000000000 --- a/egs/swbd/ASR/zipformer/profile.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: ./zipformer/profile.py -""" - -import argparse -import logging -import sentencepiece as spm -import torch - -from typing import Tuple -from torch import Tensor, nn - -from icefall.utils import make_pad_mask -from icefall.profiler import get_model_profile -from scaling import BiasNorm -from train import ( - get_encoder_embed, - get_encoder_model, - get_joiner_model, - add_model_arguments, - get_params, -) -from zipformer import BypassModule - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - add_model_arguments(parser) - - return parser - - -def _bias_norm_flops_compute(module, input, output): - assert len(input) == 1, len(input) - # estimate as layer_norm, see icefall/profiler.py - flops = input[0].numel() * 5 - module.__flops__ += int(flops) - - -def _swoosh_module_flops_compute(module, input, output): - # For SwooshL and SwooshR modules - assert len(input) == 1, len(input) - # estimate as swish/silu, see icefall/profiler.py - flops = input[0].numel() - module.__flops__ += int(flops) - - -def _bypass_module_flops_compute(module, input, output): - # For Bypass module - assert len(input) == 2, len(input) - flops = input[0].numel() * 2 - module.__flops__ += int(flops) - - -MODULE_HOOK_MAPPING = { - BiasNorm: _bias_norm_flops_compute, - BypassModule: _bypass_module_flops_compute, -} - - -class Model(nn.Module): - """A Wrapper for encoder, encoder_embed, and encoder_proj""" - - def __init__( - self, - encoder: nn.Module, - encoder_embed: nn.Module, - encoder_proj: nn.Module, - ) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: - x, x_lens = self.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) - - encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - logits = self.encoder_proj(encoder_out) - - return logits, encoder_out_lens - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - - # We only profile the encoder part - model = Model( - encoder=get_encoder_model(params), - encoder_embed=get_encoder_embed(params), - encoder_proj=get_joiner_model(params).encoder_proj, - ) - model.eval() - model.to(device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # for 30-second input - B, T, D = 1, 3000, 80 - feature = torch.ones(B, T, D, dtype=torch.float32).to(device) - feature_lens = torch.full((B,), T, dtype=torch.int64).to(device) - - flops, params = get_model_profile( - model=model, - args=(feature, feature_lens), - module_hoop_mapping=MODULE_HOOK_MAPPING, - ) - logging.info(f"For the encoder part, params: {params}, flops: {flops}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/zipformer/scaling.py b/egs/swbd/ASR/zipformer/scaling.py deleted file mode 100644 index 7c98ef0456..0000000000 --- a/egs/swbd/ASR/zipformer/scaling.py +++ /dev/null @@ -1,1840 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional, Tuple, Union -import logging -import k2 -from torch.cuda.amp import custom_fwd, custom_bwd -import random -import torch -import math -import torch.nn as nn -from torch import Tensor - -def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: - max_value = torch.max(x, y) - diff = torch.abs(x - y) - return max_value + torch.log1p(torch.exp(-diff)) - - -# RuntimeError: Exporting the operator logaddexp to ONNX opset version -# 14 is not supported. Please feel free to request support or submit -# a pull request on PyTorch GitHub. -# -# The following function is to solve the above error when exporting -# models to ONNX via torch.jit.trace() -def logaddexp(x: Tensor, y: Tensor) -> Tensor: - # Caution(fangjun): Put torch.jit.is_scripting() before - # torch.onnx.is_in_onnx_export(); - # otherwise, it will cause errors for torch.jit.script(). - # - # torch.logaddexp() works for both torch.jit.script() and - # torch.jit.trace() but it causes errors for ONNX export. - # - if torch.jit.is_scripting(): - # Note: We cannot use torch.jit.is_tracing() here as it also - # matches torch.onnx.export(). - return torch.logaddexp(x, y) - elif torch.onnx.is_in_onnx_export(): - return logaddexp_onnx(x, y) - else: - # for torch.jit.trace() - return torch.logaddexp(x, y) - -class PiecewiseLinear(object): - """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. - """ - def __init__(self, *args): - assert len(args) >= 1, len(args) - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: - assert isinstance(x, (float, int)), type(x) - assert isinstance(y, (float, int)), type(y) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) - - def max(self, x): - if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): - """ - Returns (self_mod, p_mod) which are equivalent piecewise linear - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p crosss. - """ - assert isinstance(p, PiecewiseLinear), type(p) - - # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specify the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or not in training mode or in - torch.jit scripting mode. - """ - def __init__(self, - *args, - default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' - - def __float__(self): - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) - else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) - else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = (x_abs < min_abs) - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 - - def __call__(self, x: float) -> bool: - """ - Returns true if x is above the cutoff. - """ - ans = (x > self.cutoff) - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) - return ans - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim=dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x ** 2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BiasNormFunction(torch.autograd.Function): - # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return x * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None - - -class BiasNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) - trainable scale on the output. - - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interpreted as an offset from the input's ndim if negative. - This is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. - """ - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False - ) -> None: - super(BiasNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max - - self.store_output_for_backprop = store_output_for_backprop - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) - return x * scales - - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) - - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) - - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) - - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) - - # first row is correction factors added to the scale near the left edge of the chunk, - # second row is correction factors added to the scale near the right edge of the chunk, - # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - # half_kernel_size = self.kernel_size + 1 // 2 - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Streaming Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - cache: cached left context of shape (batch_size, channels, left_pad) - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - # Pad cache - assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -left_pad:] - - x_causal = self.causal_conv(x) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size=seq_len) - x_chunk = x_chunk * chunk_scale - - return x_chunk + x_causal, cache - - -class BalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = (m_loss + r_loss) - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") - - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, - num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, - x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None - - -def with_loss(x, y, name): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return x - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - - -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = (y * (1 - s) + s) - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = (d * ((ceil - floor) / 255.0) + floor) - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - ans, = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - -class SwooshLFunction(torch.autograd.Function): - """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - - y.backward(gradient = torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - if not x.requires_grad: - return k2.swoosh_l_forward(x) - else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) - -class SwooshLOnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - - if not requires_grad: - return y - y.backward(gradient = torch.ones_like(y)) - - grad = x.grad - floor = -0.08 - ceil = 0.925 - - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class SwooshR(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - if not x.requires_grad: - return k2.swoosh_r_forward(x) - else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) - -class SwooshROnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) - - ctx.activation = activation - - forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) - if dropout_mask is not None: - x = x * dropout_mask - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved - - forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) - if dropout_mask is not None: - y = y * dropout_mask - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None - - -class ActivationDropoutAndLinear(torch.nn.Module): - """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). - """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) - - self.weight = l.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter('bias', l.bias) - - self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim - - def forward(self, - x: Tensor): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) - - return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) - - -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = Balancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_sign: x = ", x) - print("_test_balancer_sign: y grad = ", y_grad) - print("_test_balancer_sign: x grad = ", x.grad) - - -def _test_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) - x = x.detach() - x.requires_grad = True - m = Balancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - min_abs=0.2, - max_abs=0.7, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_magnitude: x = ", x) - print("_test_balancer_magnitude: y grad = ", y_grad) - print("_test_balancer_magnitude: x grad = ", x.grad) - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = ((1.2-(-0.043637))/255.0) - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshl_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshL() - - tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshr_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshR() - - tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) - for x in [-100, 0, 100]: - assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: - print("x, y = ", x, y) - assert p(x) == y, (x, p(x), y) - - q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] - pq = p.max(q) - for x in x_vals: - y1 = max(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p.min(q) - for x in x_vals: - y1 = min(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p + q - for x in x_vals: - y1 = p(x) + q(x) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - - -def _test_activation_dropout_and_linear(): - in_channels = 20 - out_channels = 30 - - for bias in [True, False]: - # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because we are using the k2 implementation of - # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() - # internally, messing up the random state. - for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) - with torch.no_grad(): - m2.weight[:] = m1[2].weight - if bias: - m2.bias[:] = m1[2].bias - # make sure forward gives same result. - x1 = torch.randn(10, in_channels) - x1.requires_grad = True - - # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) - - x2 = x1.clone().detach() - x2.requires_grad = True - seed = 10 - torch.manual_seed(seed) - y1 = m1(x1) - y_grad = torch.randn_like(y1) - y1.backward(gradient=y_grad) - torch.manual_seed(seed) - y2 = m2(x2) - y2.backward(gradient=y_grad) - - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") - print("y1 = ", y1) - print("y2 = ", y2) - assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) - if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) - print("x1.grad = ", x1.grad) - print("x2.grad = ", x2.grad) - - def isclose(a, b): - # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() - # the SwooshL() implementation has a noisy gradient due to 1-byte - # storage of it. - assert isclose(x1.grad, x2.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_piecewise_linear() - _test_softmax() - _test_whiten() - _test_balancer_sign() - _test_balancer_magnitude() - _test_double_swish_deriv() - _test_swooshr_deriv() - _test_swooshl_deriv() - _test_activation_dropout_and_linear() diff --git a/egs/swbd/ASR/zipformer/scaling_converter.py b/egs/swbd/ASR/zipformer/scaling_converter.py deleted file mode 100644 index 76622fa129..0000000000 --- a/egs/swbd/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List, Tuple - -import torch -import torch.nn as nn -from scaling import ( - Balancer, - Dropout3, - ScaleGrad, - SwooshL, - SwooshLOnnx, - SwooshR, - SwooshROnnx, - Whiten, -) -from zipformer import CompactRelPositionalEncoding - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, - is_onnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - is_onnx: - True if we are going to export the model for ONNX. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): - d[name] = nn.Identity() - elif is_onnx and isinstance(m, SwooshR): - d[name] = SwooshROnnx() - elif is_onnx and isinstance(m, SwooshL): - d[name] = SwooshLOnnx() - elif is_onnx and isinstance(m, CompactRelPositionalEncoding): - # We want to recreate the positional encoding vector when - # the input changes, so we have to use torch.jit.script() - # to replace torch.jit.trace() - d[name] = torch.jit.script(m) - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/swbd/ASR/zipformer/sclite_scoring.py b/egs/swbd/ASR/zipformer/sclite_scoring.py deleted file mode 100755 index 0383c4d71b..0000000000 --- a/egs/swbd/ASR/zipformer/sclite_scoring.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Jiayu Du -# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import os - -conversational_filler = [ - "UH", - "UHH", - "UM", - "EH", - "MM", - "HM", - "AH", - "HUH", - "HA", - "ER", - "OOF", - "HEE", - "ACH", - "EEE", - "EW", - "MHM", - "HUM", - "AW", - "OH", - "HMM", - "UMM", -] -unk_tags = ["", ""] -switchboard_garbage_utterance_tags = [ - "[LAUGHTER]", - "[NOISE]", - "[VOCALIZED-NOISE]", - "[SILENCE]", -] -non_scoring_words = ( - conversational_filler + unk_tags + switchboard_garbage_utterance_tags -) - - -def asr_text_post_processing(text: str) -> str: - # 1. convert to uppercase - text = text.upper() - - # 2. remove non-scoring words from evaluation - remaining_words = [] - text_split = text.split() - word_to_skip = 0 - for idx, word in enumerate(text_split): - if word_to_skip > 0: - word_to_skip -= 1 - continue - if word in non_scoring_words: - continue - elif word == "CANCELLED": - remaining_words.append("CANCELED") - continue - elif word == "AIRFLOW": - remaining_words.append("AIR") - remaining_words.append("FLOW") - continue - elif word == "PHD": - remaining_words.append("P") - remaining_words.append("H") - remaining_words.append("D") - continue - elif word == "UCLA": - remaining_words.append("U") - remaining_words.append("C") - remaining_words.append("L") - remaining_words.append("A") - continue - elif word == "ONTO": - remaining_words.append("ON") - remaining_words.append("TO") - continue - elif word == "DAY": - try: - if text_split[idx + 1] == "CARE": - remaining_words.append("DAYCARE") - word_to_skip = 1 - except: - remaining_words.append(word) - continue - remaining_words.append(word) - - return " ".join(remaining_words) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result via" - "SCTK's tool sclite" - ) - parser.add_argument( - "ref", - type=str, - help="sclite's standard transcription(trn) reference file", - ) - parser.add_argument( - "hyp", - type=str, - help="sclite's standard transcription(trn) hypothesis file", - ) - parser.add_argument( - "work_dir", - type=str, - help="working dir", - ) - args = parser.parse_args() - - if not os.path.isdir(args.work_dir): - os.mkdir(args.work_dir) - - REF = os.path.join(args.work_dir, "REF") - HYP = os.path.join(args.work_dir, "HYP") - RESULT = os.path.join(args.work_dir, "RESULT") - - for io in [(args.ref, REF), (args.hyp, HYP)]: - with open(io[0], "r", encoding="utf8") as fi: - with open(io[1], "w+", encoding="utf8") as fo: - for line in fi: - line = line.strip() - if line: - cols = line.split() - text = asr_text_post_processing(" ".join(cols[0:-1])) - uttid_field = cols[-1] - print(f"{text} {uttid_field}", file=fo) - - # GigaSpeech's uttid comforms to swb - os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/zipformer/streaming_beam_search.py b/egs/swbd/ASR/zipformer/streaming_beam_search.py deleted file mode 100644 index 3c8565b330..0000000000 --- a/egs/swbd/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - blank_penalty: float = 0.0, -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - if blank_penalty != 0.0: - logits[:, 0] -= blank_penalty - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, - blank_penalty: float = 0.0, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - if blank_penalty != 0.0: - logits[:, 0] -= blank_penalty - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, - blank_penalty: float = 0.0, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - - if blank_penalty != 0.0: - logits[:, 0] -= blank_penalty - - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/swbd/ASR/zipformer/streaming_decode.py b/egs/swbd/ASR/zipformer/streaming_decode.py deleted file mode 100755 index 44ff392a3a..0000000000 --- a/egs/swbd/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,876 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import LibriSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/subsampling.py b/egs/swbd/ASR/zipformer/subsampling.py deleted file mode 100644 index 6532ddccb6..0000000000 --- a/egs/swbd/ASR/zipformer/subsampling.py +++ /dev/null @@ -1,410 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple -import warnings - -import torch -from torch import Tensor, nn -from scaling import ( - Balancer, - BiasNorm, - Dropout3, - FloatLike, - Optional, - ScaledConv2d, - ScaleGrad, - ScheduledFloat, - SwooshL, - SwooshR, - Whiten, -) - - -class ConvNeXt(nn.Module): - """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf - """ - - def __init__( - self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None, - ): - super().__init__() - self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=self.padding, - ) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1 - ) - - self.hidden_balancer = Balancer( - hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - initial_scale=0.01, - ) - - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) - > layerdrop_rate - ) - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - # return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - def forward_internal( - self, x: Tensor, layer_skip_mask: Optional[Tensor] = None - ) -> Tensor: - """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. - """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - if layer_skip_mask is not None: - x = x * layer_skip_mask - - x = bypass + x - x = self.out_balancer(x) - - if x.requires_grad: - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) - - return x - - def streaming_forward( - self, - x: Tensor, - cached_left_pad: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) - - Returns: - - The returned value has the same shape as x. - - Updated cached_left_pad. - """ - padding = self.padding - - # The length without right padding for depth-wise conv - T = x.size(2) - padding[0] - - bypass = x[:, :, :T, :] - - # Pad left side - assert cached_left_pad.size(2) == padding[0], ( - cached_left_pad.size(2), - padding[0], - ) - x = torch.cat([cached_left_pad, x], dim=2) - # Update cached left padding - cached_left_pad = x[:, :, T : padding[0] + T, :] - - # depthwise_conv - x = torch.nn.functional.conv2d( - x, - weight=self.depthwise_conv.weight, - bias=self.depthwise_conv.bias, - padding=(0, padding[1]), - groups=self.depthwise_conv.groups, - ) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - x = bypass + x - return x, cached_left_pad - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite - """ - assert in_channels >= 7 - super().__init__() - - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, max_abs=1.0), - SwooshR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - Balancer(layer2_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - Balancer(layer3_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - ) - - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - # (in_channels-3)//4 - self.out_width = (((in_channels - 1) // 2) - 1) // 2 - self.layer3_channels = layer3_channels - - self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), - prob=(0.025, 0.25), - grad_scale=0.02, - ) - - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, (T-7)//2, odim) - - output lengths, of shape (batch_size,) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, (T-7)//2, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, (T-7)//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - x_lens = (x_lens - 7) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) - - return x, x_lens - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - cached_left_pad: Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, (T-7)//2, odim) - - output lengths, of shape (batch_size,) - - updated cache - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - - # T' = (T-7)//2 - x = self.conv(x) - - # T' = (T-7)//2-3 - x, cached_left_pad = self.convnext.streaming_forward( - x, cached_left_pad=cached_left_pad - ) - - # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, T', out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, T', odim) - x = self.out_norm(x) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert self.convnext.padding[0] == 3 - # The ConvNeXt module needs 3 frames of right padding after subsampling - x_lens = (x_lens - 7) // 2 - 3 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # The ConvNeXt module needs 3 frames of right padding after subsampling - assert self.convnext.padding[0] == 3 - x_lens = (x_lens - 7) // 2 - 3 - - assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) - - return x, x_lens, cached_left_pad - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> Tensor: - """Get initial states for Conv2dSubsampling module. - It is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - """ - left_pad = self.convnext.padding[0] - freq = self.out_width - channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) - - return cached_embed_left_pad diff --git a/egs/swbd/ASR/zipformer/test_scaling.py b/egs/swbd/ASR/zipformer/test_scaling.py deleted file mode 100755 index 5c04291e73..0000000000 --- a/egs/swbd/ASR/zipformer/test_scaling.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt -import torch -from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR - - -def test_piecewise_linear(): - # An identity map in the range [0, 1]. - # 1 - identity map in the range [1, 2] - # x1=0, y1=0 - # x2=1, y2=1 - # x3=2, y3=0 - pl = PiecewiseLinear((0, 0), (1, 1), (2, 0)) - assert pl(0.25) == 0.25, pl(0.25) - assert pl(0.625) == 0.625, pl(0.625) - assert pl(1.25) == 0.75, pl(1.25) - - assert pl(-10) == pl(0), pl(-10) # out of range - assert pl(10) == pl(2), pl(10) # out of range - - # multiplication - pl10 = pl * 10 - assert pl10(1) == 10 * pl(1) - assert pl10(0.5) == 10 * pl(0.5) - - -def test_scheduled_float(): - # Initial value is 0.2 and it decreases linearly towards 0 at 4000 - dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0) - dropout.batch_count = 0 - assert float(dropout) == 0.2, (float(dropout), dropout.batch_count) - - dropout.batch_count = 1000 - assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count) - - dropout.batch_count = 2000 - assert float(dropout) == 0.1, (float(dropout), dropout.batch_count) - - dropout.batch_count = 3000 - assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count) - - dropout.batch_count = 4000 - assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) - - dropout.batch_count = 5000 # out of range - assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) - - -def test_swoosh(): - x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32) - x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32) - x = torch.cat([x1, x2[1:]]) - - left = SwooshL()(x) - r = SwooshR()(x) - - relu = torch.nn.functional.relu(x) - print(left[x == 0], r[x == 0]) - plt.plot(x, left, "k") - plt.plot(x, r, "r") - plt.plot(x, relu, "b") - plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax] - plt.legend( - [ - "SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ", - "SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687", - "ReLU(x) = max(0, x)", - ] - ) - plt.grid() - plt.savefig("swoosh.pdf") - - -def main(): - test_piecewise_linear() - test_scheduled_float() - test_swoosh() - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/test_subsampling.py b/egs/swbd/ASR/zipformer/test_subsampling.py deleted file mode 100755 index 078227fb68..0000000000 --- a/egs/swbd/ASR/zipformer/test_subsampling.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling - - -def test_conv2d_subsampling(): - layer1_channels = 8 - layer2_channels = 32 - layer3_channels = 128 - - out_channels = 192 - encoder_embed = Conv2dSubsampling( - in_channels=80, - out_channels=out_channels, - layer1_channels=layer1_channels, - layer2_channels=layer2_channels, - layer3_channels=layer3_channels, - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - N = 2 - T = 200 - num_features = 80 - x = torch.rand(N, T, num_features) - x_copy = x.clone() - - x = x.unsqueeze(1) # (N, 1, T, num_features) - - x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1) - assert x.shape == (N, layer1_channels, T - 2, num_features) - # (2, 8, 198, 80) - - x = encoder_embed.conv[1](x) # scale grad - x = encoder_embed.conv[2](x) # balancer - x = encoder_embed.conv[3](x) # swooshR - - x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2 - assert x.shape == ( - N, - layer2_channels, - ((T - 2) - 3) // 2 + 1, - (num_features - 3) // 2 + 1, - ) - # (2, 32, 98, 39) - - x = encoder_embed.conv[5](x) # balancer - x = encoder_embed.conv[6](x) # swooshR - - # conv2d: - # in 32, out 128, kernel 3, stride (1, 2) - x = encoder_embed.conv[7](x) - assert x.shape == ( - N, - layer3_channels, - (((T - 2) - 3) // 2 + 1) - 2, - (((num_features - 3) // 2 + 1) - 3) // 2 + 1, - ) - # (2, 128, 96, 19) - - x = encoder_embed.conv[8](x) # balancer - x = encoder_embed.conv[9](x) # swooshR - - # (((T - 2) - 3) // 2 + 1) - 2 - # = (T - 2) - 3) // 2 + 1 - 2 - # = ((T - 2) - 3) // 2 - 1 - # = (T - 2 - 3) // 2 - 1 - # = (T - 5) // 2 - 1 - # = (T - 7) // 2 - assert x.shape[2] == (x_copy.shape[1] - 7) // 2 - - # (((num_features - 3) // 2 + 1) - 3) // 2 + 1, - # = ((num_features - 3) // 2 + 1 - 3) // 2 + 1, - # = ((num_features - 3) // 2 - 2) // 2 + 1, - # = (num_features - 3 - 4) // 2 // 2 + 1, - # = (num_features - 7) // 2 // 2 + 1, - # = (num_features - 7) // 4 + 1, - # = (num_features - 3) // 4 - assert x.shape[3] == (x_copy.shape[2] - 3) // 4 - - assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) - - # Input shape to convnext is - # - # (N, layer3_channels, (T-7)//2, (num_features - 3)//4) - - # conv2d: in layer3_channels, out layer3_channels, groups layer3_channels - # kernel_size 7, padding 3 - x = encoder_embed.convnext.depthwise_conv(x) - assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) - - # conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1 - x = encoder_embed.convnext.pointwise_conv1(x) - assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4) - - x = encoder_embed.convnext.hidden_balancer(x) # balancer - x = encoder_embed.convnext.activation(x) # swooshL - - # conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1 - x = encoder_embed.convnext.pointwise_conv2(x) - assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) - - # bypass and layer drop, omitted here. - x = encoder_embed.convnext.out_balancer(x) - - # Note: the input and output shape of ConvNeXt are the same - - x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1) - assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4)) - - x = encoder_embed.out(x) - assert x.shape == (N, (T - 7) // 2, out_channels) - - x = encoder_embed.out_whiten(x) - x = encoder_embed.out_norm(x) - # final layer is dropout - - # test streaming forward - - subsampling_factor = 2 - cached_left_padding = encoder_embed.get_init_states(batch_size=N) - depthwise_conv_kernel_size = 7 - pad_size = (depthwise_conv_kernel_size - 1) // 2 - - assert cached_left_padding.shape == ( - N, - layer3_channels, - pad_size, - (num_features - 3) // 4, - ) - - chunk_size = 16 - right_padding = pad_size * subsampling_factor - T = chunk_size * subsampling_factor + 7 + right_padding - x = torch.rand(N, T, num_features) - x_lens = torch.tensor([T] * N) - y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward( - x, x_lens, cached_left_padding - ) - - assert y.shape == (N, chunk_size, out_channels), y.shape - assert next_cached_left_padding.shape == cached_left_padding.shape - - assert y.shape[1] == y_lens[0] == y_lens[1] - - -def main(): - test_conv2d_subsampling() - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/train.py b/egs/swbd/ASR/zipformer/train.py deleted file mode 100755 index a4cac01142..0000000000 --- a/egs/swbd/ASR/zipformer/train.py +++ /dev/null @@ -1,1383 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --max-duration 1000 - -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import SwitchBoardAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import AsrModel -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if not params.use_transducer: - params.ctc_loss_scale = 1.0 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2**22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - switchboard = SwitchBoardAsrDataModule(args) - - train_cuts = switchboard.train_all_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = switchboard.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = switchboard.dev_cuts() - valid_dl = switchboard.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - SwitchBoardAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/zipformer/zipformer.py b/egs/swbd/ASR/zipformer/zipformer.py deleted file mode 100644 index b39af02b8a..0000000000 --- a/egs/swbd/ASR/zipformer/zipformer.py +++ /dev/null @@ -1,2254 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import warnings -from typing import List, Optional, Tuple, Union -import logging -import torch -import random -from encoder_interface import EncoderInterface -from scaling import ( - Balancer, - BiasNorm, - Dropout2, - ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - penalize_abs_values_gt, - softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, -) -from torch import Tensor, nn - - -class Zipformer2(EncoderInterface): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - """ - def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), - (20000.0, 0.1)) - - def _to_tuple(x): - """ Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - - for u,d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - pos_dim=pos_dim, - dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - - self.downsample_output = SimpleDownsample(max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout) - - def get_feature_masks( - self, - x: Tensor) -> Union[List[float], List[Tensor]]: - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [ 1.0 ] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0) - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and(mask1, - (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) - - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones(1, batch_size, channels, - dtype=x.dtype, device=x.device) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] - - feature_masks.append(feature_mask) - - return feature_masks - - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - - def forward( - self, x: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - outputs = [] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) - - chunk_size, left_context_chunks = self.get_chunk_info() - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x = module(x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=(None if src_key_padding_mask is None - else src_key_padding_mask[...,::ds]), - attn_mask=attn_mask) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, - chunk_size: int, - left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all (chunk_size * left_context_chunks >= - (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders)) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, - src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [ outputs[-1] ] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - outputs = [] - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) - cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] - - return states - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), - (20000.0, ratio * x), - default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), - ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, - straight_through_rate=0) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim=pos_dim, num_heads=num_heads, - query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.feed_forward1 = FeedforwardModule(embed_dim, - (feedforward_dim * 3) // 4, - dropout) - - self.feed_forward2 = FeedforwardModule(embed_dim, - feedforward_dim, - dropout) - - self.feed_forward3 = FeedforwardModule(embed_dim, - (feedforward_dim * 5) // 4, - dropout) - - self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4) - - self.conv_module1 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - self.conv_module2 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.2, max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.balancer2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.1, max_abs=4.0, - ) - - def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting() or torch.jit.is_tracing(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif not self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) - selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), - ff2_skip_rate) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), - ff3_skip_rate) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """Pass the input through the encoder layer in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, - of shape (batch_size, channels, left_pad) - left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - x, with the same shape as src - - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, - ) - src = src + na - - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm(src) - - src = self.bypass(src_orig, src) - - return ( - src, - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.0) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1. / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - output = src - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask - - return output - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_states = [] - for i, mod in enumerate(self.layers): - ( - cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 - ) = states[i * 6: (i + 1) * 6] - ( - output, - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2 - ) = mod.streaming_forward( - output, - pos_emb, - cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - new_states += [ - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ] - - return output, new_states - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing - # this module. - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value(self.bypass_scale, - min=float(self.scale_min), - max=float(self.scale_max)) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, - src_orig: Tensor, - src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - def __init__(self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, - downsample, dropout) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds,::ds] - - src = self.encoder( - src, - chunk_size=chunk_size // ds, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Downsample, go through encoder, upsample, in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); - True means masked position. May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - src_orig = src - src = self.downsample(src) - - src, new_states = self.encoder.streaming_forward( - src, - states=states, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - - return self.out_combiner(src_orig, src), new_states - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - def __init__(self, - channels: int, - downsample: int, - dropout: FloatLike): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - def __init__(self, - num_channels: int, - upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - def __init__( - self, embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0 - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0 - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T-1), T, - device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = (self.embed_dim ** 0.5) - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - : - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.0)) - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=query_head_dim**-0.25) - - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer(key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(pos_dim, - num_heads * pos_head_dim, - bias=False, - initial_scale=0.05) - - # the following are for diagnosics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] - # p is the position-encoding query - p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] - # p is the position-encoding query - p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim - - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) - k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(k_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - attn_weights = attn_scores.softmax(dim=-1) - - return attn_weights, cached_key - - def _print_attn_entropy( - self, - attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).mean(dim=(1,2)) - logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, - num_heads * value_head_dim, - bias=True) - - self.out_proj = ScaledLinear(num_heads * value_head_dim, - embed_dim, bias=True, - initial_scale=0.05) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model. - """ - def __init__(self, - embed_dim: int, - feedforward_dim: int, - dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, - activation='SwooshL', - dropout_p=dropout, - dropout_shared_dim=0, bias=True, - initial_scale=0.1) - - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear(hidden_channels, channels, - bias=True, - initial_scale=0.05) - - self.whiten1 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.whiten2 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - def __init__( - self, channels: int, kernel_size: int, causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=bottleneck_dim, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2) - - self.balancer2 = Balancer( - bottleneck_dim, channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, channels, activation='SwooshR', - dropout_p=0.0, initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0: - # Not support exporting a model for simulated streaming decoding - assert self.causal, "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - - c = Zipformer2( - encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,) - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) From 54715991b63760a971cb23b1b1629981b28c67c2 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:52:40 +0800 Subject: [PATCH 45/60] final refinement --- egs/swbd/ASR/README.md | 7 - egs/swbd/ASR/conformer_ctc/conformer.py | 910 ----------------- egs/swbd/ASR/conformer_ctc/label_smoothing.py | 109 -- egs/swbd/ASR/conformer_ctc/subsampling.py | 153 --- .../ASR/conformer_ctc/test_transformer.py | 104 -- egs/swbd/ASR/conformer_ctc/transformer.py | 928 ------------------ 6 files changed, 2211 deletions(-) delete mode 100644 egs/swbd/ASR/conformer_ctc/conformer.py delete mode 100644 egs/swbd/ASR/conformer_ctc/label_smoothing.py delete mode 100644 egs/swbd/ASR/conformer_ctc/subsampling.py delete mode 100644 egs/swbd/ASR/conformer_ctc/test_transformer.py delete mode 100644 egs/swbd/ASR/conformer_ctc/transformer.py diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md index 10e38c501c..13b27815a6 100644 --- a/egs/swbd/ASR/README.md +++ b/egs/swbd/ASR/README.md @@ -6,13 +6,6 @@ Switchboard is a collection of about 2,400 two-sided telephone conversations amo (The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) -**Caution**: The `conformer_ctc` recipe for Switchboard is currently very rough and produces a high Word Error Rate, requiring more improvement and refinement. The TODO list for this recipe is as follows. - -## TODO List -- [x] Incorporate Lhotse for data processing -- [x] Further text normalization -- [x] Detailed Word Error Rate summary for eval2000 (callhome, swbd) and rt03 (fsh, swbd) testset -- [x] Switchboard transcript train/dev split for LM training ## Performance Record | | eval2000 | rt03 | diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py deleted file mode 100644 index a1cfe6e75b..0000000000 --- a/egs/swbd/ASR/conformer_ctc/conformer.py +++ /dev/null @@ -1,910 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask - - -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - """ - - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: Union[float, bool] = 0.1, - ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - use_conv_batchnorm = True - if isinstance(use_feat_batchnorm, float): - use_conv_batchnorm = False - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - normalize_before, - use_conv_batchnorm, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity - - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - - if self.normalize_before: - x = self.after_norm(x) - - return x, mask - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - use_conv_batchnorm: bool = False, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm - ) - - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - self.normalize_before = normalize_before - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) - if not self.normalize_before: - src = self.norm_ff_macaron(src) - - # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) - - # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout( - self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - ) - if not self.normalize_before: - src = self.norm_conv(src) - - # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - - if self.normalize_before: - src = self.norm_final(src) - - return src - - -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for mod in self.layers: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - use_batchnorm: bool = False, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - self.use_batchnorm = use_batchnorm - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - if self.use_batchnorm: - self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) - if self.use_batchnorm: - x = self.norm(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py deleted file mode 100644 index 52d2eda3bb..0000000000 --- a/egs/swbd/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" - assert reduction in ("none", "sum", "mean"), reduction - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - - # See https://github.com/k2-fsa/icefall/issues/240 - # and https://github.com/k2-fsa/icefall/issues/297 - # for why we don't use target[ignored] = 0 here - target = torch.where(ignored, torch.zeros_like(target), target) - - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes - ) - - # Set the value of ignored indexes to 0 - # - # See https://github.com/k2-fsa/icefall/issues/240 - # and https://github.com/k2-fsa/icefall/issues/297 - # for why we don't use true_dist[ignored] = 0 here - true_dist = torch.where( - ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), - torch.zeros_like(true_dist), - true_dist, - ) - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py deleted file mode 100644 index 8e0f73d05b..0000000000 --- a/egs/swbd/ASR/conformer_ctc/subsampling.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__(self, idim: int, odim: int) -> None: - """ - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - assert idim >= 7 - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), - nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py deleted file mode 100644 index 667057c513..0000000000 --- a/egs/swbd/ASR/conformer_ctc/test_transformer.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from torch.nn.utils.rnn import pad_sequence -from transformer import ( - Transformer, - add_eos, - add_sos, - decoder_padding_mask, - encoder_padding_mask, - generate_square_subsequent_mask, -) - - -def test_encoder_padding_mask(): - supervisions = { - "sequence_idx": torch.tensor([0, 1, 2]), - "start_frame": torch.tensor([0, 0, 0]), - "num_frames": torch.tensor([18, 7, 13]), - } - - max_len = ((18 - 1) // 2 - 1) // 2 - mask = encoder_padding_mask(max_len, supervisions) - expected_mask = torch.tensor( - [ - [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, - [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, - [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_transformer(): - num_features = 40 - num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) - - N = 31 - - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) - - -def test_generate_square_subsequent_mask(): - s = 5 - mask = generate_square_subsequent_mask(s) - inf = float("inf") - expected_mask = torch.tensor( - [ - [0.0, -inf, -inf, -inf, -inf], - [0.0, 0.0, -inf, -inf, -inf], - [0.0, 0.0, 0.0, -inf, -inf], - [0.0, 0.0, 0.0, 0.0, -inf], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_decoder_padding_mask(): - x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] - y = pad_sequence(x, batch_first=True, padding_value=-1) - mask = decoder_padding_mask(y, ignore_id=-1) - expected_mask = torch.tensor( - [ - [False, False, True], - [False, True, True], - [False, False, False], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_add_sos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_sos(x, sos_id=0) - expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] - assert y == expected_y - - -def test_add_eos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_eos(x, eos_id=0) - expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] - assert y == expected_y diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py deleted file mode 100644 index 0566cfc81c..0000000000 --- a/egs/swbd/ASR/conformer_ctc/transformer.py +++ /dev/null @@ -1,928 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling, VggSubsampling -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class Transformer(nn.Module): - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: Union[float, bool] = 0.1, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - Number of encoder layers. - num_decoder_layers: - Number of decoder layers. - dropout: - Dropout in encoder/decoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. - Float value to scale the input layer. - False to do nothing. - """ - super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - assert isinstance(use_feat_batchnorm, (float, bool)) - if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) - - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_classes) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) - - if num_decoder_layers > 0: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol - - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None - - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) - - self.decoder_criterion = LabelSmoothingLoss() - else: - self.decoder_criterion = None - - def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C). - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is (N, T, C) - - Encoder output with shape (T, N, C). It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is (N, T). - It is None if `supervision` is None. - """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - if isinstance(self.use_feat_batchnorm, float): - x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - Returns: - Return a tuple with two tensors: - - The encoder output, with shape (T, N, C) - - encoder padding mask, with shape (N, T). - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - - return x, mask - - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is (T, N, C) - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is (N, T, C) - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - @torch.jit.export - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerDecoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: - the sequence to the decoder layer (required). - memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). - - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), - True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask - - -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist()) From 903e37e7861053e434abf027e9ed0037ec62ebcd Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:54:04 +0800 Subject: [PATCH 46/60] replaced several scripts with symlinks --- egs/swbd/ASR/conformer_ctc/conformer.py | 1 + egs/swbd/ASR/conformer_ctc/label_smoothing.py | 1 + egs/swbd/ASR/conformer_ctc/subsampling.py | 1 + egs/swbd/ASR/conformer_ctc/test_transformer.py | 1 + egs/swbd/ASR/conformer_ctc/transformer.py | 1 + 5 files changed, 5 insertions(+) create mode 120000 egs/swbd/ASR/conformer_ctc/conformer.py create mode 120000 egs/swbd/ASR/conformer_ctc/label_smoothing.py create mode 120000 egs/swbd/ASR/conformer_ctc/subsampling.py create mode 120000 egs/swbd/ASR/conformer_ctc/test_transformer.py create mode 120000 egs/swbd/ASR/conformer_ctc/transformer.py diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py new file mode 120000 index 0000000000..d1f4209d72 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 0000000000..e9d239fffb --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py new file mode 120000 index 0000000000..16354dc738 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 0000000000..8b0990ec62 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py new file mode 120000 index 0000000000..1c3f43fcf7 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file From 01b52165a107081c0ab2475d3442d08981446e77 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:57:40 +0800 Subject: [PATCH 47/60] removed scripts existing in other recipes --- egs/swbd/ASR/local/compile_hlg.py | 167 ------- egs/swbd/ASR/local/compile_lg.py | 147 ------- egs/swbd/ASR/local/prepare_lang.py | 413 ------------------ .../ASR/local/prepare_lm_training_data.py | 167 ------- egs/swbd/ASR/local/validate_bpe_lexicon.py | 77 ---- 5 files changed, 971 deletions(-) delete mode 100755 egs/swbd/ASR/local/compile_hlg.py delete mode 100755 egs/swbd/ASR/local/compile_lg.py delete mode 100755 egs/swbd/ASR/local/prepare_lang.py delete mode 100755 egs/swbd/ASR/local/prepare_lm_training_data.py delete mode 100755 egs/swbd/ASR/local/validate_bpe_lexicon.py diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py deleted file mode 100755 index d19d50ae63..0000000000 --- a/egs/swbd/ASR/local/compile_hlg.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_n_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - lm: - The language stem base name. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"data/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"data/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"data/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir, args.lm) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py deleted file mode 100755 index 709b140702..0000000000 --- a/egs/swbd/ASR/local/compile_lg.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates LG from - - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_3_gram.fst.txt - -The generated LG is saved in $lang_dir/LG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - - return parser.parse_args() - - -def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing LG. - """ - lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"data/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"data/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"data/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - # LG.labels[LG.labels >= first_token_disambig_id] = 0 - # see https://github.com/k2-fsa/k2/pull/1140 - labels = LG.labels - labels[labels >= first_token_disambig_id] = 0 - LG.labels = labels - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - return LG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "LG.pt").is_file(): - logging.info(f"{lang_dir}/LG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - LG = compile_LG(lang_dir, args.lm) - logging.info(f"Saving LG.pt to {lang_dir}") - torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py deleted file mode 100755 index d913756a19..0000000000 --- a/egs/swbd/ASR/local/prepare_lang.py +++ /dev/null @@ -1,413 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py deleted file mode 100755 index 70343fef7c..0000000000 --- a/egs/swbd/ASR/local/prepare_lm_training_data.py +++ /dev/null @@ -1,167 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey -# Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes a `bpe.model` and a text file such as -./download/lm/librispeech-lm-norm.txt -and outputs the LM training data to a supplied directory such -as data/lm_training_bpe_500. The format is as follows: - -It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a -representation of a dict with the following format: - - 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 - containing the BPE representations of each word, indexed by - integer word ID. (These integer word IDS are present in - 'lm_data'). The sentencepiece object can be used to turn the - words and BPE units into string form. - 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype - torch.int32 containing all the sentences, as word-ids (we don't - output the string form of this directly but it can be worked out - together with 'words' and the bpe.model). - 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing - number of BPE tokens of each sentence. -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import sentencepiece as spm -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--bpe-model", - type=str, - help="Input BPE model, e.g. data/bpe_500/bpe.model", - ) - parser.add_argument( - "--lm-data", - type=str, - help="""Input LM training data as text, e.g. - download/pb.train.txt""", - ) - parser.add_argument( - "--lm-archive", - type=str, - help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; - look at the source of this script to see the format.""", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - if Path(args.lm_archive).exists(): - logging.warning(f"{args.lm_archive} exists - skipping") - return - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - # word2index is a dictionary from words to integer ids. No need to reserve - # space for epsilon, etc.; the words are just used as a convenient way to - # compress the sequences of BPE pieces. - word2index = dict() - - word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. - sentences = [] # Will be a list-of-list-of-int, representing word-ids. - - if "librispeech-lm-norm" in args.lm_data: - num_lines_in_total = 40418261.0 - step = 5000000 - elif "valid" in args.lm_data: - num_lines_in_total = 5567.0 - step = 3000 - elif "test" in args.lm_data: - num_lines_in_total = 5559.0 - step = 3000 - else: - num_lines_in_total = None - step = None - - processed = 0 - - with open(args.lm_data) as f: - while True: - line = f.readline() - if line == "": - break - - if step and processed % step == 0: - logging.info( - f"Processed number of lines: {processed} " - f"({processed/num_lines_in_total*100: .3f}%)" - ) - processed += 1 - - line_words = line.split() - for w in line_words: - if w not in word2index: - w_bpe = sp.encode(w) - word2index[w] = len(word2bpe) - word2bpe.append(w_bpe) - sentences.append([word2index[w] for w in line_words]) - - logging.info("Constructing ragged tensors") - words = k2.ragged.RaggedTensor(word2bpe) - sentences = k2.ragged.RaggedTensor(sentences) - - output = dict(words=words, sentences=sentences) - - num_sentences = sentences.dim0 - logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") - sentence_lengths = [0] * num_sentences - for i in range(num_sentences): - if step and i % step == 0: - logging.info( - f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" - ) - - word_ids = sentences[i] - - # NOTE: If word_ids is a tensor with only 1 entry, - # token_ids is a torch.Tensor - token_ids = words[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - - # token_ids is a 1-D tensor containing the BPE tokens - # of the current sentence - - sentence_lengths[i] = token_ids.numel() - - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) - - torch.save(output, args.lm_archive) - logging.info(f"Saved to {args.lm_archive}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py deleted file mode 100755 index c542f2fabd..0000000000 --- a/egs/swbd/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks that there are no OOV tokens in the BPE-based lexicon. - -Usage example: - - python3 ./local/validate_bpe_lexicon.py \ - --lexicon /path/to/lexicon.txt \ - --bpe-model /path/to/bpe.model -""" - -import argparse -from pathlib import Path -from typing import List, Tuple - -import sentencepiece as spm - -from icefall.lexicon import read_lexicon - -# Map word to word pieces -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--lexicon", - required=True, - type=Path, - help="Path to lexicon.txt", - ) - - parser.add_argument( - "--bpe-model", - required=True, - type=Path, - help="Path to bpe.model", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - assert args.lexicon.is_file(), args.lexicon - assert args.bpe_model.is_file(), args.bpe_model - - lexicon = read_lexicon(args.lexicon) - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size())))) - for word, pieces in lexicon: - for p in pieces: - if p not in word_pieces: - raise ValueError(f"The word {word} contains an OOV token {p}") - - -if __name__ == "__main__": - main() From 29a02f3fc538c4e9390999f69caa2b2c71c00e5e Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:00:11 +0800 Subject: [PATCH 48/60] replaced with symlinks --- egs/swbd/ASR/local/compile_hlg.py | 1 + egs/swbd/ASR/local/compile_lg.py | 1 + egs/swbd/ASR/local/prepare_lang.py | 1 + egs/swbd/ASR/local/prepare_lm_training_data.py | 1 + egs/swbd/ASR/local/validate_bpe_lexicon.py | 1 + 5 files changed, 5 insertions(+) create mode 120000 egs/swbd/ASR/local/compile_hlg.py create mode 120000 egs/swbd/ASR/local/compile_lg.py create mode 120000 egs/swbd/ASR/local/prepare_lang.py create mode 120000 egs/swbd/ASR/local/prepare_lm_training_data.py create mode 120000 egs/swbd/ASR/local/validate_bpe_lexicon.py diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py new file mode 120000 index 0000000000..471aa7fb40 --- /dev/null +++ b/egs/swbd/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py new file mode 120000 index 0000000000..462d6d3fb9 --- /dev/null +++ b/egs/swbd/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py new file mode 120000 index 0000000000..747f2ab398 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py new file mode 120000 index 0000000000..abc00d421f --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 0000000000..721bb48e7c --- /dev/null +++ b/egs/swbd/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file From 56b70728330a891c68d9b2e542008e4fd22edc3b Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 26 Sep 2023 17:01:20 +0800 Subject: [PATCH 49/60] Update egs/swbd/ASR/RESULTS.md Co-authored-by: Fangjun Kuang --- egs/swbd/ASR/RESULTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md index d285261acc..f3a22c4448 100644 --- a/egs/swbd/ASR/RESULTS.md +++ b/egs/swbd/ASR/RESULTS.md @@ -5,7 +5,7 @@ The best WER, as of 2023-09-04, for the Switchboard is below -Results using attention decoder is given as: +Results using attention decoder are given as: | | eval2000-swbd | eval2000-callhome | eval2000-avg | |--------------------------------|-----------------|---------------------|--------------| From d981e82ea56594a38aaae6eb7128c09e6801d02f Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:10:22 +0800 Subject: [PATCH 50/60] added CI test for the swbd recipe --- .../run-swbd-conformer-ctc-2023-08-26.sh | 41 +++++++++ .github/workflows/run-swbd-conformer-ctc.yml | 84 +++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh create mode 100644 .github/workflows/run-swbd-conformer-ctc.yml diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh new file mode 100644 index 0000000000..ef9cadcff8 --- /dev/null +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/swbd/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-98.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + +for method in ctc-decoding attention-decoder; do + log "$method" + + ./conformer_ctc/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml new file mode 100644 index 0000000000..842691d38c --- /dev/null +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -0,0 +1,84 @@ +# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-swbd-conformer_ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +concurrency: + group: run-swbd-conformer_ctc-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-swbd-conformer_ctc: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh From 71a860d9db94718eaa6f95863ee8078e81fa3cb6 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:23:40 +0800 Subject: [PATCH 51/60] Update run-swbd-conformer-ctc-2023-08-26.sh --- .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh old mode 100644 new mode 100755 From ebf0e5acc00310263a57bac849a27ff4b2212f6c Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:33:01 +0800 Subject: [PATCH 52/60] Delete pretrained.py --- egs/swbd/ASR/conformer_ctc/pretrained.py | 430 ----------------------- 1 file changed, 430 deletions(-) delete mode 100755 egs/swbd/ASR/conformer_ctc/pretrained.py diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py deleted file mode 100755 index 30def9c404..0000000000 --- a/egs/swbd/ASR/conformer_ctc/pretrained.py +++ /dev/null @@ -1,430 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from conformer import Conformer -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_whole_lattice, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - (3) attention-decoder - Extract n paths from the rescored - lattice and use the transformer attention decoder for - rescoring. - We call it HLG decoding + n-gram LM rescoring + attention - decoder rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or attention-decoder. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and attention-decoder. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--attention-decoder-scale", - type=float, - default=1.2, - help=""" - Used only when method is attention-decoder. - It specifies the scale for attention decoder scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help=""" - Used only when method is attention-decoder. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the SOS token. - """, - ) - - parser.add_argument( - "--num-classes", - type=int, - default=500, - help=""" - Vocab size in the BPE model. - """, - ) - - parser.add_argument( - "--eos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the EOS token. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "subsampling_factor": 4, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - if args.method != "attention-decoder": - # to save memory as the attention decoder - # will not be used - params.num_decoder_layers = 0 - - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output, memory, memory_key_padding_mask = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) - max_token_id = params.num_classes - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=params.num_classes > 500, - device=device, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] - elif params.method in [ - "1best", - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() From 191ad9d759823f41f1f4f882cec8f3b6ef2cc8e8 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:33:20 +0800 Subject: [PATCH 53/60] Create pretrained.py --- egs/swbd/ASR/conformer_ctc/pretrained.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/swbd/ASR/conformer_ctc/pretrained.py diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 120000 index 0000000000..04ee4687b3 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file From 28bcc965391d3f6d4ee0ea0c76c25f18e8fc8a4a Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:38:26 +0800 Subject: [PATCH 54/60] Delete pretrained.py --- egs/swbd/ASR/conformer_ctc/pretrained.py | 1 - 1 file changed, 1 deletion(-) delete mode 120000 egs/swbd/ASR/conformer_ctc/pretrained.py diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py deleted file mode 120000 index 04ee4687b3..0000000000 --- a/egs/swbd/ASR/conformer_ctc/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file From 95866d96becb62afc943e6938e394041417b31a8 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:38:43 +0800 Subject: [PATCH 55/60] Create pretrained.py --- egs/swbd/ASR/conformer_ctc/pretrained.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/swbd/ASR/conformer_ctc/pretrained.py diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 120000 index 0000000000..526bc96787 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file From 078c8722d39a8263dc9d46d79ace82772d7ae9cc Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 26 Sep 2023 21:26:40 +0800 Subject: [PATCH 56/60] fixed CI test --- .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh index ef9cadcff8..4618039db3 100755 --- a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -35,6 +35,7 @@ for method in ctc-decoding attention-decoder; do --method $method \ --checkpoint $repo/exp/epoch-99.pt \ --tokens $repo/data/lang_bpe_500/tokens.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav From b80aae7ab12583e1d9a72edec66199467e9abc01 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:26:42 +0800 Subject: [PATCH 57/60] Update run-swbd-conformer-ctc-2023-08-26.sh --- .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh index 4618039db3..2699fdf14d 100755 --- a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -36,6 +36,7 @@ for method in ctc-decoding attention-decoder; do --checkpoint $repo/exp/epoch-99.pt \ --tokens $repo/data/lang_bpe_500/tokens.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav From d977d47c37a1a3b76b37f164831429f3db0e5c70 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:43:25 +0800 Subject: [PATCH 58/60] removed attention-decoder to fix 143 error --- .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh index 2699fdf14d..4d06132824 100755 --- a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -28,7 +28,7 @@ popd ls -lh $repo/exp/*.pt -for method in ctc-decoding attention-decoder; do +for method in ctc-decoding 1best nbest nbest-rescoring; do log "$method" ./conformer_ctc/pretrained.py \ From efd936489ddd321985f088202aaeb856e3db6a5b Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:58:54 +0800 Subject: [PATCH 59/60] Update run-swbd-conformer-ctc-2023-08-26.sh --- .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh index 4d06132824..1702524132 100755 --- a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -35,6 +35,7 @@ for method in ctc-decoding 1best nbest nbest-rescoring; do --method $method \ --checkpoint $repo/exp/epoch-99.pt \ --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ --G $repo/data/lm/G_4_gram.pt \ $repo/test_wavs/1089-134686-0001.wav \ From 17100985b2f170639fe8ad98a192b43cdddcf0cb Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:03:53 +0800 Subject: [PATCH 60/60] finalize CI test --- .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh index 1702524132..d8cc020e1b 100755 --- a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -28,7 +28,7 @@ popd ls -lh $repo/exp/*.pt -for method in ctc-decoding 1best nbest nbest-rescoring; do +for method in ctc-decoding 1best; do log "$method" ./conformer_ctc/pretrained.py \