-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Add multi gpu support #3548
base: master
Are you sure you want to change the base?
Conversation
Hello @jeffpicard this is awesome, thanks for the PR! @helpmefindaname @HallerPatrick can you take a look? |
flair/distributed_utils.py
Outdated
log = logging.getLogger("flair") | ||
|
||
|
||
def launch_distributed(fp, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get an example of this function being called, or is this run explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's no longer run by ordinary users -- it's called inside Trainer.train
and .fine_tune
.
flair/distributed_utils.py
Outdated
mp.spawn(entrypoint, args=(world_size, fp, *args), nprocs=world_size) | ||
|
||
|
||
def entrypoint(rank, world_size, fp, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add type hints and a docstring here?
Hey @jeffpicard, thanks for the PR. I tested your changes with different number of GPUs and can, more or less, reproduce your speedups! I also like the approach of settings everything up in-process to isolate the distribution logic only for the training logic. For the logging, we could go simple with: if flair.distributed:
self.model = DistributedModel(self.model, device_ids=[flair.device.index])
# Disable logging in distributed mode for all but the main process
log.disabled = not is_main_process() Here some points from my side:
On a side note, maybe we can also implement multi-gpu support for the LM trainer @alanakbik :) Thank you! |
Many thanks for the thoughtful review!
I'll look into distributing across processes inside the call to
Ah, great idea, thanks!
I'll follow up soon. |
Hi @jeffpicard thank you for creating this draft. Conceptionally, I think this is a good way to finally integrate multi-gpu training in flair. I tested this on on 2 RTX A 4000, by increasing the I observed, that somehow the logging at multi-gpu is off by 1 epoch: Also, since currently the non-main-processes also log values, I could observe the following: I wonder how the plugins are impacted by the multi-gpu. Logger plugins should obviously only work on the main-process, while others, like the lr-scheduler plugins need to be run on every process. using the |
Thanks for looking at this @helpmefindaname !
Ahh, sorry about that. I think it's from the new call to
DistributedDataParallel should be synchronizing every backward(), but there was a bug. I fixed it.
Thanks to @HallerPatrick's good idea to disable the logger on all but the main process, I've simplified the plugins to run on all processes. This makes the I ran into an unfortunate wrinkle -- I no longer see a speedup after the following bug fix: I noticed the gradients were not the same on each process/gpu for a given epoch_num and batch_no across all GPUs, like they should be. I think this is because pytorch's synchronization implementation relies on hooks that get called when you
fixes the gradients, but makes multiple GPUs a bit slower than a single GPU. Any idea what could be going on that's making it slower? |
Aha, with a bigger batch size, multiple GPUs are faster again. There's a little overhead to synchronizing the gradients, so the bigger the batch size, the more the overhead can be amortized. I've fixed most of what's mentioned above
I'll push these changes up. I'm still stuck on:
Let me know if you have any thoughts on |
And here's an example of running it on the lastest commit from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
if __name__ == "__main__":
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True) |
Hi! This is a draft PR that adds multi gpu support. @alanakbik and others: would you be interested in incorporating something like this? The core functionality is working and I've pasted a short script below demonstrating its usage. I get a near-linear speed increase -- for 1 epoch it took: 16 cpus --> 368s, 1 gpu --> 32s, 4 gpus --> 8s when when running on an AWS g5.12xlarge instance with 4 A10 GPUs.
There's a related issue here, and a past PR that never ended up merging.
The approach:
DistributedDataParallel
rather than another package like fabric, accelerate, or deepspeed. This gives more control and visibility into exactly what's happening and avoids needing to integrate another large pytorch project's design on how to handle e.g. AMP. However, it leaves more to be handled in flair, such as multi-node / TPUs etc. I'm open to discussing/implementing other approaches if you have preferences.launch_distributed
mechanism. This means 1) user code will be running num_gpus times which can be unintuitive and 2) existing flair scripts won't automatically use multi-gpus without refactoring. I think a simpler approach may be possible by spawning processes insideTrainer.train_custom
. However, I ran into problems doing it this way (e.g.TransformerEmbeddings
andPluggable._event_queue
would not serialize correctly), and many multi-gpu projects involve this kind of complexity. I think this PR is still a step toward that better future though, and existing CPU/single-gpu usage is unchanged.There are still TODOs. For example, the logging inside
.train_custom
prints out multiple times (once for each process/gpu). If you connect with the approach, I can add new commits fixing this by adding statements likeif is_main_process():
ortorch.distributed.gather_object
to aggregate metrics across processes, similar to what's done for the eval steps in this PR.Example usage: