From 693d84a3011b1bda51ac6f95c3002af93efa772d Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 21 Oct 2024 10:35:26 +0800 Subject: [PATCH] Add Consistency-Regularized CTC (#1766) * support consistency-regularized CTC * update arguments of cr-ctc * set default value of cr_loss_masked_scale to 1.0 * minor fix * refactor codes * update RESULTS.md --- egs/librispeech/ASR/README.md | 8 +- egs/librispeech/ASR/RESULTS.md | 310 +++++++++++++++++++++++++ egs/librispeech/ASR/zipformer/model.py | 123 +++++++++- egs/librispeech/ASR/zipformer/train.py | 95 +++++++- icefall/utils.py | 40 ++++ 5 files changed, 556 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 8b87ee19b4..0dbfdc931c 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer. | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | | `zipformer-ctc` | Zipformer | Use auxiliary attention head | -| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe | +| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head (the latest recipe) | # MMI @@ -58,3 +58,9 @@ We place an additional Conv1d layer right after the input embedding layer. |------------------------------|-----------|---------------------------------------------------| | `conformer-mmi` | Conformer | | | `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | + +# CR-CTC + +| | Encoder | Comment | +|------------------------------|--------------------|------------------------------| +| `zipformer` | Upgraded Zipformer | Could also be an auxiliary loss to improve transducer or CTC/AED (the latest recipe) | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index bc7d8a5efb..6a669f072d 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,315 @@ ## Results +### zipformer (zipformer + pruned-transducer w/ CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 148824074, i.e., 148.8 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| greedy_search | 1.9 | 3.96 | --epoch 50 --avg 26 | +| modified_beam_search | 1.88 | 3.95 | --epoch 50 --avg 26 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large-cr-ctc-rnnt \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 1 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --ctc-loss-scale 0.1 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.02 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1400 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 50 \ + --avg 26 \ + --exp-dir zipformer/exp-large-cr-ctc-rnnt \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 1 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 300 \ + --decoding-method $m +done +``` + +### zipformer (zipformer + CR-CTC-AED) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| attention-decoder-rescoring-no-ngram | 1.96 | 4.08 | --epoch 50 --avg 20 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large-cr-ctc-aed \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.02 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1200 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 20 \ + --exp-dir zipformer/exp-large-cr-ctc-aed/ \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 200 \ + --decoding-method attention-decoder-rescoring-no-ngram +done +``` + +### zipformer (zipformer + CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### small-scale model, number of model parameters: 22118279, i.e., 22.1 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 | + +The training command using 2 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-small/ \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --base-lr 0.04 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 850 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 25 \ + --exp-dir zipformer/exp-small \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +##### medium-scale model, number of model parameters: 64250603, i.e., 64.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 | + +The training command using 4 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 700 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 24 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +##### large-scale model, number of model parameters: 147010094, i.e., 147.0 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1400 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 26 \ + --exp-dir zipformer/exp-large \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 600 \ + --decoding-method $m +done +``` + ### zipformer (zipformer + CTC/AED) See for more details. diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index bd1ed26d8d..deebb2a754 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -24,7 +24,8 @@ from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, time_warp +from lhotse.dataset import SpecAugment class AsrModel(nn.Module): @@ -181,6 +182,49 @@ def forward_ctc( ) return ctc_loss + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + def forward_transducer( self, encoder_out: torch.Tensor, @@ -296,7 +340,12 @@ def forward( prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -316,9 +365,26 @@ def forward( lm_scale: The scale to smooth the loss with lm (output of predictor network) part + use_cr_ctc: + Whether use consistency-regularized CTC. + use_spec_aug: + Whether apply spec-augment manually, used only if use_cr_ctc is True. + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -334,6 +400,24 @@ def forward( device = x.device + if use_cr_ctc: + assert self.use_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -351,6 +435,9 @@ def forward( am_scale=am_scale, lm_scale=lm_scale, ) + if use_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) @@ -358,14 +445,26 @@ def forward( if self.use_ctc: # Compute CTC loss targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + if not use_cr_ctc: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 else: ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -374,7 +473,9 @@ def forward( ys=y.to(device), ys_lens=y_lens.to(device), ) + if use_cr_ctc: + attention_decoder_loss = attention_decoder_loss * 0.5 else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9c1c7f5a78..c074c32ec7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -45,11 +45,10 @@ --max-duration 1000 It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` - - ctc loss & attention decoder loss, no transducer loss, - with `--use-transducer False --use-ctc True --use-attention-decoder True` + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) """ @@ -72,6 +71,7 @@ from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -304,6 +304,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -449,6 +456,20 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -717,6 +738,24 @@ def get_model(params: AttributeDict) -> nn.Module: return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -839,6 +878,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -855,8 +895,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -874,14 +914,34 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, ) loss = 0.0 @@ -904,6 +964,8 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -922,6 +984,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -971,6 +1035,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -997,6 +1062,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1043,6 +1110,7 @@ def save_bad_model(suffix: str = ""): sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,6 +1306,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1360,6 +1435,7 @@ def remove_short_and_long_utt(c: Cut): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1387,6 +1463,7 @@ def remove_short_and_long_utt(c: Cut): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1452,6 +1529,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1471,6 +1549,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954ded..b0a42cefaa 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -21,6 +21,7 @@ import collections import logging import os +import random import re import subprocess from collections import defaultdict @@ -38,6 +39,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -2271,3 +2273,41 @@ def num_tokens( if 0 in ans: num_tokens -= 1 return num_tokens + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + """Apply time warping on a batch of features + """ + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features