Skip to content

Commit

Permalink
This should fix the inference
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Apr 30, 2024
1 parent a9027a7 commit af1081e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
16 changes: 12 additions & 4 deletions mala/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,9 @@ def optuna_singlenode_setup(self, wait_time=0):
self.hyperparameters._update_device(device_temp)

@classmethod
def load_from_file(cls, file, save_format="json", no_snapshots=False):
def load_from_file(
cls, file, save_format="json", no_snapshots=False, force_no_ddp=False
):
"""
Load a Parameters object from a file.
Expand Down Expand Up @@ -1598,7 +1600,10 @@ def load_from_file(cls, file, save_format="json", no_snapshots=False):
not isinstance(json_dict[key], dict)
or key == "openpmd_configuration"
):
setattr(loaded_parameters, key, json_dict[key])
if key == "use_ddp" and force_no_ddp is True:
setattr(loaded_parameters, key, False)
else:
setattr(loaded_parameters, key, json_dict[key])
if no_snapshots is True:
loaded_parameters.data.snapshot_directories_list = []
else:
Expand Down Expand Up @@ -1631,7 +1636,7 @@ def load_from_pickle(cls, file, no_snapshots=False):
)

@classmethod
def load_from_json(cls, file, no_snapshots=False):
def load_from_json(cls, file, no_snapshots=False, force_no_ddp=False):
"""
Load a Parameters object from a json file.
Expand All @@ -1651,5 +1656,8 @@ def load_from_json(cls, file, no_snapshots=False):
"""
return Parameters.load_from_file(
file, save_format="json", no_snapshots=no_snapshots
file,
save_format="json",
no_snapshots=no_snapshots,
force_no_ddp=force_no_ddp,
)
10 changes: 9 additions & 1 deletion mala/network/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.distributed as dist

import mala
from mala.common.parallelizer import get_rank
from mala.common.parameters import ParametersRunning
from mala.network.network import Network
Expand Down Expand Up @@ -145,6 +146,7 @@ def load_run(
prepare_data=False,
load_with_mpi=None,
load_with_gpu=None,
load_with_ddp=None,
):
"""
Load a run.
Expand Down Expand Up @@ -231,7 +233,13 @@ def load_run(
path, run_name + ".params." + params_format
)

loaded_params = Parameters.load_from_json(loaded_params)
# Neither Predictor nor Runner classes can work with DDP.
if cls is mala.Trainer:
loaded_params = Parameters.load_from_json(loaded_params)
else:
loaded_params = Parameters.load_from_json(
loaded_params, force_no_ddp=True
)

# MPI has to be specified upon loading, in contrast to GPU.
if load_with_mpi is not None:
Expand Down
4 changes: 4 additions & 0 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def load_run(
params_format="json",
load_runner=True,
prepare_data=True,
load_with_ddp=None,
):
"""
Load a run.
Expand Down Expand Up @@ -205,6 +206,9 @@ def load_run(
params_format=params_format,
load_runner=load_runner,
prepare_data=prepare_data,
load_with_gpu=None,
load_with_mpi=None,
load_with_ddp=load_with_ddp,
)

@classmethod
Expand Down

0 comments on commit af1081e

Please sign in to comment.