From 25f1566f1d7f2de81d6f1542db6d1e77b619dcdf Mon Sep 17 00:00:00 2001 From: Michal Grzadkowski Date: Mon, 16 Dec 2024 21:42:21 -0500 Subject: [PATCH] addressing torch.amp.GradScaler FutureWarnings --- cryodrgn/commands/abinit_het.py | 5 ++++- cryodrgn/commands/abinit_homo.py | 5 ++++- cryodrgn/commands/train_nn.py | 5 ++++- cryodrgn/commands/train_vae.py | 5 ++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/cryodrgn/commands/abinit_het.py b/cryodrgn/commands/abinit_het.py index 885b0473..f165d550 100644 --- a/cryodrgn/commands/abinit_het.py +++ b/cryodrgn/commands/abinit_het.py @@ -910,7 +910,10 @@ def main(args): model, optim = amp.initialize(model, optim, opt_level="O1") # mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() if args.load == "latest": args = get_latest(args) diff --git a/cryodrgn/commands/abinit_homo.py b/cryodrgn/commands/abinit_homo.py index 404afd75..efcf32ec 100644 --- a/cryodrgn/commands/abinit_homo.py +++ b/cryodrgn/commands/abinit_homo.py @@ -679,7 +679,10 @@ def main(args): model, optim = amp.initialize(model, optim, opt_level="O1") # mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() sorted_poses = [] if args.load: diff --git a/cryodrgn/commands/train_nn.py b/cryodrgn/commands/train_nn.py index a76cecf3..65b7bcee 100644 --- a/cryodrgn/commands/train_nn.py +++ b/cryodrgn/commands/train_nn.py @@ -513,7 +513,10 @@ def main(args: argparse.Namespace) -> None: model, optim = amp.initialize(model, optim, opt_level="O1") # Mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() # parallelize if args.multigpu and torch.cuda.device_count() > 1: diff --git a/cryodrgn/commands/train_vae.py b/cryodrgn/commands/train_vae.py index deb2ceac..75535127 100755 --- a/cryodrgn/commands/train_vae.py +++ b/cryodrgn/commands/train_vae.py @@ -864,7 +864,10 @@ def main(args: argparse.Namespace) -> None: model, optim = amp.initialize(model, optim, opt_level="O1") # mixed precision with pytorch (v1.6+) except: # noqa: E722 - scaler = torch.cuda.amp.grad_scaler.GradScaler() + try: + scaler = torch.amp.GradScaler("cuda") + except AttributeError: + scaler = torch.cuda.amp.grad_scaler.GradScaler() # restart from checkpoint if args.load: