From b58c096c323237e8d28803ece9cfa06e2fc3de17 Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Fri, 3 May 2024 11:28:47 +0200 Subject: [PATCH] Small bugfix to fix CI --- mala/network/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index f8bf391f5..81977c40e 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -240,7 +240,9 @@ def _load_from_run(cls, params, network, data, file=None): # First, load the checkpoint. if params.use_ddp: map_location = {"cuda:%d" % 0: "cuda:%d" % get_local_rank()} - checkpoint = torch.load(file, map_location=map_location) + checkpoint = torch.load(file, map_location=map_location) + else: + checkpoint = torch.load(file) # Now, create the Trainer class with it. loaded_trainer = Trainer(