Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add multi gpu support #3548

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jeffpicard
Copy link

@jeffpicard jeffpicard commented Sep 24, 2024

Hi! This is a draft PR that adds multi gpu support. @alanakbik and others: would you be interested in incorporating something like this? The core functionality is working and I've pasted a short script below demonstrating its usage. I get a near-linear speed increase -- for 1 epoch it took: 16 cpus --> 368s, 1 gpu --> 32s, 4 gpus --> 8s when when running on an AWS g5.12xlarge instance with 4 A10 GPUs.

There's a related issue here, and a past PR that never ended up merging.

The approach:

  • This PR uses raw pytorch's DistributedDataParallel rather than another package like fabric, accelerate, or deepspeed. This gives more control and visibility into exactly what's happening and avoids needing to integrate another large pytorch project's design on how to handle e.g. AMP. However, it leaves more to be handled in flair, such as multi-node / TPUs etc. I'm open to discussing/implementing other approaches if you have preferences.
  • In order to use multiple GPUs, users would call a launch_distributed mechanism. This means 1) user code will be running num_gpus times which can be unintuitive and 2) existing flair scripts won't automatically use multi-gpus without refactoring. I think a simpler approach may be possible by spawning processes inside Trainer.train_custom. However, I ran into problems doing it this way (e.g. TransformerEmbeddings and Pluggable._event_queue would not serialize correctly), and many multi-gpu projects involve this kind of complexity. I think this PR is still a step toward that better future though, and existing CPU/single-gpu usage is unchanged.

There are still TODOs. For example, the logging inside .train_custom prints out multiple times (once for each process/gpu). If you connect with the approach, I can add new commits fixing this by adding statements like if is_main_process(): or torch.distributed.gather_object to aggregate metrics across processes, similar to what's done for the eval steps in this PR.

Example usage:

import flair
from flair.data import Sentence
from flair.datasets import TREC_6
from flair.distributed_utils import launch_distributed
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer


def example(max_epochs):
    corpus = TREC_6()
    label_type = "question_class"
    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=max_epochs)
    model.predict(Sentence("Hello, world!"))


if __name__ == "__main__":
    mode = "multi_gpu"
    epochs = 2
    if mode == "multi_gpu":
        launch_distributed(example, epochs)
    elif mode == "single_gpu":
        example(epochs)
    elif mode == "cpu":
        flair.device = "cpu"
        example(epochs)
    print("Done")

@alanakbik
Copy link
Collaborator

Hello @jeffpicard this is awesome, thanks for the PR!

@helpmefindaname @HallerPatrick can you take a look?

log = logging.getLogger("flair")


def launch_distributed(fp, *args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get an example of this function being called, or is this run explicitly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no longer run by ordinary users -- it's called inside Trainer.train and .fine_tune.

mp.spawn(entrypoint, args=(world_size, fp, *args), nprocs=world_size)


def entrypoint(rank, world_size, fp, *args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add type hints and a docstring here?

@HallerPatrick
Copy link
Collaborator

Hey @jeffpicard, thanks for the PR.

I tested your changes with different number of GPUs and can, more or less, reproduce your speedups!

I also like the approach of settings everything up in-process to isolate the distribution logic only for the training logic. For the logging, we could go simple with:

if flair.distributed:
   self.model = DistributedModel(self.model, device_ids=[flair.device.index])
       
   # Disable logging in distributed mode for all but the main process
   log.disabled = not is_main_process()    

Here some points from my side:

  1. I am a little suspicious about the DistributedModel wrapper, where we can now arbitrarily update the DistributedDataParallel model without knowing if it effects any distributed-logic. I see the convenience of it.
    Maybe we can check every __getattr__ and __setattr__ call just to be on the save side here :P

  2. Model saving logic is still distributed. Easily fixable

  3. How to handle "best model" logic after each epoch. Should we just naively test the main process model? I dont know if this is nitpicky...

On a side note, maybe we can also implement multi-gpu support for the LM trainer @alanakbik :)

Thank you!

@jeffpicard
Copy link
Author

jeffpicard commented Sep 25, 2024

Many thanks for the thoughtful review!

isolate the distribution logic only for the training logic

I'll look into distributing across processes inside the call to .train/.fine_tune rather than before. Some of the serialization issues (e.g. Pluggable._event_queue) should be solvable, I think.

log.disabled = not is_main_process()

Ah, great idea, thanks!

  1. I felt the same way! Thanks for calling it out. The idea is inspired by other implementations like Lightning Fabric. I'll be more careful about the implementation.

  2. I did add an if is_main_process() to Model.save, but I can move that in front of all the calls to model.save in trainer to be less surprising.

  3. I believe testing the main process model should be fine since the models on each process/GPU should be the same. However, the data should also be the same. The dev Dataset should already be the same on each process, but if train_loss is used, that's calculated only for the fraction of data the given process handles. I'll try torch.distributed.gather_object(train_loss) to average across all processes/gpus. This will also help for logging the training progress.

I'll follow up soon.

@helpmefindaname
Copy link
Collaborator

Hi @jeffpicard

thank you for creating this draft. Conceptionally, I think this is a good way to finally integrate multi-gpu training in flair.

I tested this on on 2 RTX A 4000, by increasing the mini_batch_chunk_size to be so large that all gpu-memory is used. And the mini_batch_size to be either the same (multi-gpu) or 2x (single-gpu) to have a fair comparision in terms of batch-updates.
Also, I used clearML for logging.
With that, I can comfirm the ~2x speed improvement for 2 gpus & that the metrics at the end are about the same (although slightly worse for multi-gpu).

I observed, that somehow the logging at multi-gpu is off by 1 epoch:
image
here you see, that there was no report for epoch 1, but a epoch 21 magically appeared. I am not sure why that is.

Also, since currently the non-main-processes also log values, I could observe the following:
image
Here, the non-mainprocess is ahead of the main process, due to not having to evaluate. I am not sure, if that is good, or if we should rather syncronize the processes at the end of each epoch.
Obviously splitting the evaluation would also be an option, but I think that would imply a lot of changes that make this PR more complicated.

I wonder how the plugins are impacted by the multi-gpu. Logger plugins should obviously only work on the main-process, while others, like the lr-scheduler plugins need to be run on every process.
Note: currently the lr-scheduler doesn't know that multi-gpu training uses a higher batch-size/less train steps:
image

using the AnnealOnPlateau schduler doesn't work, as the non-main-processes fail without eval metric.

@jeffpicard
Copy link
Author

Thanks for looking at this @helpmefindaname !

logging at multi-gpu is off by 1 epoch

Ahh, sorry about that. I think it's from the new call to .set_epoch(epoch) which was off by 1.

the non-mainprocess is ahead of the main process [...] we should rather syncronize the processes

DistributedDataParallel should be synchronizing every backward(), but there was a bug. I fixed it.

plugins

Thanks to @HallerPatrick's good idea to disable the logger on all but the main process, I've simplified the plugins to run on all processes. This makes the AnnealOnPlateau work. However, yes, if a plugin needs to synchronize information from all processes, it'll have to explicitly do that.


I ran into an unfortunate wrinkle -- I no longer see a speedup after the following bug fix: I noticed the gradients were not the same on each process/gpu for a given epoch_num and batch_no across all GPUs, like they should be. I think this is because pytorch's synchronization implementation relies on hooks that get called when you __call__ a model rather than just use forward_loss. Changing the Trainer:

loss, datapoint_count = self.model.forward_loss(batch_step)
# becomes
loss, datapoint_count = self.model(batch_step)

fixes the gradients, but makes multiple GPUs a bit slower than a single GPU. Any idea what could be going on that's making it slower?

@jeffpicard
Copy link
Author

Any idea what could be going on that's making it slower?

Aha, with a bigger batch size, multiple GPUs are faster again. There's a little overhead to synchronizing the gradients, so the bigger the batch size, the more the overhead can be amortized.

I've fixed most of what's mentioned above

  • Process forking now happens inside .train so all users have to do is add the multi_gpu=True argument
  • The metrics logged during training are now averaged/summed from all GPUs rather than printing the rank=0 data
  • Removed DistributedModel wrapper

I'll push these changes up.


I'm still stuck on:

  • What to do about forward vs forward_loss. In order to get the gradients to synchronize, pytorch relies on hooks run by __call__, which then invoke the special function forward. flair's trainer relies on forward_loss. Which is potentially convenient because forward can just be redirected to forward_loss. But some Model's also use forward. One option is to refactor all models so that either all use forward or none use forward but that's complex ¯_(ツ)_/¯.
  • I need to make TransformerEmbeddings work with pickle. Currently getting TypeError: DistilBertModel.__init__() got an unexpected keyword argument 'instance_parameters'.

Let me know if you have any thoughts on forward.

@jeffpicard
Copy link
Author

jeffpicard commented Oct 8, 2024

And here's an example of running it on the lastest commit

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

if __name__ == "__main__":
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"

    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants