From ef658d691e75041398abb76567c810af1c22c7fc Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 24 Sep 2023 17:06:47 +0800 Subject: [PATCH] fixes for init value of `diagnostics.TensorDiagnosticOptions` (#1269) * fixes for `diagnostics` Replace `2 ** 22` with `512` as the default value of `diagnostics.TensorDiagnosticOptions` also black formatted some scripts * fixed formatting issues --- .../ASR/pruned_transducer_stateless2/train.py | 3 +- .../ASR/pruned_transducer_stateless2/train.py | 2 +- .../ASR/pruned_transducer_stateless3/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../pruned_transducer_stateless7/train2.py | 2 +- .../train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 3 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../ASR/pruned_transducer_stateless2/train.py | 3 +- .../pruned_transducer_stateless7/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../train.py | 2 +- .../train2.py | 2 +- .../train.py | 2 +- .../train.py | 2 +- .../train2.py | 2 +- .../ASR/pruned2_knowledge/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../pruned_transducer_stateless7/finetune.py | 2 +- .../ASR/pruned_transducer_stateless7/optim.py | 12 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../pruned_transducer_stateless7_ctc/train.py | 2 +- .../train.py | 2 +- .../train.py | 2 +- .../train2.py | 2 +- .../train.py | 2 +- .../ASR/pruned_transducer_stateless8/train.py | 2 +- .../ASR/streaming_conformer_ctc/conformer.py | 4 +- egs/librispeech/ASR/zipformer/decoder.py | 30 +- egs/librispeech/ASR/zipformer/joiner.py | 9 +- egs/librispeech/ASR/zipformer/onnx_decode.py | 4 +- egs/librispeech/ASR/zipformer/optim.py | 22 +- egs/librispeech/ASR/zipformer/profile.py | 12 +- egs/librispeech/ASR/zipformer/scaling.py | 719 ++++++++++-------- .../ASR/zipformer/streaming_decode.py | 57 +- egs/librispeech/ASR/zipformer/subsampling.py | 16 +- egs/librispeech/ASR/zipformer/train.py | 21 +- egs/librispeech/ASR/zipformer_mmi/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 5 +- egs/multi_zh-hans/ASR/zipformer/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../train.py | 4 +- egs/tedlium3/ASR/conformer_ctc2/train.py | 2 +- egs/tedlium3/ASR/zipformer/train.py | 2 +- .../pruned_transducer_stateless2/finetune.py | 2 +- .../ASR/pruned_transducer_stateless2/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- egs/wenetspeech/ASR/zipformer/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- 51 files changed, 513 insertions(+), 481 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index c9d9c4aa8a..fa809b768a 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index d089082386..60f014c48d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -872,7 +872,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 62e67530de..7c23041cad 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -1045,7 +1045,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index cbb7db0861..11671db92e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py index c30f6f9606..057af297f0 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 4e52f95732..3858bafd7c 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 74bf68ccbd..8c7448d4c8 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -730,7 +730,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -919,7 +918,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 47015cbe7d..a354f761e5 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -908,7 +908,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index e57b5c8593..30154291df 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 45d7779229..8f09f1aa54 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -999,7 +999,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index 8c8d9593b8..9b67141c0d 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -988,7 +988,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 4bd5b83a23..4aedeffe4d 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index 18cb75c375..73fcd67aad 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1074,7 +1074,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py index bc4bcf2534..4c866ddd81 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1075,7 +1075,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c5a05d349e..ca21bd6bf8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 6bb37b0179..23ddb6bec8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py index 36067510c2..420dc1065a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -955,7 +955,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 77e06d3b7e..a4899f7bd6 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -811,7 +811,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 3b5a635e47..66dc5f991f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1003,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 3ee2b7d656..4e261dbc1d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -1132,7 +1132,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index aa3cef338e..8ab3589dab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -117,7 +117,7 @@ def batched_params(self, param_group, group_params_names): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ def __init__( parameters_names=None, show_dominant_parameters=True, ): - assert parameters_names is not None, ( "Please prepare parameters_names," "which is a List[List[str]]. Each List[str] is for a group" @@ -224,9 +223,7 @@ def step(self, closure=None): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -325,7 +322,7 @@ def _get_clipping_scale( clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -410,7 +407,7 @@ def _show_gradient_dominating_parameter( from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -426,7 +423,6 @@ def _show_gradient_dominating_parameter( for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2b4d51089c..fac3706d2e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b387968a9c..d8fa08372a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1052,7 +1052,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 23fb6f4975..25a1aa6743 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1042,7 +1042,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 99090b2c17..2d915ff870 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1029,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py index 9be629149b..aa6c0668a9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1030,7 +1030,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index b494253d6b..565dc7a162 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -1141,7 +1141,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index bee414292d..3f271c5b4d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,7 +1154,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index be6fabf353..0b982f4bfc 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -230,7 +230,9 @@ def train_run_encoder( x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask ) # (T, B, F) else: - x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) if self.normalize_before: x = self.after_norm(x) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index e8db988f6e..e77e541187 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -61,10 +61,15 @@ def __init__( ) # the balancers are to avoid any drift in the magnitude of the # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) self.blank_id = blank_id @@ -81,10 +86,15 @@ def __init__( groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -107,9 +117,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index f03cc930ec..dfb0a0057b 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -52,12 +52,13 @@ def forward( Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 2aca36ca94..356c2a8303 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -303,7 +303,9 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index abfb2092cd..c9b76526c6 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -116,7 +116,7 @@ def batched_params(self, param_group, group_params_names): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ def __init__( size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -299,8 +298,8 @@ def _get_names_of_parameters( # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] del cur_group["named_params"] cur_group["params"] = p_list param_groups.append(cur_group) @@ -327,9 +326,7 @@ def step(self, closure=None): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -428,7 +425,7 @@ def _get_clipping_scale( clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -513,7 +510,7 @@ def _show_gradient_dominating_parameter( from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -529,7 +526,6 @@ def _show_gradient_dominating_parameter( for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -667,8 +663,7 @@ def _size_update( # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -879,7 +874,8 @@ def get_lr(self): warmup_factor = ( 1.0 if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) @@ -1111,7 +1107,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/profile.py index b460b53389..57f44a90a8 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/profile.py @@ -100,17 +100,13 @@ def __init__( self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: x, x_lens = self.encoder_embed(feature, feature_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) logits = self.encoder_proj(encoder_out) @@ -168,9 +164,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7c98ef0456..23fd279b31 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -25,6 +25,7 @@ import torch.nn as nn from torch import Tensor + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) @@ -55,28 +56,34 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: # for torch.jit.trace() return torch.logaddexp(x, y) + class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] respectively. """ + def __init__(self, *args): assert len(args) >= 1, len(args) if len(args) == 1 and isinstance(args[0], PiecewiseLinear): self.pairs = list(args[0].pairs) else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: + self.pairs = [(float(x), float(y)) for x, y in args] + for (x, y) in self.pairs: assert isinstance(x, (float, int)), type(x) assert isinstance(y, (float, int)), type(y) for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" def __call__(self, x): if x <= self.pairs[0][0]: @@ -93,37 +100,36 @@ def __call__(self, x): assert False def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) def __add__(self, x): if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) s, x = self.get_common_basis(x) return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) def max(self, x): if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def min(self, x): if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def __eq__(self, other): return self.pairs == other.pairs - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): """ Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. @@ -135,28 +141,30 @@ def get_common_basis(self, assert isinstance(p, PiecewiseLinear), type(p) # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): # if the two lines in this subsegment potentially cross each other.. diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) # `pos`, between 0 and 1, gives the relative x position, # with 0 being x_vals[i] and 1 being x_vals[i+1]. pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) class ScheduledFloat(torch.nn.Module): @@ -176,9 +184,8 @@ class ScheduledFloat(torch.nn.Module): `default` is used when self.batch_count is not set or not in training mode or in torch.jit scripting mode. """ - def __init__(self, - *args, - default: float = 0.0): + + def __init__(self, *args, default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. self.batch_count = None @@ -187,47 +194,55 @@ def __init__(self, self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) def __float__(self): batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): return float(self.default) else: ans = self.schedule(self.batch_count) if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) return ans def __add__(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) + return ScheduledFloat(self.schedule + x, default=self.default) else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) def max(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) + return ScheduledFloat(self.schedule.max(x), default=self.default) else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) FloatLike = Union[float, ScheduledFloat] -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -242,6 +257,7 @@ class CutoffEstimator: p is the proportion of items that should be above the cutoff. """ + def __init__(self, p: float): self.p = p # total count of items @@ -255,7 +271,7 @@ def __call__(self, x: float) -> bool: """ Returns true if x is above the cutoff. """ - ans = (x > self.cutoff) + ans = x > self.cutoff self.count += 1 if ans: self.count_above += 1 @@ -263,7 +279,7 @@ def __call__(self, x: float) -> bool: delta_p = cur_p - self.p if (delta_p > 0) == ans: q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) + self.cutoff = x * q + self.cutoff * (1 - q) return ans @@ -272,6 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -287,7 +304,7 @@ def forward(ctx, x: Tensor, dim: int): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -306,17 +323,16 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x @staticmethod @@ -328,15 +344,20 @@ def backward(ctx, x_grad, *args): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -348,8 +369,14 @@ class BiasNormFunction(torch.autograd.Function): # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim @@ -357,10 +384,16 @@ def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) return ans @staticmethod @@ -376,7 +409,9 @@ def backward(ctx, ans_grad: Tensor) -> Tensor: log_scale.requires_grad = True with torch.enable_grad(): # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), log_scale.grad, None, None @@ -412,14 +447,15 @@ class BiasNorm(torch.nn.Module): than the input of this module to be required to be stored for the backprop. """ + def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels @@ -442,23 +478,24 @@ def forward(self, x: Tensor) -> Tensor: bias = self.bias for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() return x * scales - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -477,15 +514,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -504,15 +537,11 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: """ Behaves like a constructor of a modified version of nn.Conv2d that gives an easy way to set the default initial parameter scale. @@ -532,9 +561,7 @@ def ScaledConv2d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -562,29 +589,36 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): Another option, if you want to do something like this, is to re-initialize the parameters. """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): super().__init__() assert kernel_size % 2 == 1 half_kernel_size = (kernel_size + 1) // 2 # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) - - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) # first row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, @@ -596,17 +630,15 @@ def __init__(self, self.causal_conv.weight[:] *= initial_scale self.chunkwise_conv.weight[:] *= initial_scale if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ (batch_size, num_channels, seq_len) = x.shape @@ -622,28 +654,32 @@ def forward(self, x = torch.nn.functional.pad(x, (left_pad, right_pad)) - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) assert x_causal.shape == (batch_size, num_channels, seq_len) x_chunk = x[..., left_pad:] num_chunks = x_chunk.shape[2] // chunk_size x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) x_chunk = self.chunkwise_conv(x_chunk) # does not change shape chunk_scale = self._get_chunk_scale(chunk_size) x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] return x_chunk + x_causal def _get_chunk_scale(self, chunk_size: int): """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" + scale the output of self.chunkwise_conv.""" left_edge = self.chunkwise_conv_scale[0] right_edge = self.chunkwise_conv_scale[1] if chunk_size < self.kernel_size: @@ -652,9 +688,9 @@ def _get_chunk_scale(self, chunk_size: int): else: t = chunk_size - self.kernel_size channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) left_edge = torch.cat((left_edge, pad), dim=-1) right_edge = torch.cat((pad, right_edge), dim=-1) return 1.0 + (left_edge + right_edge) @@ -698,14 +734,14 @@ def streaming_forward( class BalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim @@ -715,10 +751,8 @@ def forward( return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config try: @@ -727,8 +761,8 @@ def backward( x = x.to(torch.float32) x = x.detach() x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) mean = x.mean(dim=mean_dims, keepdim=True) stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() rms = uncentered_var.clamp(min=1.0e-20).sqrt() @@ -742,11 +776,16 @@ def backward( rms_clamped = rms.clamp(min=min_rms, max=max_rms) r_loss = (rms_clamped / rms).log().abs() - loss = (m_loss + r_loss) + loss = m_loss + r_loss loss.backward(gradient=torch.ones_like(loss)) loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) loss_grad = loss_grad * (grad_scale / loss_grad_rms) @@ -757,7 +796,9 @@ def backward( x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad = x_grad_mod.to(x_grad.dtype) except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -793,16 +834,17 @@ class Balancer(torch.nn.Module): on each forward(). This is done randomly to prevent all layers from doing it at the same time. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, ): super().__init__() @@ -823,8 +865,11 @@ def __init__( self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): return _no_op(x) prob = float(self.prob) @@ -842,7 +887,7 @@ def _atanh(x): eps = 1.0e-10 # eps is to prevent crashes if x is exactly 0 or 1. # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 def _approx_inverse_erf(x): # 1 / (sqrt(pi) * ln(2)), @@ -853,6 +898,7 @@ def _approx_inverse_erf(x): # and math.erf(0.0407316414078772) = 0.045935330944660666, # which is pretty close to 0.05. return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error # function returns x = -1 + (2 * x) @@ -873,8 +919,9 @@ def _approx_inverse_erf(x): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -910,13 +957,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -946,25 +992,22 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: ctx.save_for_backward(x) ctx.module = module return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors w = ctx.module try: @@ -976,8 +1019,10 @@ def backward(ctx, metric = _whitening_metric(x_detached, w.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) if metric < float(w.whitening_limit): w.prob = w.min_prob @@ -986,22 +1031,27 @@ def backward(ctx, w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -1033,10 +1083,9 @@ def __init__( (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob <= self.max_prob <= 1 self.prob = self.max_prob - self.name = None # will be set in training loop + self.name = None # will be set in training loop - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -1071,9 +1120,11 @@ def forward(ctx, x: Tensor, y: Tensor, name: str): @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) def with_loss(x, y, name): @@ -1118,20 +1169,21 @@ def forward(ctx, x: Tensor, min: float, max: float): @staticmethod def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) return x_grad, None, None -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. @@ -1187,7 +1239,7 @@ def forward(ctx, x: Tensor) -> Tensor: y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 @@ -1197,7 +1249,9 @@ def forward(ctx, x: Tensor) -> Tensor: # floors), should be expectation-preserving. floor = -0.044 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1210,12 +1264,12 @@ def forward(ctx, x: Tensor) -> Tensor: @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) + d = d * ((ceil - floor) / 255.0) + floor return y_grad * d @@ -1239,9 +1293,7 @@ def __init__(self, p: FloatLike): self.p = p def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) class MulForDropout3(torch.autograd.Function): @@ -1259,7 +1311,7 @@ def forward(ctx, x, y, alpha): @staticmethod @custom_bwd def backward(ctx, ans_grad): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) return x_grad, None, None @@ -1286,7 +1338,7 @@ def forward(self, x: Tensor) -> Tensor: class SwooshLFunction(torch.autograd.Function): """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1308,13 +1360,15 @@ def forward(ctx, x: Tensor) -> Tensor: if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = coeff ceil = 1.0 + coeff + 0.005 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1328,20 +1382,19 @@ def forward(ctx, x: Tensor) -> Tensor: @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. coeff = -0.08 floor = coeff ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 @@ -1351,19 +1404,19 @@ def forward(self, x: Tensor) -> Tensor: return k2.swoosh_l(x) # return SwooshLFunction.apply(x) + class SwooshLOnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 class SwooshRFunction(torch.autograd.Function): """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - derivatives are between -0.08 and 0.92. + derivatives are between -0.08 and 0.92. """ @staticmethod @@ -1379,17 +1432,19 @@ def forward(ctx, x: Tensor) -> Tensor: with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = -0.08 ceil = 0.925 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1403,33 +1458,32 @@ def forward(ctx, x: Tensor) -> Tensor: @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.08 ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: return k2.swoosh_r_forward(x) else: return k2.swoosh_r(x) # return SwooshRFunction.apply(x) + class SwooshROnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 # simple version of SwooshL that does not redefine the backprop, used in @@ -1437,7 +1491,7 @@ def forward(self, x: Tensor) -> Tensor: def SwooshLForward(x: Tensor): x_offset = x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.035 @@ -1446,28 +1500,30 @@ def SwooshLForward(x: Tensor): def SwooshRForward(x: Tensor): x_offset = x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.313261687 class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): if dropout_p != 0.0: dropout_shape = list(x.shape) if dropout_shared_dim is not None: dropout_shape[dropout_shared_dim] = 1 # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) else: dropout_mask = None @@ -1476,8 +1532,8 @@ def forward(ctx, ctx.activation = activation forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1495,8 +1551,8 @@ def backward(ctx, ans_grad: Tensor): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1511,8 +1567,7 @@ def backward(ctx, ans_grad: Tensor): in_channels = y.shape[-1] g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) y_deriv = torch.matmul(ans_grad, weight) bias_deriv = None if bias is None else g.sum(dim=0) x_deriv = y_deriv * func_deriv @@ -1525,71 +1580,76 @@ def backward(ctx, ans_grad: Tensor): class ActivationDropoutAndLinear(torch.nn.Module): """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): super().__init__() # create a temporary module of nn.Linear that we'll steal the # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) self.weight = l.weight # register_parameter properly handles making it a parameter when l.bias # is None. I think there is some reason for doing it this way rather # than just setting it to None but I don't know what it is, maybe # something to do with exporting the module.. - self.register_parameter('bias', l.bias) + self.register_parameter("bias", l.bias) self.activation = activation self.dropout_p = dropout_p self.dropout_shared_dim = dropout_shared_dim - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': + if self.activation == "SwooshL": x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) else: assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) + return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: @@ -1612,10 +1672,9 @@ def _test_whiten(): x.requires_grad = True - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1656,9 +1715,7 @@ def _test_balancer_sign(): def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = Balancer( @@ -1685,7 +1742,7 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) # for self-test. @@ -1699,7 +1756,7 @@ def _test_swooshl_deriv(): x.requires_grad = True m = SwooshL() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1713,7 +1770,7 @@ def _test_swooshr_deriv(): x.requires_grad = True m = SwooshR() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1727,24 +1784,24 @@ def _test_softmax(): b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) + p = PiecewiseLinear((0, 10.0)) for x in [-100, 0, 100]: assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: print("x, y = ", x, y) assert p(x) == y, (x, p(x), y) q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] pq = p.max(q) for x in x_vals: y1 = max(p(x), q(x)) @@ -1757,7 +1814,7 @@ def _test_piecewise_linear(): assert abs(y1 - y2) < 0.001 pq = p + q for x in x_vals: - y1 = p(x) + q(x) + y1 = p(x) + q(x) y2 = pq(x) assert abs(y1 - y2) < 0.001 @@ -1772,15 +1829,22 @@ def _test_activation_dropout_and_linear(): # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() # internally, messing up the random state. for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) with torch.no_grad(): m2.weight[:] = m1[2].weight if bias: @@ -1790,9 +1854,9 @@ def _test_activation_dropout_and_linear(): x1.requires_grad = True # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) x2 = x1.clone().detach() x2.requires_grad = True @@ -1805,21 +1869,24 @@ def _test_activation_dropout_and_linear(): y2 = m2(x2) y2.backward(gradient=y_grad) - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) def isclose(a, b): # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte # storage of it. assert isclose(x1.grad, x2.grad) diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 44ff392a3a..904caf8af1 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states @@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -380,11 +374,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -404,9 +394,7 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -494,9 +482,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -517,9 +503,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) @@ -577,9 +561,7 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -649,9 +631,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -684,8 +664,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -718,9 +697,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -760,9 +737,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -789,9 +766,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 6532ddccb6..d16d87bac9 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -107,9 +107,7 @@ def forward(self, x: Tensor) -> Tensor: if layerdrop_rate != 0.0: batch_size = x.shape[0] mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) + torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate ) else: @@ -278,9 +276,7 @@ def __init__( # many copies of this extra gradient term. self.out_whiten = Whiten( num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), prob=(0.025, 0.25), grad_scale=0.02, ) @@ -331,7 +327,7 @@ def forward( with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) return x, x_lens @@ -403,8 +399,8 @@ def get_init_states( left_pad = self.convnext.padding[0] freq = self.out_width channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) + cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( + device + ) return cached_embed_left_pad diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1bae..7009f33466 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -808,17 +808,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1166,7 +1165,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1b3ea3e0a..4b50acddef 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -981,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py index a687027767..48468cfbdb 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -746,7 +746,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -966,7 +965,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1019,7 +1018,6 @@ def remove_short_and_long_text(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -1118,7 +1116,6 @@ def scan_pessimistic_batches_for_oom( # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 4f2d728bef..c1bbd2ee83 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index 417515968c..d039702659 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -915,7 +915,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index d80e0147c8..aee3972cd5 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -69,7 +69,7 @@ from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer -from icefall import diagnostics, byte_encode, tokenize_by_CJK_char +from icefall import byte_encode, diagnostics, tokenize_by_CJK_char from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -1018,7 +1018,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py index 42e4c010af..fc3e3b2d92 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -905,7 +905,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 9271c84384..33d03908c3 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -1126,7 +1126,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index e703100a96..82bc882bdd 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -886,7 +886,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 48b347b64a..49977e01b5 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -851,7 +851,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 8e1b12dba9..931e699d92 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -985,7 +985,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 83dbfa22fb..b1557dedbe 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py index 5b5ac17be8..a6fa46b171 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index f8dd7b2876..8c53972fd2 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -993,7 +993,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts)