Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MALA-DDP #466

Merged
merged 23 commits into from
May 12, 2024
Merged

MALA-DDP #466

merged 23 commits into from
May 12, 2024

Conversation

dytnvgl
Copy link
Contributor

@dytnvgl dytnvgl commented Jul 6, 2023

Replacement of horovod with DDP for GPU parallelization using PyTorch.

@dytnvgl dytnvgl marked this pull request as ready for review July 6, 2023 03:41
Copy link
Contributor

@romerojosh romerojosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a few small issues when I tried this out on my local system which I've marked. The original patch was developed before the lazy loading changes were merged, so I think there might be some changes needed to handle that case, specifically with the distributed samplers.

compression=compression,
op=hvd.Average)
if self.parameters_full.use_distributed_sampler_test:
if self.data.test_data_sets is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.data.test_data_sets is not None:
if self.data.test_data_sets:

rank=hvd.rank(),
if self.parameters_full.use_distributed_sampler_train:
self.train_sampler = torch.utils.data.\
distributed.DistributedSampler(self.data.training_data_sets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
distributed.DistributedSampler(self.data.training_data_sets,
distributed.DistributedSampler(self.data.training_data_sets[0],

The sampler needs to be associated with an existing dataset. The existing line is trying to associate a sampler with the list of data sets, which does not work.

Setting up the sampler for [0] works for the in-memory loader only. For the lazy-loading case where there are actually multiple datasets in the self.data.training_data_sets list, I think we will need to create and manage an individual sampler per dataset.

shuffle=do_shuffle)
if self.parameters_full.use_distributed_sampler_val:
self.validation_sampler = torch.utils.data.\
distributed.DistributedSampler(self.data.validation_data_sets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
distributed.DistributedSampler(self.data.validation_data_sets,
distributed.DistributedSampler(self.data.validation_data_sets[0],

if self.parameters_full.use_distributed_sampler_test:
if self.data.test_data_sets is not None:
self.test_sampler = torch.utils.data.\
distributed.DistributedSampler(self.data.test_data_sets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
distributed.DistributedSampler(self.data.test_data_sets,
distributed.DistributedSampler(self.data.test_data_sets[0],

@srajama1
Copy link
Contributor

@dytnvgl @RandomDefaultUser Now the release is done, can we prioritize merging DDP changes pleasE?

@RandomDefaultUser
Copy link
Member

@dytnvgl @RandomDefaultUser Now the release is done, can we prioritize merging DDP changes pleasE?

Hi Siva, definitely! I still have to work around my PhD thesis schedule, but I can dedicate my "programming" time to this issue now!

RandomDefaultUser and others added 3 commits April 25, 2024 17:58
# Conflicts:
#	mala/common/check_modules.py
#	mala/common/parallelizer.py
#	mala/common/parameters.py
#	mala/datahandling/data_handler.py
#	mala/datahandling/lazy_load_dataset.py
#	mala/datahandling/lazy_load_dataset_clustered.py
#	mala/datahandling/lazy_load_dataset_single.py
#	mala/network/network.py
#	mala/network/objective_naswot.py
#	mala/network/predictor.py
#	mala/network/runner.py
#	mala/network/tester.py
#	mala/network/trainer.py
Co-authored-by: Josh Romero <[email protected]>
@RandomDefaultUser
Copy link
Member

Hi @dytnvgl @romerojosh @srajama1
I finally found the time to look into this.
The first thing I did is merge the current develop branch into this branch, i.e., update it. This was a bit more tedious than usual because we recently enforced black formatting on the entire codebase.

I will test the code now, and then make a few small adjustments.

One question right from merging: Why is the use of distributed samplers handled via bools? What happens when I run in DDP and do not use distributed samplers? In horovod I just always used distributed samplers when horovod was enabled, why is this different here now?

@romerojosh
Copy link
Contributor

Hi @RandomDefaultUser, glad to see you are able to start working on getting this merged!

It's been a while but I recall the reason I made the distributed sampler usage controlled via bools is that I considered two possible ways of distributing snapshot data across workers:

  1. All workers all load the same set of snapshots: In this case, you need the distributed sampler to shard the common loaded dataset across workers
  2. Each worker is assigned a unique set of snapshots and only loads those: In this case, you would not want to use the distributed sampler since each worker only loads unique data samples

The thought behind option 2 here is that MALA without lazy loading loads all snapshot data into memory and for large datasets, this ends up being very limiting. With this second option, workers would only load a smaller subset of the dataset, enabling running training on larger datasets without lazy loading by increasing the number of workers.

@RandomDefaultUser
Copy link
Member

I see @romerojosh, now I understand, thanks for the explanation! I definitely understand the approach, given the structure of MALA models, that makes a lot of sense.

I recall that @ellisja had tinkered with the same concept with horovod years ago, when MALA was in its conception. While he was still at Oak Ridge, this type of training (I think the term is "federated learning"?) was part of his general research. He tested similar approaches for different application cases. His conclusion back then was that splitting data like this, despite increasing performance, reduced training accuracy.

But since the field of ML is always changing, maybe this problem has been solved or it is not as drastic in our case. Have you perhaps tested this type of training / experience in that regard from a different model? Because I personally unfortunately do not, this is all second-hand experience so to speak.

Looking at the code, I think we would need to make a few more adjustments to get this running natively - in its current form, one would have to assign different snapshots to each node (per rank, in the python script), if I am not mistaken. If we would indeed like to implement such a mechanism, I would argue it would be better to let MALA handle the splitting. Of course this is all possible and I would be happy to take care of it, but I think we should first confirm that it may yield the results we anticipate.

Thus, my suggestion for now would be to make distributed samplers be the default and merge with that. Then afterwards, if someone/whoever has the time can take a look at this type of snapshot splitting and whether it gives reasonable accuracy, and if so, we implement it into the MALA code, such that the splitting is done automatically. Does that sound OK with you?

@RandomDefaultUser
Copy link
Member

RandomDefaultUser commented Apr 28, 2024

I also have tested the code. For a single node it works great, thank you so much for the great work @romerojosh and @dytnvgl ! I could achieve almost ideal speed-up with up to 4 GPUs (the maximum on our nodes).
n1_results

For more than one node my script is currently crashing. I get an RuntimeError: Socket Timeout, which is not an unknown issue (NVIDIA/Megatron-LM#386, pytorch/pytorch#25767). I suspect the issue is on my side, i.e., my slurm/HPC setup. I will try to debug the error further this week, and also attach my slurm script in case someone has an idea.

However, if this is indeed an error on my HPC/slurm side, we could already merge the code. Have you, @dytnvgl @romerojosh already successfully tested this across multiple nodes? If so, I would just add some documentation and then we could merge.

#SBATCH --job-name=N2G4
#SBATCH --nodes=2
#SBATCH --ntasks=4
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4
#SBATCH --time=01:00:00
#SBATCH --mem=360G
#SBATCH --output=parallel_n2_g4.out
#SBATCH -A casus
#SBATCH -p casus

# Loading the modules I always load
module load gcc/8.2.0
module load openmpi/4.0.4
module load python/3.8.0
module load lapack

# NCCL v2.14.3 is installed, I confirmed that
export NCCL_DEBUG=INFO

# Found this code in this helpful tutorial here: https://gist.github.com/TengdaHan/1dd10d335c7ca6f13810fff41e809904
export MASTER_PORT=12342
echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

# And then just run with torchrun
torchrun --nnodes 2 --nproc_per_node 4 --rdzv_id "$SLURM_JOB_ID" parallel_n2_g4.py

@romerojosh
Copy link
Contributor

I rarely ever use torchrun to launch jobs, as it doesn't seem to be designed for systems that already have good multi-node process facilities via SLURM, and the error handling of that tool is not that great.

I typically just use srun to launch my Python processes, and export required env vars for torch distributed to initialize. Something like:

#!/bin/bash
#SBATCH -N 2
#SBATCH --ntasks-per-node=4    # number of tasks per node
# Other sbatch options....

# Exporting vars for DDP
export MASTER_ADDR=$(hostname)
export MASTER_PORT=29500 # default from torch launcher

srun -u bash -c '
# Export additional per process variables
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export WORLD_SIZE=$SLURM_NTASKS

python train.py
'

See https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization for a description of the environment variables you need to specify.

Your batch script using torchrun above does not work correctly because you are using SLURM to launch all the 8 tasks (2 nodes, 4 per node), but torchrun expects you to only launch a single task per node and allow it to launch all the additional tasks (set by the --nproc_per_node 4) argument. As I mention though, directly using srun is simpler and better IMO.

@romerojosh
Copy link
Contributor

Thus, my suggestion for now would be to make distributed samplers be the default and merge with that. Then afterwards, if someone/whoever has the time can take a look at this type of snapshot splitting and whether it gives reasonable accuracy, and if so, we implement it into the MALA code, such that the splitting is done automatically. Does that sound OK with you?

This is fine with me. I'll leave it to @dytnvgl or someone else to make a call on whether the snapshot splitting idea mentioned is useful.

@RandomDefaultUser
Copy link
Member

Thus, my suggestion for now would be to make distributed samplers be the default and merge with that. Then afterwards, if someone/whoever has the time can take a look at this type of snapshot splitting and whether it gives reasonable accuracy, and if so, we implement it into the MALA code, such that the splitting is done automatically. Does that sound OK with you?

This is fine with me. I'll leave it to @dytnvgl or someone else to make a call on whether the snapshot splitting idea mentioned is useful.

Alright, sounds good, then this part of the PR seems finished (for now) and we can get back to the splitting whenever we want/need to.

@RandomDefaultUser
Copy link
Member

I rarely ever use torchrun to launch jobs, as it doesn't seem to be designed for systems that already have good multi-node process facilities via SLURM, and the error handling of that tool is not that great.

I typically just use srun to launch my Python processes, and export required env vars for torch distributed to initialize. Something like:

Thanks for the advice! I didn't know torchrun was more difficult to use, I thought this was the standard way. But I totally see your point, and I tried out the srun route. With it, I still get an error, namely

Traceback (most recent call last):
  File "parallel_n2_g4.py", line 66, in <module>
    data_handler.prepare_data()
  File "/home/fiedle09/tools/mala/mala/datahandling/data_handler.py", line 217, in prepare_data
    barrier()
  File "/home/fiedle09/tools/mala/mala/common/parallelizer.py", line 209, in barrier
    dist.barrier()
  File "/home/fiedle09/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3145, in barrier
    work = default_pg.barrier(opts=opts)
RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1269, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.
Last error:
Net : Connect to 169.254.3.1<33323> failed : Connection refused
srun: error: ga011: tasks 4,6: Exited with exit code 1
srun: Job step aborted: Waiting up to 152 seconds for job step to finish.

Which, after some research, seems to be caused by the nodes on my cluster not being able to communicate over TCP, and thus NCCL communication failing. This now really seems like a problem related to my cluster, and I will contact HPC staff at my research facility to help with it.

On top of looking into this issue I tested

  • loading models trained in parallel for inference
  • checkpointing and loading models
  • using lazy loading with shuffling

All three needed small adjustments, but work now! Here are some timing results for lazy loading:

n1_ll_results

We move away from ideal speedup a bit, but I think that is to be expected as file I/O takes a larger overhead. Overall, we still see a decent speed up. It would be interesting to see how this behavior changes as we move to even larger data sets and how the ratio data per file/number of data files impacts results. There may be even more room for optimization, as has already been pointed out, but that would exceed the scope of this PR. Please note that the absolute values of the epoch times cannot be compared to my previous plot, since I used V100 GPUs there and A100 here.

So from my side, this PR is ready to be merged! I know that the CI is currently failing, but that is not due to the code - it's due to the download quota of our test data repository exceeded. It will reset in 9 days, until then, the CI will fail.... we have to find a better solution for this, it is not the first time this happened. I have confirmed that the tests pass locally.

I would merge this PR, if this is OK with you @dytnvgl and @romerojosh ? Again, thanks for the great work!

@RandomDefaultUser RandomDefaultUser merged commit 7abd9d7 into mala-project:develop May 12, 2024
4 of 5 checks passed
@RandomDefaultUser
Copy link
Member

@romerojosh @dytnvgl

Together with the HPC administrators at HZDR I was able to track down the problem with multi-node training and got it to work on my system! :)
As expected, the MALA code did not need adjusting, there was a configuration problem on the cluster that is now fixed. Results are looking good!

multi_n_results

@zirui
Copy link

zirui commented Aug 1, 2024

Together with the HPC administrators at HZDR I was able to track down the problem with multi-node training and got it to work on my system! :) As expected, the MALA code did not need adjusting, there was a configuration problem on the cluster that is now fixed. Results are looking good!

I'm running into the same Socket Timeout issue, could you please share more details on what the configuration problem was and how it was fixed?

@RandomDefaultUser
Copy link
Member

Hi @zirui I don't have the full picture, but the admin told me that on the compute node I was using my python process was attempting to use some internal virtual adapter rather than the network connection. Why this one instead of the network connection was DDPs first go-to I don't know, but after the admin deactivated this virtual adapter everything worked. Sorry for not being able to provide more info, I hope you find a solution!

@zirui
Copy link

zirui commented Aug 2, 2024

Hi @zirui I don't have the full picture, but the admin told me that on the compute node I was using my python process was attempting to use some internal virtual adapter rather than the network connection. Why this one instead of the network connection was DDPs first go-to I don't know, but after the admin deactivated this virtual adapter everything worked. Sorry for not being able to provide more info, I hope you find a solution!

Hi @RandomDefaultUser
Thank you so much for the update and the insights. The information is very valuable, it give me a good starting point to troubleshoot on my end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants