Skip to content

Commit

Permalink
addressing torch.amp.GradScaler FutureWarnings
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Dec 17, 2024
1 parent 9989827 commit 25f1566
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 4 additions & 1 deletion cryodrgn/commands/abinit_het.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,10 @@ def main(args):
model, optim = amp.initialize(model, optim, opt_level="O1")
# mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()
try:
scaler = torch.amp.GradScaler("cuda")
except AttributeError:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

if args.load == "latest":
args = get_latest(args)
Expand Down
5 changes: 4 additions & 1 deletion cryodrgn/commands/abinit_homo.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,10 @@ def main(args):
model, optim = amp.initialize(model, optim, opt_level="O1")
# mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()
try:
scaler = torch.amp.GradScaler("cuda")
except AttributeError:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

sorted_poses = []
if args.load:
Expand Down
5 changes: 4 additions & 1 deletion cryodrgn/commands/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,10 @@ def main(args: argparse.Namespace) -> None:
model, optim = amp.initialize(model, optim, opt_level="O1")
# Mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()
try:
scaler = torch.amp.GradScaler("cuda")
except AttributeError:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

# parallelize
if args.multigpu and torch.cuda.device_count() > 1:
Expand Down
5 changes: 4 additions & 1 deletion cryodrgn/commands/train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,10 @@ def main(args: argparse.Namespace) -> None:
model, optim = amp.initialize(model, optim, opt_level="O1")
# mixed precision with pytorch (v1.6+)
except: # noqa: E722
scaler = torch.cuda.amp.grad_scaler.GradScaler()
try:
scaler = torch.amp.GradScaler("cuda")
except AttributeError:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

# restart from checkpoint
if args.load:
Expand Down

0 comments on commit 25f1566

Please sign in to comment.