Skip to content

Commit

Permalink
Small bugfix to fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed May 3, 2024
1 parent 873e486 commit b58c096
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def _load_from_run(cls, params, network, data, file=None):
# First, load the checkpoint.
if params.use_ddp:
map_location = {"cuda:%d" % 0: "cuda:%d" % get_local_rank()}
checkpoint = torch.load(file, map_location=map_location)
checkpoint = torch.load(file, map_location=map_location)
else:
checkpoint = torch.load(file)

# Now, create the Trainer class with it.
loaded_trainer = Trainer(
Expand Down

0 comments on commit b58c096

Please sign in to comment.