From af1081ead5f7c8e5ec2b66cfc676bc2dd3d617bd Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 30 Apr 2024 15:21:24 +0200 Subject: [PATCH] This should fix the inference --- mala/common/parameters.py | 16 ++++++++++++---- mala/network/runner.py | 10 +++++++++- mala/network/trainer.py | 4 ++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 65523d048..6a8baec76 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -1543,7 +1543,9 @@ def optuna_singlenode_setup(self, wait_time=0): self.hyperparameters._update_device(device_temp) @classmethod - def load_from_file(cls, file, save_format="json", no_snapshots=False): + def load_from_file( + cls, file, save_format="json", no_snapshots=False, force_no_ddp=False + ): """ Load a Parameters object from a file. @@ -1598,7 +1600,10 @@ def load_from_file(cls, file, save_format="json", no_snapshots=False): not isinstance(json_dict[key], dict) or key == "openpmd_configuration" ): - setattr(loaded_parameters, key, json_dict[key]) + if key == "use_ddp" and force_no_ddp is True: + setattr(loaded_parameters, key, False) + else: + setattr(loaded_parameters, key, json_dict[key]) if no_snapshots is True: loaded_parameters.data.snapshot_directories_list = [] else: @@ -1631,7 +1636,7 @@ def load_from_pickle(cls, file, no_snapshots=False): ) @classmethod - def load_from_json(cls, file, no_snapshots=False): + def load_from_json(cls, file, no_snapshots=False, force_no_ddp=False): """ Load a Parameters object from a json file. @@ -1651,5 +1656,8 @@ def load_from_json(cls, file, no_snapshots=False): """ return Parameters.load_from_file( - file, save_format="json", no_snapshots=no_snapshots + file, + save_format="json", + no_snapshots=no_snapshots, + force_no_ddp=force_no_ddp, ) diff --git a/mala/network/runner.py b/mala/network/runner.py index 896e8b720..5e6ecdafa 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import mala from mala.common.parallelizer import get_rank from mala.common.parameters import ParametersRunning from mala.network.network import Network @@ -145,6 +146,7 @@ def load_run( prepare_data=False, load_with_mpi=None, load_with_gpu=None, + load_with_ddp=None, ): """ Load a run. @@ -231,7 +233,13 @@ def load_run( path, run_name + ".params." + params_format ) - loaded_params = Parameters.load_from_json(loaded_params) + # Neither Predictor nor Runner classes can work with DDP. + if cls is mala.Trainer: + loaded_params = Parameters.load_from_json(loaded_params) + else: + loaded_params = Parameters.load_from_json( + loaded_params, force_no_ddp=True + ) # MPI has to be specified upon loading, in contrast to GPU. if load_with_mpi is not None: diff --git a/mala/network/trainer.py b/mala/network/trainer.py index bb9d4d41b..430a0cf47 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -156,6 +156,7 @@ def load_run( params_format="json", load_runner=True, prepare_data=True, + load_with_ddp=None, ): """ Load a run. @@ -205,6 +206,9 @@ def load_run( params_format=params_format, load_runner=load_runner, prepare_data=prepare_data, + load_with_gpu=None, + load_with_mpi=None, + load_with_ddp=load_with_ddp, ) @classmethod