Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft][Demo] auto tp training #5445

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Apr 22, 2024

This is an experimental demo on autoTP training, not for review. Apologies for its somewhat rudimentary draft version, I hope to elucidate this process.

Currently, I tested pure TP (DP=1 cases), directly using the HF transformers Trainer. I trained llama7B (finetune from pretrained weights) on 4GPUs and 8GPUs with pure TP and achieved a loss curve of 1.6~0.3(expected).
Main modifications are as follows:

  • On the train script side:
  1. Explicit use of an API (currently directly using the inference API, invoking autoTP to do module replacement).
  2. Manual modification of a dataloader to synchronize data for all TP ranks (this is a temporary solution).
  3. Setting ds_config.json where zero_config zero_stage=0, autotp_size=num_gpus(DP=1).
  • On the DS side change, in this demo:

    1 Decoupling MPU and Megatron, I've directly taken Megatron's code and put it in the parallel_states.py file
    2 Adding backward code for the main replace modules, linelinear & linearallreduce.
    3 Adding the 'tensor_model_parallel' attribute for linelinear & linearallreduce, ensuring they are correctly handled in grad norm or other calculations.
    5 Setting requires_grad=True for the weights and bias of linelinear & linearallreduce, ensuring they are captured in model_params by transformer prepare_deepspeed logic and fed to the DS optimizer in related.
    6 _broadcast_model: Due to some inconsistencies in group settings, the dp group used by _broadcast_model is not correct, so I directly bypass the logic here(DP=1).
    7 gradient_allreduce: directly disable it because of the similar reason as 6. 5&6 can be resolved by a unified group init function.
    8 Adding the autotp_size config.

Currently, in this basic version, I did two simple tests. Under the same gbs and gas conditions, it has 70% performance compared to zero3, but there are some gbs threshold limits lower than zero3 (at this time, zero3 performs better, TP oom, may be either the dataloader or lacking some optimizations from Megatron? I didn't further analyze)

The benefit of doing this is to decouple TP and Megatron bindings, enable user directly using transformers+ds training with tp+something, which can also be applied to other simple models (through module replacement). Additionally, because There are autoTP inference code and the inheritance between zero backend and transformers, No need for particularly much additional logic

  • For the most basic usage, there are some issues to handle:
  1. Transformer seems to have no logic about real TP(Only single-device simulation), which causes some minor problems with dataset counting. For example, if two ranks each load 4 identical data, it will consider this as 8 data. This affects the display of counters and some parameters of the optimizer schedule (equivalent to increasing the lr decay). Need to correct the counting of trained num_samples.
  2. Load and save checkpoint: If the code for autoTP_training is just for autoTP inference, I think this should be much easier. Otherwise, some reverse shard operations may be needed.

For better use:
The most basic is to consider compatibility with zero dp, and it may also be compatible with more features(reuse the relevant logic of ds for Megatron's TP), Some performance and memory optimizations.

@delock
Copy link
Collaborator

delock commented Apr 23, 2024

@inkcherry Is there a link to the demo code? I'm interested in the potential use case of this feature proposal.

@delock
Copy link
Collaborator

delock commented Apr 23, 2024

This PR should be addressing this discussion. Link.
#4930

@inkcherry
Copy link
Contributor Author

inkcherry commented Apr 23, 2024

@inkcherry Is there a link to the demo code? I'm interested in the potential use case of this feature proposal.
hi, @delock

FYI:https://github.com/inkcherry/stanford_alpaca/tree/tp_demo
see the latest commit msg
Due to my bandwidth, it's a bit hard for me to sustain continuous focus on this. If possible, really appreciate an experienced engineer like you to help completing or enhancing it.

@tjruwase
Copy link
Contributor

@inkcherry and @delock, please let us know any way we can help. Thanks!

@skyshine102
Copy link

It would be super helpful if one can make autoTP training with domino #6733.

@tjruwase
Copy link
Contributor

@delock and @inkcherry, is this still active work?

@GuanhuaWang
Copy link
Member

Hi @skyshine102 @delock @inkcherry ,

I am leading domino project, would like to collaborate if possible with this effort of autoTP training & decouple TP and Megatron bindings.

@delock
Copy link
Collaborator

delock commented Nov 12, 2024

Hi @GuanhuaWang is Domino referring to this paper? https://arxiv.org/html/2409.15241v1 Thanks!

@skyshine102
Copy link

@delock Yes, @GuanhuaWang is the first author.
I would suggest a refactoring on parallel_states.py first. Deepspeed team may need to discuss on

  1. maintain deepspeed's parallel_states version (megatron deepspeed) or whether to make it more consistent with NV's megatron-core API
  • It would be great to reserve TP/PP/EP process group API, even it's not used by ZeRO algorithm.
  1. support torch DeviceMesh API
  • as this will be the main entrypoint for most torch & huggingface transformers users
  • currently DeviceMesh API for sequence parallel is supported, but not PP/TP/EP e.t.c. A well-designed process group management API can alleviate future contributor's development cost in prototyping stage.

@inkcherry
Copy link
Contributor Author

@GuanhuaWang, sure, I can help with rebasing the code recently. I think this PR still needs three things to be done:

  1. As mentioned by @skyshine102 , the comm group API is needed, at least for TP and DP.
  2. Dataloader API.
  3. Checkpoint API.

@skyshine102
Copy link

@GuanhuaWang, sure, I can help with rebasing the code recently. I think this PR still needs three things to be done:

  1. As mentioned by @skyshine102 , the comm group API is needed, at least for TP and DP.
  2. Dataloader API.
  3. Checkpoint API.

@inkcherry , I just noticed that you copied parallel_state from megatron-core to your draft but deepspeed engine already has one (here). You may modify this file instead of copying. Currently deepspeed engine support DP/PP/EP/SP and now TP. I suppose all of these process groups will be needed.

  • Dataloader API: Not sure if deepspeed team wants to maintain it... "it's already implemented in megatron-deepspeed repo." Maybe just a quick demo snipplet in example.py will be enough.

…-precision version before the rebase, but the grad norm differs (display issue)
@delock
Copy link
Collaborator

delock commented Nov 27, 2024

Hi @inkcherry need to make sure this PR does not impact autotp inference performance and compatibility. When your PR is stable, check with Guobing and @rogerxfeng8 for internal test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants