Skip to content

Commit

Permalink
fixes for init value of diagnostics.TensorDiagnosticOptions (#1269)
Browse files Browse the repository at this point in the history
* fixes for `diagnostics`

Replace `2 ** 22` with `512` as the default value of `diagnostics.TensorDiagnosticOptions`

also black formatted some scripts

* fixed formatting issues
  • Loading branch information
JinZr authored Sep 24, 2023
1 parent 34e40a8 commit ef658d6
Show file tree
Hide file tree
Showing 51 changed files with 513 additions and 481 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()

for batch_idx, batch in enumerate(train_dl):

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -800,7 +799,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/ASR/pruned_transducer_stateless7/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/ASR/pruned_transducer_stateless7/train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
3 changes: 1 addition & 2 deletions egs/aishell2/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()

for batch_idx, batch in enumerate(train_dl):

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -919,7 +918,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell4/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
3 changes: 1 addition & 2 deletions egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()

for batch_idx, batch in enumerate(train_dl):

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -800,7 +799,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/ami/ASR/pruned_transducer_stateless7/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/pruned2_knowledge/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
12 changes: 4 additions & 8 deletions egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def batched_params(self, param_group, group_params_names):

yield tuples # <-- calling code will do the actual optimization here!

for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])

Expand Down Expand Up @@ -181,7 +181,6 @@ def __init__(
parameters_names=None,
show_dominant_parameters=True,
):

assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
Expand Down Expand Up @@ -224,9 +223,7 @@ def step(self, closure=None):
batch = True

for group, group_params_names in zip(self.param_groups, self.parameters_names):

with self.batched_params(group["params"], group_params_names) as batches:

# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
Expand Down Expand Up @@ -325,7 +322,7 @@ def _get_clipping_scale(
clipping_update_period = group["clipping_update_period"]

tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
Expand Down Expand Up @@ -410,7 +407,7 @@ def _show_gradient_dominating_parameter(
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
Expand All @@ -426,7 +423,6 @@ def _show_gradient_dominating_parameter(
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):

proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)

Expand Down Expand Up @@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int):

# if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# 512
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/pruned_transducer_stateless7/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/pruned_transducer_stateless8/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ def run(rank, world_size, args):

if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Expand Down
4 changes: 3 additions & 1 deletion egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def train_run_encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)

if self.normalize_before:
x = self.after_norm(x)
Expand Down
30 changes: 19 additions & 11 deletions egs/librispeech/ASR/zipformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ def __init__(
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)

self.blank_id = blank_id

Expand All @@ -81,10 +86,15 @@ def __init__(
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)

def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Expand All @@ -107,9 +117,7 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
Expand Down
Loading

0 comments on commit ef658d6

Please sign in to comment.