Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MALA-DDP #466

Merged
merged 23 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f80a983
all files updated with changes for ddp implementation
dytnvgl Jun 30, 2023
2f10a23
allowing ddp wrapper to push network for saving during checkpoint
Jun 30, 2023
05f551b
allow checkpoint network save when not using ddp
Jul 3, 2023
50272a7
Merge branch 'refs/heads/develop' into ddp
RandomDefaultUser Apr 25, 2024
e60e995
blackified remaining files
RandomDefaultUser Apr 25, 2024
4175cd0
Added suggestions by Josh
RandomDefaultUser Apr 25, 2024
2cde3f1
Removed DDP in yaml file
RandomDefaultUser Apr 25, 2024
bd68063
Minor reformatting
RandomDefaultUser Apr 26, 2024
3cebb9d
Small bug
RandomDefaultUser Apr 26, 2024
ffa3082
Model only saved on master rank
RandomDefaultUser Apr 26, 2024
04b0050
Adjusted output for parallel
RandomDefaultUser Apr 26, 2024
1fb2c98
Testing if distributed samplers work as default
RandomDefaultUser Apr 30, 2024
a9027a7
Added some documentation
RandomDefaultUser Apr 30, 2024
af1081e
This should fix the inference
RandomDefaultUser Apr 30, 2024
18fa6e2
Trying to fix checkpointing
RandomDefaultUser May 2, 2024
f49e63d
Added docs for new loading parameters
RandomDefaultUser May 2, 2024
e1753d0
This should fix lazy loading mixing
RandomDefaultUser May 2, 2024
36e626c
Missing comma
RandomDefaultUser May 2, 2024
d9c7a73
Made printing for DDP init debug only
RandomDefaultUser May 2, 2024
51235b4
Forgot an equals sign
RandomDefaultUser May 2, 2024
325cf65
Lazy loading working now
RandomDefaultUser May 3, 2024
873e486
Adapted docs to use srun instead of torchrun for example
RandomDefaultUser May 3, 2024
b58c096
Small bugfix to fix CI
RandomDefaultUser May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions docs/source/advanced_usage/trainingmodel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,68 @@ via

The full path for ``path_to_visualization`` can be accessed via
``trainer.full_visualization_path``.


Training in parallel
********************

If large models or large data sets are employed, training may be slow even
if a GPU is used. In this case, multiple GPUs can be employed with MALA
using the ``DistributedDataParallel`` (DDP) formalism of the ``torch`` library.
To use DDP, make sure you have `NCCL <https://developer.nvidia.com/nccl>`_
installed on your system.

To activate and use DDP in MALA, almost no modification of your training script
is necessary. Simply activate DDP in your ``Parameters`` object. Make sure to
also enable GPU, since parallel training is currently only supported on GPUs.

.. code-block:: python

parameters = mala.Parameters()
parameters.use_gpu = True
parameters.use_ddp = True

MALA is now set up for parallel training. DDP works across multiple compute
nodes on HPC infrastructure as well as on a single machine hosting multiple
GPUs. While essentially no modification of the python script is necessary, some
modifications for calling the python script may be necessary, to ensure
that DDP has all the information it needs for inter/intra-node communication.
This setup *may* differ across machines/clusters. During testing, the
following setup was confirmed to work on an HPC cluster using the
``slurm`` scheduler.

.. code-block:: bash

#SBATCH --nodes=NUMBER_OF_NODES
#SBATCH --ntasks-per-node=NUMBER_OF_TASKS_PER_NODE
#SBATCH --gres=gpu:NUMBER_OF_TASKS_PER_NODE
# Add more arguments as needed
...

# Load more modules as needed
...

# This port can be arbitrarily chosen.
# Given here is the torchrun default
export MASTER_PORT=29500

# Find out the host node.
echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

# Run using srun.
srun -u bash -c '
# Export additional per process variables
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export WORLD_SIZE=$SLURM_NTASKS

python3 -u training.py
'

An overview of environment variables to be set can be found `in the official documentation <https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization>`_.
A general tutorial on DDP itself can be found `here <https://pytorch.org/tutorials/beginner/ddp_series_theory.html>`_.


4 changes: 3 additions & 1 deletion install/mala_gpu_base_environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
name: mala-gpu
channels:
- defaults
- conda-forge
- defaults
dependencies:
- python=3.10
4 changes: 0 additions & 4 deletions mala/common/check_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ def check_modules():
"available": False,
"description": "Enables inference parallelization.",
},
"horovod": {
"available": False,
"description": "Enables training parallelization.",
},
"lammps": {
"available": False,
"description": "Enables descriptor calculation for data preprocessing "
Expand Down
48 changes: 22 additions & 26 deletions mala/common/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from collections import defaultdict
import platform
import os
import warnings

try:
import horovod.torch as hvd
except ModuleNotFoundError:
pass
import torch
import torch.distributed as dist

use_horovod = False
use_ddp = False
use_mpi = False
comm = None
local_mpi_rank = None
Expand All @@ -33,45 +31,43 @@ def set_current_verbosity(new_value):
current_verbosity = new_value


def set_horovod_status(new_value):
def set_ddp_status(new_value):
"""
Set the horovod status.
Set the ddp status.

By setting the horovod status via this function it can be ensured that
By setting the ddp status via this function it can be ensured that
printing works in parallel. The Parameters class does that for the user.

Parameters
----------
new_value : bool
Value the horovod status has.
Value the ddp status has.

"""
if use_mpi is True and new_value is True:
raise Exception(
"Cannot use horovod and inference-level MPI at "
"the same time yet."
"Cannot use ddp and inference-level MPI at " "the same time yet."
)
global use_horovod
use_horovod = new_value
global use_ddp
use_ddp = new_value


def set_mpi_status(new_value):
"""
Set the MPI status.

By setting the horovod status via this function it can be ensured that
By setting the MPI status via this function it can be ensured that
printing works in parallel. The Parameters class does that for the user.

Parameters
----------
new_value : bool
Value the horovod status has.
Value the MPI status has.

"""
if use_horovod is True and new_value is True:
if use_ddp is True and new_value is True:
raise Exception(
"Cannot use horovod and inference-level MPI at "
"the same time yet."
"Cannot use ddp and inference-level MPI at " "the same time yet."
)
global use_mpi
use_mpi = new_value
Expand Down Expand Up @@ -119,8 +115,8 @@ def get_rank():
The rank of the current thread.

"""
if use_horovod:
return hvd.rank()
if use_ddp:
return dist.get_rank()
if use_mpi:
return comm.Get_rank()
return 0
Expand Down Expand Up @@ -159,8 +155,8 @@ def get_local_rank():
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
if use_horovod:
return hvd.local_rank()
if use_ddp:
return int(os.environ.get("LOCAL_RANK"))
if use_mpi:
global local_mpi_rank
if local_mpi_rank is None:
Expand All @@ -187,8 +183,8 @@ def get_size():
size : int
The number of ranks.
"""
if use_horovod:
return hvd.size()
if use_ddp:
return dist.get_world_size()
if use_mpi:
return comm.Get_size()

Expand All @@ -209,8 +205,8 @@ def get_comm():

def barrier():
"""General interface for a barrier."""
if use_horovod:
hvd.allreduce(torch.tensor(0), name="barrier")
if use_ddp:
dist.barrier()
if use_mpi:
comm.Barrier()
return
Expand Down
Loading
Loading