Skip to content

Commit

Permalink
Fix bugs introduced by previous commits
Browse files Browse the repository at this point in the history
Along with reformatting to pass black lint.

- egs/libritts/ASR/zipformer/train.py
- egs/libritts/CODEC/encodec/encodec.py
- egs/ljspeech/TTS/vits/vits.py
- egs/wenetspeech4tts/TTS/valle/train.py
  • Loading branch information
Li Peng committed Dec 3, 2024
1 parent 30ba83a commit b29ab59
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
8 changes: 4 additions & 4 deletions egs/libritts/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,8 +1049,8 @@ def save_bad_model(suffix: str = ""):
batch_size = len(batch["supervisions"]["text"])

try:
with torch.amp.autocast("cuda",
enabled=params.use_autocast, dtype=params.dtype
with torch.amp.autocast(
"cuda", enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss(
params=params,
Expand Down Expand Up @@ -1478,8 +1478,8 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.amp.autocast("cuda",
enabled=params.use_autocast, dtype=params.dtype
with torch.amp.autocast(
"cuda", enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss(
params=params,
Expand Down
2 changes: 1 addition & 1 deletion egs/libritts/CODEC/encodec/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
WavReconstructionLoss,
)
from torch import nn
from torch.cuda.amp import autocast
from torch.amp import autocast


class Encodec(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from torch.amp import autocast
from utils import get_segments

AVAILABLE_GENERATERS = {
Expand Down
4 changes: 3 additions & 1 deletion egs/wenetspeech4tts/TTS/valle/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,9 @@ def run(rank, world_size, args):
params=params,
)

scaler = GradScaler("cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
scaler = GradScaler(
"cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
Expand Down

0 comments on commit b29ab59

Please sign in to comment.