diff --git a/docs/source/advanced_usage/trainingmodel.rst b/docs/source/advanced_usage/trainingmodel.rst index ddb429368..4413ab078 100644 --- a/docs/source/advanced_usage/trainingmodel.rst +++ b/docs/source/advanced_usage/trainingmodel.rst @@ -220,3 +220,68 @@ via The full path for ``path_to_visualization`` can be accessed via ``trainer.full_visualization_path``. + + +Training in parallel +******************** + +If large models or large data sets are employed, training may be slow even +if a GPU is used. In this case, multiple GPUs can be employed with MALA +using the ``DistributedDataParallel`` (DDP) formalism of the ``torch`` library. +To use DDP, make sure you have `NCCL `_ +installed on your system. + +To activate and use DDP in MALA, almost no modification of your training script +is necessary. Simply activate DDP in your ``Parameters`` object. Make sure to +also enable GPU, since parallel training is currently only supported on GPUs. + + .. code-block:: python + + parameters = mala.Parameters() + parameters.use_gpu = True + parameters.use_ddp = True + +MALA is now set up for parallel training. DDP works across multiple compute +nodes on HPC infrastructure as well as on a single machine hosting multiple +GPUs. While essentially no modification of the python script is necessary, some +modifications for calling the python script may be necessary, to ensure +that DDP has all the information it needs for inter/intra-node communication. +This setup *may* differ across machines/clusters. During testing, the +following setup was confirmed to work on an HPC cluster using the +``slurm`` scheduler. + + .. code-block:: bash + + #SBATCH --nodes=NUMBER_OF_NODES + #SBATCH --ntasks-per-node=NUMBER_OF_TASKS_PER_NODE + #SBATCH --gres=gpu:NUMBER_OF_TASKS_PER_NODE + # Add more arguments as needed + ... + + # Load more modules as needed + ... + + # This port can be arbitrarily chosen. + # Given here is the torchrun default + export MASTER_PORT=29500 + + # Find out the host node. + echo "NODELIST="${SLURM_NODELIST} + master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) + export MASTER_ADDR=$master_addr + echo "MASTER_ADDR="$MASTER_ADDR + + # Run using srun. + srun -u bash -c ' + # Export additional per process variables + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + export WORLD_SIZE=$SLURM_NTASKS + + python3 -u training.py + ' + +An overview of environment variables to be set can be found `in the official documentation `_. +A general tutorial on DDP itself can be found `here `_. + + diff --git a/install/mala_gpu_base_environment.yml b/install/mala_gpu_base_environment.yml index c3e9e6c9f..340fef170 100644 --- a/install/mala_gpu_base_environment.yml +++ b/install/mala_gpu_base_environment.yml @@ -1,4 +1,6 @@ name: mala-gpu channels: - - defaults - conda-forge + - defaults +dependencies: + - python=3.10 diff --git a/mala/common/check_modules.py b/mala/common/check_modules.py index 6bb96094d..b504f213a 100644 --- a/mala/common/check_modules.py +++ b/mala/common/check_modules.py @@ -11,10 +11,6 @@ def check_modules(): "available": False, "description": "Enables inference parallelization.", }, - "horovod": { - "available": False, - "description": "Enables training parallelization.", - }, "lammps": { "available": False, "description": "Enables descriptor calculation for data preprocessing " diff --git a/mala/common/parallelizer.py b/mala/common/parallelizer.py index 1bffdfedb..160695a42 100644 --- a/mala/common/parallelizer.py +++ b/mala/common/parallelizer.py @@ -2,15 +2,13 @@ from collections import defaultdict import platform +import os import warnings -try: - import horovod.torch as hvd -except ModuleNotFoundError: - pass import torch +import torch.distributed as dist -use_horovod = False +use_ddp = False use_mpi = False comm = None local_mpi_rank = None @@ -33,45 +31,43 @@ def set_current_verbosity(new_value): current_verbosity = new_value -def set_horovod_status(new_value): +def set_ddp_status(new_value): """ - Set the horovod status. + Set the ddp status. - By setting the horovod status via this function it can be ensured that + By setting the ddp status via this function it can be ensured that printing works in parallel. The Parameters class does that for the user. Parameters ---------- new_value : bool - Value the horovod status has. + Value the ddp status has. """ if use_mpi is True and new_value is True: raise Exception( - "Cannot use horovod and inference-level MPI at " - "the same time yet." + "Cannot use ddp and inference-level MPI at " "the same time yet." ) - global use_horovod - use_horovod = new_value + global use_ddp + use_ddp = new_value def set_mpi_status(new_value): """ Set the MPI status. - By setting the horovod status via this function it can be ensured that + By setting the MPI status via this function it can be ensured that printing works in parallel. The Parameters class does that for the user. Parameters ---------- new_value : bool - Value the horovod status has. + Value the MPI status has. """ - if use_horovod is True and new_value is True: + if use_ddp is True and new_value is True: raise Exception( - "Cannot use horovod and inference-level MPI at " - "the same time yet." + "Cannot use ddp and inference-level MPI at " "the same time yet." ) global use_mpi use_mpi = new_value @@ -119,8 +115,8 @@ def get_rank(): The rank of the current thread. """ - if use_horovod: - return hvd.rank() + if use_ddp: + return dist.get_rank() if use_mpi: return comm.Get_rank() return 0 @@ -159,8 +155,8 @@ def get_local_rank(): FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - if use_horovod: - return hvd.local_rank() + if use_ddp: + return int(os.environ.get("LOCAL_RANK")) if use_mpi: global local_mpi_rank if local_mpi_rank is None: @@ -187,8 +183,8 @@ def get_size(): size : int The number of ranks. """ - if use_horovod: - return hvd.size() + if use_ddp: + return dist.get_world_size() if use_mpi: return comm.Get_size() @@ -209,8 +205,8 @@ def get_comm(): def barrier(): """General interface for a barrier.""" - if use_horovod: - hvd.allreduce(torch.tensor(0), name="barrier") + if use_ddp: + dist.barrier() if use_mpi: comm.Barrier() return diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 20c471334..3627bd40f 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -7,19 +7,13 @@ import pickle from time import sleep -horovod_available = False -try: - import horovod.torch as hvd - - horovod_available = True -except ModuleNotFoundError: - pass import numpy as np import torch +import torch.distributed as dist from mala.common.parallelizer import ( printout, - set_horovod_status, + set_ddp_status, set_mpi_status, get_rank, get_local_rank, @@ -40,7 +34,7 @@ def __init__( super(ParametersBase, self).__init__() self._configuration = { "gpu": False, - "horovod": False, + "ddp": False, "mpi": False, "device": "cpu", "openpmd_configuration": {}, @@ -76,8 +70,8 @@ def show(self, indent=""): def _update_gpu(self, new_gpu): self._configuration["gpu"] = new_gpu - def _update_horovod(self, new_horovod): - self._configuration["horovod"] = new_horovod + def _update_ddp(self, new_ddp): + self._configuration["ddp"] = new_ddp def _update_mpi(self, new_mpi): self._configuration["mpi"] = new_mpi @@ -686,10 +680,6 @@ class ParametersRunning(ParametersBase): validation loss has to plateau before the schedule takes effect). Default: 0. - use_compression : bool - If True and horovod is used, horovod compression will be used for - allreduce communication. This can improve performance. - num_workers : int Number of workers to be used for data loading. @@ -750,7 +740,6 @@ def __init__(self): self.learning_rate_scheduler = None self.learning_rate_decay = 0.1 self.learning_rate_patience = 0 - self.use_compression = False self.num_workers = 0 self.use_shuffling_for_samplers = True self.checkpoints_each_epoch = 0 @@ -766,8 +755,8 @@ def __init__(self): self.training_report_frequency = 1000 self.profiler_range = None # [1000, 2000] - def _update_horovod(self, new_horovod): - super(ParametersRunning, self)._update_horovod(new_horovod) + def _update_ddp(self, new_ddp): + super(ParametersRunning, self)._update_ddp(new_ddp) self.during_training_metric = self.during_training_metric self.after_before_training_metric = self.after_before_training_metric @@ -789,10 +778,10 @@ def during_training_metric(self): @during_training_metric.setter def during_training_metric(self, value): if value != "ldos": - if self._configuration["horovod"]: + if self._configuration["ddp"]: raise Exception( "Currently, MALA can only operate with the " - '"ldos" metric for horovod runs.' + '"ldos" metric for ddp runs.' ) self._during_training_metric = value @@ -814,20 +803,20 @@ def after_before_training_metric(self): @after_before_training_metric.setter def after_before_training_metric(self, value): if value != "ldos": - if self._configuration["horovod"]: + if self._configuration["ddp"]: raise Exception( "Currently, MALA can only operate with the " - '"ldos" metric for horovod runs.' + '"ldos" metric for ddp runs.' ) self._after_before_training_metric = value @during_training_metric.setter def during_training_metric(self, value): if value != "ldos": - if self._configuration["horovod"]: + if self._configuration["ddp"]: raise Exception( "Currently, MALA can only operate with the " - '"ldos" metric for horovod runs.' + '"ldos" metric for ddp runs.' ) self._during_training_metric = value @@ -1218,7 +1207,7 @@ def __init__(self): # Properties self.use_gpu = False - self.use_horovod = False + self.use_ddp = False self.use_mpi = False self.verbosity = 1 self.device = "cpu" @@ -1304,32 +1293,36 @@ def use_gpu(self, value): self.hyperparameters._update_gpu(self.use_gpu) @property - def use_horovod(self): - """Control whether or not horovod is used for parallel training.""" - return self._use_horovod - - @use_horovod.setter - def use_horovod(self, value): - if value is False: - self._use_horovod = False - else: - if horovod_available: - hvd.init() - # Invalidate, will be updated in setter. - set_horovod_status(value) - self.device = None - self._use_horovod = value - self.network._update_horovod(self.use_horovod) - self.descriptors._update_horovod(self.use_horovod) - self.targets._update_horovod(self.use_horovod) - self.data._update_horovod(self.use_horovod) - self.running._update_horovod(self.use_horovod) - self.hyperparameters._update_horovod(self.use_horovod) - else: - parallel_warn( - "Horovod requested, but not installed found. " - "MALA will operate without horovod only." - ) + def use_ddp(self): + """Control whether or not dd is used for parallel training.""" + return self._use_ddp + + @use_ddp.setter + def use_ddp(self, value): + if value: + if self.verbosity > 1: + print("Initializing torch.distributed.") + # JOSHR: + # We start up torch distributed here. As is fairly standard + # convention, we get the rank and world size arguments via + # environment variables (RANK, WORLD_SIZE). In addition to + # those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT + # should be set. + rank = int(os.environ.get("RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + set_ddp_status(value) + # Invalidate, will be updated in setter. + self.device = None + self._use_ddp = value + self.network._update_ddp(self.use_ddp) + self.descriptors._update_ddp(self.use_ddp) + self.targets._update_ddp(self.use_ddp) + self.data._update_ddp(self.use_ddp) + self.running._update_ddp(self.use_ddp) + self.hyperparameters._update_ddp(self.use_ddp) @property def device(self): @@ -1352,7 +1345,7 @@ def device(self, value): @property def use_mpi(self): - """Control whether or not horovod is used for parallel training.""" + """Control whether or not MPI is used for paralle inference.""" return self._use_mpi @use_mpi.setter @@ -1551,7 +1544,9 @@ def optuna_singlenode_setup(self, wait_time=0): self.hyperparameters._update_device(device_temp) @classmethod - def load_from_file(cls, file, save_format="json", no_snapshots=False): + def load_from_file( + cls, file, save_format="json", no_snapshots=False, force_no_ddp=False + ): """ Load a Parameters object from a file. @@ -1606,7 +1601,10 @@ def load_from_file(cls, file, save_format="json", no_snapshots=False): not isinstance(json_dict[key], dict) or key == "openpmd_configuration" ): - setattr(loaded_parameters, key, json_dict[key]) + if key == "use_ddp" and force_no_ddp is True: + setattr(loaded_parameters, key, False) + else: + setattr(loaded_parameters, key, json_dict[key]) if no_snapshots is True: loaded_parameters.data.snapshot_directories_list = [] else: @@ -1639,7 +1637,7 @@ def load_from_pickle(cls, file, no_snapshots=False): ) @classmethod - def load_from_json(cls, file, no_snapshots=False): + def load_from_json(cls, file, no_snapshots=False, force_no_ddp=False): """ Load a Parameters object from a json file. @@ -1659,5 +1657,8 @@ def load_from_json(cls, file, no_snapshots=False): """ return Parameters.load_from_file( - file, save_format="json", no_snapshots=no_snapshots + file, + save_format="json", + no_snapshots=no_snapshots, + force_no_ddp=force_no_ddp, ) diff --git a/mala/datahandling/data_handler.py b/mala/datahandling/data_handler.py index b40a93ea1..7b8fc2a43 100644 --- a/mala/datahandling/data_handler.py +++ b/mala/datahandling/data_handler.py @@ -72,14 +72,14 @@ def __init__( if self.input_data_scaler is None: self.input_data_scaler = DataScaler( self.parameters.input_rescaling_type, - use_horovod=self.use_horovod, + use_ddp=self.use_ddp, ) self.output_data_scaler = output_data_scaler if self.output_data_scaler is None: self.output_data_scaler = DataScaler( self.parameters.output_rescaling_type, - use_horovod=self.use_horovod, + use_ddp=self.use_ddp, ) # Actual data points in the different categories. @@ -639,7 +639,8 @@ def __build_datasets(self): self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, + self.parameters._configuration["device"], ) ) self.validation_data_sets.append( @@ -650,7 +651,8 @@ def __build_datasets(self): self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, + self.parameters._configuration["device"], ) ) @@ -663,7 +665,8 @@ def __build_datasets(self): self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, + self.parameters._configuration["device"], input_requires_grad=True, ) ) @@ -706,7 +709,7 @@ def __build_datasets(self): self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, ) ) if snapshot.snapshot_function == "va": @@ -720,7 +723,7 @@ def __build_datasets(self): self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, ) ) if snapshot.snapshot_function == "te": @@ -734,7 +737,7 @@ def __build_datasets(self): self.output_data_scaler, self.descriptor_calculator, self.target_calculator, - self.use_horovod, + self.use_ddp, input_requires_grad=True, ) ) diff --git a/mala/datahandling/data_handler_base.py b/mala/datahandling/data_handler_base.py index e59627cc5..54e27e959 100644 --- a/mala/datahandling/data_handler_base.py +++ b/mala/datahandling/data_handler_base.py @@ -37,7 +37,7 @@ def __init__( descriptor_calculator=None, ): self.parameters: ParametersData = parameters.data - self.use_horovod = parameters.use_horovod + self.use_ddp = parameters.use_ddp # Calculators used to parse data from compatible files. self.target_calculator = target_calculator diff --git a/mala/datahandling/data_scaler.py b/mala/datahandling/data_scaler.py index 4eebad467..e3c8a5328 100644 --- a/mala/datahandling/data_scaler.py +++ b/mala/datahandling/data_scaler.py @@ -1,14 +1,9 @@ """DataScaler class for scaling DFT data.""" import pickle - -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by parameters class - pass import numpy as np import torch +import torch.distributed as dist from mala.common.parameters import printout @@ -34,13 +29,13 @@ class DataScaler: - "feature-wise-normal": Row Min-Max scaling (Scale to be in range 0...1) - use_horovod : bool - If True, the DataScaler will use horovod to check that data is + use_ddp : bool + If True, the DataScaler will use ddp to check that data is only saved on the root process in parallel execution. """ - def __init__(self, typestring, use_horovod=False): - self.use_horovod = use_horovod + def __init__(self, typestring, use_ddp=False): + self.use_ddp = use_ddp self.typestring = typestring self.scale_standard = False self.scale_normal = False @@ -409,9 +404,9 @@ def save(self, filename, save_format="pickle"): save_format : File format which will be used for saving. """ - # If we use horovod, only save the network on root. - if self.use_horovod: - if hvd.rank() != 0: + # If we use ddp, only save the network on root. + if self.use_ddp: + if dist.get_rank() != 0: return if save_format == "pickle": with open(filename, "wb") as handle: diff --git a/mala/datahandling/lazy_load_dataset.py b/mala/datahandling/lazy_load_dataset.py index ac07cdcb6..00810beb3 100644 --- a/mala/datahandling/lazy_load_dataset.py +++ b/mala/datahandling/lazy_load_dataset.py @@ -2,13 +2,9 @@ import os -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class. - pass import numpy as np import torch +import torch.distributed as dist from torch.utils.data import Dataset from mala.common.parallelizer import barrier @@ -47,8 +43,8 @@ class LazyLoadDataset(Dataset): target_calculator : mala.targets.target.Target or derivative Used to do unit conversion on output data. - use_horovod : bool - If true, it is assumed that horovod is used. + use_ddp : bool + If true, it is assumed that ddp is used. input_requires_grad : bool If True, then the gradient is stored for the inputs. @@ -62,7 +58,8 @@ def __init__( output_data_scaler, descriptor_calculator, target_calculator, - use_horovod, + use_ddp, + device, input_requires_grad=False, ): self.snapshot_list = [] @@ -80,9 +77,10 @@ def __init__( self.currently_loaded_file = None self.input_data = np.empty(0) self.output_data = np.empty(0) - self.use_horovod = use_horovod + self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad + self.device = device @property def return_outputs_directly(self): @@ -122,9 +120,14 @@ def mix_datasets(self): """ used_perm = torch.randperm(self.number_of_snapshots) barrier() - if self.use_horovod: - used_perm = hvd.broadcast(used_perm, 0) - self.snapshot_list = [self.snapshot_list[i] for i in used_perm] + if self.use_ddp: + used_perm = used_perm.to(device=self.device) + dist.broadcast(used_perm, 0) + self.snapshot_list = [ + self.snapshot_list[i] for i in used_perm.to("cpu") + ] + else: + self.snapshot_list = [self.snapshot_list[i] for i in used_perm] self.get_new_data(0) def get_new_data(self, file_index): diff --git a/mala/datahandling/lazy_load_dataset_single.py b/mala/datahandling/lazy_load_dataset_single.py index 83fa30548..33d7fee87 100644 --- a/mala/datahandling/lazy_load_dataset_single.py +++ b/mala/datahandling/lazy_load_dataset_single.py @@ -39,8 +39,8 @@ class LazyLoadDatasetSingle(Dataset): target_calculator : mala.targets.target.Target or derivative Used to do unit conversion on output data. - use_horovod : bool - If true, it is assumed that horovod is used. + use_ddp : bool + If true, it is assumed that ddp is used. input_requires_grad : bool If True, then the gradient is stored for the inputs. @@ -56,7 +56,7 @@ def __init__( output_data_scaler, descriptor_calculator, target_calculator, - use_horovod, + use_ddp, input_requires_grad=False, ): self.snapshot = snapshot @@ -74,7 +74,7 @@ def __init__( self.currently_loaded_file = None self.input_data = np.empty(0) self.output_data = np.empty(0) - self.use_horovod = use_horovod + self.use_ddp = use_ddp self.return_outputs_directly = False self.input_requires_grad = input_requires_grad diff --git a/mala/network/network.py b/mala/network/network.py index 668d02a6d..3835702b9 100644 --- a/mala/network/network.py +++ b/mala/network/network.py @@ -3,18 +3,13 @@ from abc import abstractmethod import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as functional from mala.common.parameters import Parameters from mala.common.parallelizer import printout -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by parameters class - pass - class Network(nn.Module): """ @@ -69,7 +64,7 @@ def __new__(cls, params: Parameters): def __init__(self, params: Parameters): # copy the network params from the input parameter object - self.use_horovod = params.use_horovod + self.use_ddp = params.use_ddp self.mini_batch_size = params.running.mini_batch_size self.params = params.network @@ -162,9 +157,9 @@ def save_network(self, path_to_file): path_to_file : string Path to the file in which the network should be saved. """ - # If we use horovod, only save the network on root. - if self.use_horovod: - if hvd.rank() != 0: + # If we use ddp, only save the network on root. + if self.use_ddp: + if dist.get_rank() != 0: return torch.save( self.state_dict(), diff --git a/mala/network/objective_naswot.py b/mala/network/objective_naswot.py index a4fd68d25..96377e527 100644 --- a/mala/network/objective_naswot.py +++ b/mala/network/objective_naswot.py @@ -76,7 +76,7 @@ def __call__(self, trial): do_shuffle = self.params.running.use_shuffling_for_samplers if ( self.data_handler.parameters.use_lazy_loading - or self.params.use_horovod + or self.params.use_ddp ): do_shuffle = False if self.params.running.use_shuffling_for_samplers: diff --git a/mala/network/runner.py b/mala/network/runner.py index 4ed514266..a5f620071 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -3,14 +3,12 @@ import os from zipfile import ZipFile, ZIP_STORED -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np import torch +import torch.distributed as dist +import mala +from mala.common.parallelizer import get_rank from mala.common.parameters import ParametersRunning from mala.network.network import Network from mala.datahandling.data_scaler import DataScaler @@ -80,50 +78,62 @@ def save_run( data is already present in the DataHandler object, it can be saved by setting. """ - model_file = run_name + ".network.pth" - iscaler_file = run_name + ".iscaler.pkl" - oscaler_file = run_name + ".oscaler.pkl" - params_file = run_name + ".params.json" - if save_runner: - optimizer_file = run_name + ".optimizer.pth" - - self.parameters_full.save(os.path.join(save_path, params_file)) - self.network.save_network(os.path.join(save_path, model_file)) - self.data.input_data_scaler.save(os.path.join(save_path, iscaler_file)) - self.data.output_data_scaler.save( - os.path.join(save_path, oscaler_file) - ) + # If a model is trained via DDP, we need to make sure saving is only + # performed on rank 0. + if get_rank() == 0: + model_file = run_name + ".network.pth" + iscaler_file = run_name + ".iscaler.pkl" + oscaler_file = run_name + ".oscaler.pkl" + params_file = run_name + ".params.json" + if save_runner: + optimizer_file = run_name + ".optimizer.pth" + + self.parameters_full.save(os.path.join(save_path, params_file)) + if self.parameters_full.use_ddp: + self.network.module.save_network( + os.path.join(save_path, model_file) + ) + else: + self.network.save_network(os.path.join(save_path, model_file)) + self.data.input_data_scaler.save( + os.path.join(save_path, iscaler_file) + ) + self.data.output_data_scaler.save( + os.path.join(save_path, oscaler_file) + ) - files = [model_file, iscaler_file, oscaler_file, params_file] - if save_runner: - files += [optimizer_file] - if zip_run: - if additional_calculation_data is not None: - additional_calculation_file = run_name + ".info.json" - if isinstance(additional_calculation_data, str): - self.data.target_calculator.read_additional_calculation_data( - additional_calculation_data - ) - self.data.target_calculator.write_additional_calculation_data( - os.path.join(save_path, additional_calculation_file) - ) - elif isinstance(additional_calculation_data, bool): - if additional_calculation_data: + files = [model_file, iscaler_file, oscaler_file, params_file] + if save_runner: + files += [optimizer_file] + if zip_run: + if additional_calculation_data is not None: + additional_calculation_file = run_name + ".info.json" + if isinstance(additional_calculation_data, str): + self.data.target_calculator.read_additional_calculation_data( + additional_calculation_data + ) self.data.target_calculator.write_additional_calculation_data( os.path.join( save_path, additional_calculation_file ) ) + elif isinstance(additional_calculation_data, bool): + if additional_calculation_data: + self.data.target_calculator.write_additional_calculation_data( + os.path.join( + save_path, additional_calculation_file + ) + ) - files.append(additional_calculation_file) - with ZipFile( - os.path.join(save_path, run_name + ".zip"), - "w", - compression=ZIP_STORED, - ) as zip_obj: - for file in files: - zip_obj.write(os.path.join(save_path, file), file) - os.remove(os.path.join(save_path, file)) + files.append(additional_calculation_file) + with ZipFile( + os.path.join(save_path, run_name + ".zip"), + "w", + compression=ZIP_STORED, + ) as zip_obj: + for file in files: + zip_obj.write(os.path.join(save_path, file), file) + os.remove(os.path.join(save_path, file)) @classmethod def load_run( @@ -136,6 +146,7 @@ def load_run( prepare_data=False, load_with_mpi=None, load_with_gpu=None, + load_with_ddp=None, ): """ Load a run. @@ -163,7 +174,7 @@ def load_run( If True, the data will be loaded into memory. This is needed when continuing a model training. - load_with_mpi : bool + load_with_mpi : bool or None Can be used to actively enable/disable MPI during loading. Default is None, so that the MPI parameters set during training/saving of the model are not overwritten. @@ -171,7 +182,7 @@ def load_run( MPI already has to be activated here, if it was not activated during training! - load_with_gpu : bool + load_with_gpu : bool or None Can be used to actively enable/disable GPU during loading. Default is None, so that the GPU parameters set during training/saving of the model are not overwritten. @@ -180,6 +191,14 @@ def load_run( activated during training. Can also be used to activate a CPU based inference, by setting it to False. + load_with_ddp : bool or None + Can be used to actively disable DDP (pytorch distributed + data parallel used for parallel training) during loading. + Default is None, which for loading a Trainer object will not + interfere with DDP settings. For Predictor and Tester class, + this command will automatically disable DDP during loading, + as inference is using MPI rather than DDP for parallelization. + Return ------ loaded_params : mala.common.parameters.Parameters @@ -222,7 +241,13 @@ def load_run( path, run_name + ".params." + params_format ) - loaded_params = Parameters.load_from_json(loaded_params) + # Neither Predictor nor Runner classes can work with DDP. + if cls is mala.Trainer: + loaded_params = Parameters.load_from_json(loaded_params) + else: + loaded_params = Parameters.load_from_json( + loaded_params, force_no_ddp=True + ) # MPI has to be specified upon loading, in contrast to GPU. if load_with_mpi is not None: @@ -421,23 +446,26 @@ def __prepare_to_run(self): """ Prepare the Runner to run the Network. - This includes e.g. horovod setup. + This includes e.g. ddp setup. """ - # See if we want to use horovod. - if self.parameters_full.use_horovod: + # See if we want to use ddp. + if self.parameters_full.use_ddp: if self.parameters_full.use_gpu: # We cannot use "printout" here because this is supposed # to happen on every rank. + size = dist.get_world_size() + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK")) if self.parameters_full.verbosity >= 2: print( "size=", - hvd.size(), + size, "global_rank=", - hvd.rank(), + rank, "local_rank=", - hvd.local_rank(), + local_rank, "device=", - torch.cuda.get_device_name(hvd.local_rank()), + torch.cuda.get_device_name(local_rank), ) # pin GPU to local rank - torch.cuda.set_device(hvd.local_rank()) + torch.cuda.set_device(local_rank) diff --git a/mala/network/trainer.py b/mala/network/trainer.py index bc4a93454..81977c40e 100644 --- a/mala/network/trainer.py +++ b/mala/network/trainer.py @@ -5,18 +5,16 @@ from datetime import datetime from packaging import version -try: - import horovod.torch as hvd -except ModuleNotFoundError: - # Warning is thrown by Parameters class - pass import numpy as np import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from torch import optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from mala.common.parameters import printout +from mala.common.parallelizer import get_local_rank from mala.datahandling.fast_tensor_dataset import FastTensorDataset from mala.network.runner import Runner from mala.datahandling.lazy_load_dataset_single import LazyLoadDatasetSingle @@ -46,6 +44,16 @@ class Trainer(Runner): def __init__(self, params, network, data, optimizer_dict=None): # copy the parameters into the class. super(Trainer, self).__init__(params, network, data) + + if self.parameters_full.use_ddp: + printout("DDP activated, wrapping model in DDP.", min_verbosity=1) + # JOSHR: using streams here to maintain compatibility with + # graph capture + s = torch.cuda.Stream() + with torch.cuda.stream(s): + self.network = DDP(self.network) + torch.cuda.current_stream().wait_stream(s) + self.final_test_loss = float("inf") self.initial_test_loss = float("inf") self.final_validation_loss = float("inf") @@ -59,7 +67,7 @@ def __init__(self, params, network, data, optimizer_dict=None): self.validation_data_loaders = [] self.test_data_loaders = [] - # Samplers for the horovod case. + # Samplers for the ddp case. self.train_sampler = None self.test_sampler = None self.validation_sampler = None @@ -198,6 +206,9 @@ def load_run( params_format=params_format, load_runner=load_runner, prepare_data=prepare_data, + load_with_gpu=None, + load_with_mpi=None, + load_with_ddp=None, ) @classmethod @@ -227,7 +238,11 @@ def _load_from_run(cls, params, network, data, file=None): The trainer that was loaded from the file. """ # First, load the checkpoint. - checkpoint = torch.load(file) + if params.use_ddp: + map_location = {"cuda:%d" % 0: "cuda:%d" % get_local_rank()} + checkpoint = torch.load(file, map_location=map_location) + else: + checkpoint = torch.load(file) # Now, create the Trainer class with it. loaded_trainer = Trainer( @@ -256,11 +271,17 @@ def train_network(self): ) # Collect and average all the losses from all the devices - if self.parameters_full.use_horovod: - vloss = self.__average_validation(vloss, "average_loss") + if self.parameters_full.use_ddp: + vloss = self.__average_validation( + vloss, "average_loss", self.parameters._configuration["device"] + ) self.initial_validation_loss = vloss - if self.data.test_data_set is not None: - tloss = self.__average_validation(tloss, "average_loss") + if self.data.test_data_sets: + tloss = self.__average_validation( + tloss, + "average_loss", + self.parameters._configuration["device"], + ) self.initial_test_loss = tloss printout( @@ -301,7 +322,7 @@ def train_network(self): ) # train sampler - if self.parameters_full.use_horovod: + if self.train_sampler: self.train_sampler.set_epoch(epoch) # shuffle dataset if necessary @@ -406,8 +427,12 @@ def train_network(self): self.parameters.during_training_metric, ) - if self.parameters_full.use_horovod: - vloss = self.__average_validation(vloss, "average_loss") + if self.parameters_full.use_ddp: + vloss = self.__average_validation( + vloss, + "average_loss", + self.parameters._configuration["device"], + ) if self.parameters_full.verbosity > 1: printout( "Epoch {0}: validation data loss: {1}, " @@ -526,8 +551,12 @@ def train_network(self): "validation", self.parameters.after_before_training_metric, ) - if self.parameters_full.use_horovod: - vloss = self.__average_validation(vloss, "average_loss") + if self.parameters_full.use_ddp: + vloss = self.__average_validation( + vloss, + "average_loss", + self.parameters._configuration["device"], + ) # Calculate final loss. self.final_validation_loss = vloss @@ -540,8 +569,12 @@ def train_network(self): "test", self.parameters.after_before_training_metric, ) - if self.parameters_full.use_horovod: - tloss = self.__average_validation(tloss, "average_loss") + if self.parameters_full.use_ddp: + tloss = self.__average_validation( + tloss, + "average_loss", + self.parameters._configuration["device"], + ) printout("Final test data loss: ", tloss, min_verbosity=0) self.final_test_loss = tloss @@ -566,16 +599,16 @@ def __prepare_to_train(self, optimizer_dict): if optimizer_dict is not None: self.last_epoch = optimizer_dict["epoch"] + 1 - # Scale the learning rate according to horovod. - if self.parameters_full.use_horovod: - if hvd.size() > 1 and self.last_epoch == 0: + # Scale the learning rate according to ddp. + if self.parameters_full.use_ddp: + if dist.get_world_size() > 1 and self.last_epoch == 0: printout( "Rescaling learning rate because multiple workers are" " used for training.", min_verbosity=1, ) self.parameters.learning_rate = ( - self.parameters.learning_rate * hvd.size() + self.parameters.learning_rate * dist.get_world_size() ) # Choose an optimizer to use. @@ -614,16 +647,10 @@ def __prepare_to_train(self, optimizer_dict): self.patience_counter = optimizer_dict["early_stopping_counter"] self.last_loss = optimizer_dict["early_stopping_last_loss"] - if self.parameters_full.use_horovod: + if self.parameters_full.use_ddp: # scaling the batch size for multiGPU per node # self.batch_size= self.batch_size*hvd.local_size() - compression = ( - hvd.Compression.fp16 - if self.parameters_full.running.use_compression - else hvd.Compression.none - ) - # If lazy loading is used we do not shuffle the data points on # their own, but rather shuffle them # by shuffling the files themselves and then reading file by file @@ -636,17 +663,16 @@ def __prepare_to_train(self, optimizer_dict): self.train_sampler = ( torch.utils.data.distributed.DistributedSampler( self.data.training_data_sets[0], - num_replicas=hvd.size(), - rank=hvd.rank(), + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), shuffle=do_shuffle, ) ) - self.validation_sampler = ( torch.utils.data.distributed.DistributedSampler( self.data.validation_data_sets[0], - num_replicas=hvd.size(), - rank=hvd.rank(), + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), shuffle=False, ) ) @@ -655,25 +681,12 @@ def __prepare_to_train(self, optimizer_dict): self.test_sampler = ( torch.utils.data.distributed.DistributedSampler( self.data.test_data_sets[0], - num_replicas=hvd.size(), - rank=hvd.rank(), + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), shuffle=False, ) ) - # broadcaste parameters and optimizer state from root device to - # other devices - hvd.broadcast_parameters(self.network.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) - - # Wraps the opimizer for multiGPU operation - self.optimizer = hvd.DistributedOptimizer( - self.optimizer, - named_parameters=self.network.named_parameters(), - compression=compression, - op=hvd.Average, - ) - # Instantiate the learning rate scheduler, if necessary. if self.parameters.learning_rate_scheduler == "ReduceLROnPlateau": self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( @@ -700,7 +713,7 @@ def __prepare_to_train(self, optimizer_dict): do_shuffle = self.parameters.use_shuffling_for_samplers if ( self.data.parameters.use_lazy_loading - or self.parameters_full.use_horovod + or self.parameters_full.use_ddp ): do_shuffle = False @@ -796,9 +809,15 @@ def __process_mini_batch(self, network, input_data, target_data): enabled=self.parameters.use_mixed_precision ): prediction = network(input_data) - loss = network.calculate_loss( - prediction, target_data - ) + if self.parameters_full.use_ddp: + # JOSHR: We have to use "module" here to access custom method of DDP wrapped model + loss = network.module.calculate_loss( + prediction, target_data + ) + else: + loss = network.calculate_loss( + prediction, target_data + ) if self.gradscaler: self.gradscaler.scale(loss).backward() @@ -814,7 +833,7 @@ def __process_mini_batch(self, network, input_data, target_data): # Capture graph self.train_graph = torch.cuda.CUDAGraph() - self.network.zero_grad(set_to_none=True) + network.zero_grad(set_to_none=True) with torch.cuda.graph(self.train_graph): with torch.cuda.amp.autocast( enabled=self.parameters.use_mixed_precision @@ -823,9 +842,14 @@ def __process_mini_batch(self, network, input_data, target_data): self.static_input_data ) - self.static_loss = network.calculate_loss( - self.static_prediction, self.static_target_data - ) + if self.parameters_full.use_ddp: + self.static_loss = network.module.calculate_loss( + self.static_prediction, self.static_target_data + ) + else: + self.static_loss = network.calculate_loss( + self.static_prediction, self.static_target_data + ) if self.gradscaler: self.gradscaler.scale(self.static_loss).backward() @@ -851,7 +875,12 @@ def __process_mini_batch(self, network, input_data, target_data): torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push("loss") - loss = network.calculate_loss(prediction, target_data) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss( + prediction, target_data + ) + else: + loss = network.calculate_loss(prediction, target_data) # loss torch.cuda.nvtx.range_pop() @@ -874,7 +903,10 @@ def __process_mini_batch(self, network, input_data, target_data): return loss else: prediction = network(input_data) - loss = network.calculate_loss(prediction, target_data) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss(prediction, target_data) + else: + loss = network.calculate_loss(prediction, target_data) loss.backward() self.optimizer.step() self.optimizer.zero_grad() @@ -950,9 +982,14 @@ def __validate_network(self, network, data_set_type, validation_type): enabled=self.parameters.use_mixed_precision ): prediction = network(x) - loss = network.calculate_loss( - prediction, y - ) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss( + prediction, y + ) + else: + loss = network.calculate_loss( + prediction, y + ) torch.cuda.current_stream( self.parameters._configuration["device"] ).wait_stream(s) @@ -976,10 +1013,16 @@ def __validate_network(self, network, data_set_type, validation_type): self.static_input_validation ) ) - self.static_loss_validation = network.calculate_loss( - self.static_prediction_validation, - self.static_target_validation, - ) + if self.parameters_full.use_ddp: + self.static_loss_validation = network.module.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) + else: + self.static_loss_validation = network.calculate_loss( + self.static_prediction_validation, + self.static_target_validation, + ) if self.validation_graph: self.static_input_validation.copy_(x) @@ -993,9 +1036,14 @@ def __validate_network(self, network, data_set_type, validation_type): enabled=self.parameters.use_mixed_precision ): prediction = network(x) - loss = network.calculate_loss( - prediction, y - ) + if self.parameters_full.use_ddp: + loss = network.module.calculate_loss( + prediction, y + ) + else: + loss = network.calculate_loss( + prediction, y + ) validation_loss_sum += loss if ( batchid != 0 @@ -1027,9 +1075,16 @@ def __validate_network(self, network, data_set_type, validation_type): x = x.to(self.parameters._configuration["device"]) y = y.to(self.parameters._configuration["device"]) prediction = network(x) - validation_loss_sum += network.calculate_loss( - prediction, y - ).item() + if self.parameters_full.use_ddp: + validation_loss_sum += ( + network.module.calculate_loss( + prediction, y + ).item() + ) + else: + validation_loss_sum += network.calculate_loss( + prediction, y + ).item() batchid += 1 validation_loss = validation_loss_sum.item() / batchid @@ -1189,8 +1244,8 @@ def __create_training_checkpoint(self): # Next, we save all the other objects. - if self.parameters_full.use_horovod: - if hvd.rank() != 0: + if self.parameters_full.use_ddp: + if dist.get_rank() != 0: return if self.scheduler is None: save_dict = { @@ -1214,8 +1269,9 @@ def __create_training_checkpoint(self): self.save_run(self.parameters.checkpoint_name, save_runner=True) @staticmethod - def __average_validation(val, name): + def __average_validation(val, name, device="cpu"): """Average validation over multiple parallel processes.""" - tensor = torch.tensor(val) - avg_loss = hvd.allreduce(tensor, name=name, op=hvd.Average) + tensor = torch.tensor(val, device=device) + dist.all_reduce(tensor) + avg_loss = tensor / dist.get_world_size() return avg_loss.item()