Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MALA-DDP #466

Merged
merged 23 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f80a983
all files updated with changes for ddp implementation
dytnvgl Jun 30, 2023
2f10a23
allowing ddp wrapper to push network for saving during checkpoint
Jun 30, 2023
05f551b
allow checkpoint network save when not using ddp
Jul 3, 2023
50272a7
Merge branch 'refs/heads/develop' into ddp
RandomDefaultUser Apr 25, 2024
e60e995
blackified remaining files
RandomDefaultUser Apr 25, 2024
4175cd0
Added suggestions by Josh
RandomDefaultUser Apr 25, 2024
2cde3f1
Removed DDP in yaml file
RandomDefaultUser Apr 25, 2024
bd68063
Minor reformatting
RandomDefaultUser Apr 26, 2024
3cebb9d
Small bug
RandomDefaultUser Apr 26, 2024
ffa3082
Model only saved on master rank
RandomDefaultUser Apr 26, 2024
04b0050
Adjusted output for parallel
RandomDefaultUser Apr 26, 2024
1fb2c98
Testing if distributed samplers work as default
RandomDefaultUser Apr 30, 2024
a9027a7
Added some documentation
RandomDefaultUser Apr 30, 2024
af1081e
This should fix the inference
RandomDefaultUser Apr 30, 2024
18fa6e2
Trying to fix checkpointing
RandomDefaultUser May 2, 2024
f49e63d
Added docs for new loading parameters
RandomDefaultUser May 2, 2024
e1753d0
This should fix lazy loading mixing
RandomDefaultUser May 2, 2024
36e626c
Missing comma
RandomDefaultUser May 2, 2024
d9c7a73
Made printing for DDP init debug only
RandomDefaultUser May 2, 2024
51235b4
Forgot an equals sign
RandomDefaultUser May 2, 2024
325cf65
Lazy loading working now
RandomDefaultUser May 3, 2024
873e486
Adapted docs to use srun instead of torchrun for example
RandomDefaultUser May 3, 2024
b58c096
Small bugfix to fix CI
RandomDefaultUser May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions install/mala_gpu_base_environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
name: mala-gpu
name: mala-gpu-ddp
channels:
- defaults
- conda-forge
- defaults
dependencies:
- python=3.10
2 changes: 0 additions & 2 deletions mala/common/check_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ def check_modules():
optional_libs = {
"mpi4py": {"available": False, "description":
"Enables inference parallelization."},
"horovod": {"available": False, "description":
"Enables training parallelization."},
"lammps": {"available": False, "description":
"Enables descriptor calculation for data preprocessing "
"and inference."},
Expand Down
46 changes: 22 additions & 24 deletions mala/common/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Functions for operating MALA in parallel."""
from collections import defaultdict
import platform
import os
import warnings

try:
import horovod.torch as hvd
except ModuleNotFoundError:
pass
import torch
import torch.distributed as dist

use_horovod = False
use_ddp = False
use_mpi = False
comm = None
local_mpi_rank = None
Expand All @@ -32,41 +30,41 @@ def set_current_verbosity(new_value):
current_verbosity = new_value


def set_horovod_status(new_value):
def set_ddp_status(new_value):
"""
Set the horovod status.
Set the ddp status.

By setting the horovod status via this function it can be ensured that
By setting the ddp status via this function it can be ensured that
printing works in parallel. The Parameters class does that for the user.

Parameters
----------
new_value : bool
Value the horovod status has.
Value the ddp status has.

"""
if use_mpi is True and new_value is True:
raise Exception("Cannot use horovod and inference-level MPI at "
raise Exception("Cannot use ddp and inference-level MPI at "
"the same time yet.")
global use_horovod
use_horovod = new_value
global use_ddp
use_ddp = new_value


def set_mpi_status(new_value):
"""
Set the MPI status.

By setting the horovod status via this function it can be ensured that
By setting the ddp status via this function it can be ensured that
printing works in parallel. The Parameters class does that for the user.

Parameters
----------
new_value : bool
Value the horovod status has.
Value the ddp status has.

"""
if use_horovod is True and new_value is True:
raise Exception("Cannot use horovod and inference-level MPI at "
if use_ddp is True and new_value is True:
raise Exception("Cannot use ddp and inference-level MPI at "
"the same time yet.")
global use_mpi
use_mpi = new_value
Expand Down Expand Up @@ -113,8 +111,8 @@ def get_rank():
The rank of the current thread.

"""
if use_horovod:
return hvd.rank()
if use_ddp:
return dist.get_rank()
if use_mpi:
return comm.Get_rank()
return 0
Expand Down Expand Up @@ -153,8 +151,8 @@ def get_local_rank():
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
if use_horovod:
return hvd.local_rank()
if use_ddp:
return int(os.environ.get("LOCAL_RANK"))
if use_mpi:
global local_mpi_rank
if local_mpi_rank is None:
Expand All @@ -181,8 +179,8 @@ def get_size():
size : int
The number of ranks.
"""
if use_horovod:
return hvd.size()
if use_ddp:
return dist.get_world_size()
if use_mpi:
return comm.Get_size()

Expand All @@ -203,8 +201,8 @@ def get_comm():

def barrier():
"""General interface for a barrier."""
if use_horovod:
hvd.allreduce(torch.tensor(0), name='barrier')
if use_ddp:
dist.barrier()
if use_mpi:
comm.Barrier()
return
Expand Down
106 changes: 69 additions & 37 deletions mala/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
import pickle
from time import sleep

try:
import horovod.torch as hvd
except ModuleNotFoundError:
pass
import numpy as np
import torch
import torch.distributed as dist

from mala.common.parallelizer import printout, set_horovod_status, \
from mala.common.parallelizer import printout, set_ddp_status, \
set_mpi_status, get_rank, get_local_rank, set_current_verbosity, \
parallel_warn
from mala.common.json_serializable import JSONSerializable
Expand All @@ -26,7 +23,7 @@ class ParametersBase(JSONSerializable):

def __init__(self,):
super(ParametersBase, self).__init__()
self._configuration = {"gpu": False, "horovod": False, "mpi": False,
self._configuration = {"gpu": False, "ddp": False, "mpi": False,
"device": "cpu", "openpmd_configuration": {},
"openpmd_granularity": 1}
pass
Expand Down Expand Up @@ -54,8 +51,8 @@ def show(self, indent=""):
def _update_gpu(self, new_gpu):
self._configuration["gpu"] = new_gpu

def _update_horovod(self, new_horovod):
self._configuration["horovod"] = new_horovod
def _update_ddp(self, new_ddp):
self._configuration["ddp"] = new_ddp

def _update_mpi(self, new_mpi):
self._configuration["mpi"] = new_mpi
Expand Down Expand Up @@ -675,10 +672,6 @@ class ParametersRunning(ParametersBase):
validation loss has to plateau before the schedule takes effect).
Default: 0.

use_compression : bool
If True and horovod is used, horovod compression will be used for
allreduce communication. This can improve performance.

num_workers : int
Number of workers to be used for data loading.

Expand Down Expand Up @@ -739,7 +732,6 @@ def __init__(self):
self.learning_rate_scheduler = None
self.learning_rate_decay = 0.1
self.learning_rate_patience = 0
self.use_compression = False
self.num_workers = 0
self.use_shuffling_for_samplers = True
self.checkpoints_each_epoch = 0
Expand All @@ -755,8 +747,8 @@ def __init__(self):
self.training_report_frequency = 1000
self.profiler_range = [1000, 2000]

def _update_horovod(self, new_horovod):
super(ParametersRunning, self)._update_horovod(new_horovod)
def _update_ddp(self, new_ddp):
super(ParametersRunning, self)._update_ddp(new_ddp)
self.during_training_metric = self.during_training_metric
self.after_before_training_metric = self.after_before_training_metric

Expand All @@ -778,9 +770,9 @@ def during_training_metric(self):
@during_training_metric.setter
def during_training_metric(self, value):
if value != "ldos":
if self._configuration["horovod"]:
if self._configuration["ddp"]:
raise Exception("Currently, MALA can only operate with the "
"\"ldos\" metric for horovod runs.")
"\"ldos\" metric for ddp runs.")
self._during_training_metric = value

@property
Expand All @@ -801,17 +793,17 @@ def after_before_training_metric(self):
@after_before_training_metric.setter
def after_before_training_metric(self, value):
if value != "ldos":
if self._configuration["horovod"]:
if self._configuration["ddp"]:
raise Exception("Currently, MALA can only operate with the "
"\"ldos\" metric for horovod runs.")
"\"ldos\" metric for ddp runs.")
self._after_before_training_metric = value

@during_training_metric.setter
def during_training_metric(self, value):
if value != "ldos":
if self._configuration["horovod"]:
if self._configuration["ddp"]:
raise Exception("Currently, MALA can only operate with the "
"\"ldos\" metric for horovod runs.")
"\"ldos\" metric for ddp runs.")
self._during_training_metric = value

@property
Expand Down Expand Up @@ -1178,7 +1170,10 @@ def __init__(self):

# Properties
self.use_gpu = False
self.use_horovod = False
self.use_ddp = False
self.use_distributed_sampler_train = True
self.use_distributed_sampler_val = True
self.use_distributed_sampler_test = True
self.use_mpi = False
self.verbosity = 1
self.device = "cpu"
Expand Down Expand Up @@ -1259,25 +1254,62 @@ def use_gpu(self, value):
self.hyperparameters._update_gpu(self.use_gpu)

@property
def use_horovod(self):
"""Control whether or not horovod is used for parallel training."""
return self._use_horovod
def use_ddp(self):
"""Control whether or not dd is used for parallel training."""
return self._use_ddp

@property
def use_distributed_sampler_train(self):
"""Control wether or not distributed sampler is used to distribute training data."""
return self._use_distributed_sampler_train

@use_distributed_sampler_train.setter
def use_distributed_sampler_train(self, value):
"""Control whether or not distributed sampler is used to distribute training data."""
self._use_distributed_sampler_train = value

@property
def use_distributed_sampler_val(self):
"""Control whether or not distributed sampler is used to distribute validation data."""
return self._use_distributed_sampler_val

@use_distributed_sampler_val.setter
def use_distributed_sampler_val(self, value):
"""Control whether or not distributed sampler is used to distribute validation data."""
self._use_distributed_sampler_val = value

@property
def use_distributed_sampler_test(self):
"""Control whether or not distributed sampler is used to distribute test data."""
return self._use_distributed_sampler_test

@use_distributed_sampler_test.setter
def use_distributed_sampler_test(self, value):
"""Control whether or not distributed sampler is used to distribute test data."""
self._use_distributed_sampler_test = value

@use_horovod.setter
def use_horovod(self, value):
@use_ddp.setter
def use_ddp(self, value):
if value:
hvd.init()
print("initializing torch.distributed.")
# JOSHR:
# We start up torch distributed here. As is fairly standard convention, we get the rank
# and world size arguments via environment variables (RANK, WORLD_SIZE). In addition to
# those variables, LOCAL_RANK, MASTER_ADDR and MASTER_PORT should be set.
rank = int(os.environ.get("RANK"))
world_size = int(os.environ.get("WORLD_SIZE"))
dist.init_process_group("nccl", rank=rank, world_size=world_size)

# Invalidate, will be updated in setter.
set_horovod_status(value)
set_ddp_status(value)
self.device = None
self._use_horovod = value
self.network._update_horovod(self.use_horovod)
self.descriptors._update_horovod(self.use_horovod)
self.targets._update_horovod(self.use_horovod)
self.data._update_horovod(self.use_horovod)
self.running._update_horovod(self.use_horovod)
self.hyperparameters._update_horovod(self.use_horovod)
self._use_ddp = value
self.network._update_ddp(self.use_ddp)
self.descriptors._update_ddp(self.use_ddp)
self.targets._update_ddp(self.use_ddp)
self.data._update_ddp(self.use_ddp)
self.running._update_ddp(self.use_ddp)
self.hyperparameters._update_ddp(self.use_ddp)

@property
def device(self):
Expand All @@ -1301,7 +1333,7 @@ def device(self, value):

@property
def use_mpi(self):
"""Control whether or not horovod is used for parallel training."""
"""Control whether or not ddp is used for parallel training."""
return self._use_mpi

@use_mpi.setter
Expand Down
Loading