Skip to content

Commit

Permalink
Added some documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Apr 30, 2024
1 parent 1fb2c98 commit a9027a7
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions docs/source/advanced_usage/trainingmodel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,60 @@ via
The full path for ``path_to_visualization`` can be accessed via
``trainer.full_visualization_path``.


Training in parallel
********************

If large models or large data sets are employed, training may be slow even
if a GPU is used. In this case, multiple GPUs can be employed with MALA
using the ``DistributedDataParallel`` (DDP) formalism of the ``torch`` library.
To use DDP, make sure you have `NCCL <https://developer.nvidia.com/nccl>`_
installed on your system.

To activate and use DDP in MALA, almost no modification of your training script
is necessary. Simply activate DDP in your ``Parameters`` object. Make sure to
also enable GPU, since parallel training is currently only supported on GPUs.

.. code-block:: python
parameters = mala.Parameters()
parameters.use_gpu = True
parameters.use_ddp = True
MALA is now set up for parallel training. DDP works across multiple compute
nodes on HPC infrastructure as well as on a single machine hosting multiple
GPUs. While essentially no modification of the python script is necessary, some
modifications for calling the python script may be necessary, to ensure
that DDP has all the information it needs for inter/intra-node communication.
This setup *may* differ across machines/clusters. During testing, the
following setup was confirmed to work on an HPC cluster using the
``slurm`` scheduler.

.. code-block:: bash
#SBATCH --nodes=NUMBER_OF_NODES
#SBATCH --ntasks-per-node=NUMBER_OF_TASKS_PER_NODE
#SBATCH --gres=gpu:NUMBER_OF_TASKS_PER_NODE
# Add more arguments as needed
...
# Load more modules as needed
...
# This port can be arbitrarily chosen.
export MASTER_PORT=12342
# Find out the host node.
echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR
# Run using torchrun.
torchrun --nnodes NUMBER_OF_NODES --nproc_per_node NUMBER_OF_TASKS_PER_NODE --rdzv_id "$SLURM_JOB_ID" training.py
This script follows `this tutorial <https://gist.github.com/TengdaHan/1dd10d335c7ca6f13810fff41e809904>`_.
A tutorial on DDP itself can be found `here <https://pytorch.org/tutorials/beginner/ddp_series_theory.html>`_.


0 comments on commit a9027a7

Please sign in to comment.