-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support bf16_optimizer moe expert parallel training and moe EP grad_scale/grad_norm fix #5259
Conversation
inkcherry
commented
Mar 12, 2024
•
edited
Loading
edited
- bf16 moe EP requires different partitions and this will impact dp gradient allreduce, zero1 params allgather, as well as gradient_norm allreduce. Currently, the bf16_optimizer does not correctly partition the group. fix and support bf16 type training.
- fix calculation of moe ep grad scale and grad_norm for bf16&fp16
Hi @inkcherry, thank you for the contribution! |
Hi , @tohtana However, it seems that the current fp16 optimizer implementation computes final_grad_norm as follows: Is this what you referred to as the issue? I am uncertain whether this is an issue or if the current design choice is based on experimental results or other considerations. Therefore, I have maintained the same implementation as fp16 optimizer for now. What are your thoughts? |
Another change is that, due to the existence of EP, the params_group of moe layer needs to implement [partition code][bwd dp gradient allreduce] in groups of ep_dp_group. The parameters of the MoE layer are only replicated within the ep_dp_group. |
Hi @inkcherry, I think that you are correct and the final grad norm should be calculated as you wrote: Apparently, I was working on the same issue in parallel to you and I have written a method to calculate the global norm. This is only a suggestion. ``
`` |
@inkcherry Thank you for the detailed explanation! I agree this is the right fix. The behaviors of BF16/FP16 should match. Honestly, I'm not so sure if we should use a different scaling for experts or not. With data parallel, we usually scale gradients according to the ratio of number of samples. It seems reasonable to scale gradients (not only gradient norm) according to number of tokens that each expert computes. I would appreciate it if you could share any thought. |
Thank you very much for your suggestions. @mosheisland @tohtana DeepSpeed/deepspeed/runtime/engine.py Line 2338 in b112c99
Excluding the discussion about EP, focusing solely on the scaling of gradients in the context of the MOE,
|
I agree with @inkcherry and I also think that MoE parameter grads should be summed using ep_dp_group and scaled using dp group. However, currently configuration 2 parameter gradients are exactly 2x than configuration 1 parameter gradients. Actually, locally I already implemented this change it as part of my changes. |
Thank you @inkcherry @mosheisland for sharing your insights. They seem reasonable to me. |
The following three configurations were tested for convergence testing. hi, @mosheisland , I added the modifications discussed above, do you have any suggestions or comments? I would appreciate your thoughts: ), If you're okay with these modifications, let's ping tohtana to take a look. |
deepspeed/runtime/engine.py
Outdated
values.mul_(self.gradient_predivide_factor() / | ||
(dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) | ||
if dp_world_size is None: | ||
dp_world_size = dist.get_world_size(group=dp_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be moved to before if self.postscale_gradienst():
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -261,7 +261,7 @@ def step(self, closure=None): | |||
pg = self.deepspeed.mpu.get_data_parallel_group() | |||
else: | |||
pg = groups._get_data_parallel_group() | |||
all_groups_norm = get_norm_with_moe_layers(all_groups_norm, pg) | |||
all_groups_norm = get_norm_with_moe_layers_fast(all_groups_norm, pg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use the non-fast also for fp16? I think the "fast" version will give incorrect results when you are using multiple EP sizes in a single model (e.g. first layer uses EP=4 and 2nd layer uses EP=2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have also unified the past fp16 with the current bf16.
@inkcherry - added few comments inside the source code.
In all above, I would also compare grad norm graph. After the expert grad reduction fix, grad norm curve should be the same regardless of the EP used (as opposed to before the fix). In order to make the loss curve comparisons more clear, I would go with a smaller model but with more steps (e.g. 1000). |
hi, @tohtana , I have added the remaining fixes. |
Thank you @inkcherry, I reviewed the code. This is truly amazing work! I also appreciate the thorough verification. |
deepspeed/runtime/utils.py
Outdated
@@ -884,6 +895,8 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F | |||
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) | |||
if mpu is not None: | |||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) | |||
if moe_ep_group is not None: | |||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group) | |||
total_norm = total_norm_cuda[0].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_norm = total_norm_cuda[0].item() -
This line needs to activated also for the if mpu is not None case
maybe shift left one tab.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, fixed.
hi @tohtana , thanks for the review. and I just fixed the CI error. Could you please trigger the CI again? |
…cale/grad_norm fix (microsoft#5259) - bf16 moe EP requires different partitions and this will impact dp gradient allreduce, zero1 params allgather, as well as gradient_norm allreduce. Currently, the bf16_optimizer does not correctly partition the group. fix and support bf16 type training. - fix calculation of moe ep grad scale and grad_norm for bf16&fp16 --------- Co-authored-by: Olatunji Ruwase <[email protected]>
…cale/grad_norm fix (microsoft#5259) - bf16 moe EP requires different partitions and this will impact dp gradient allreduce, zero1 params allgather, as well as gradient_norm allreduce. Currently, the bf16_optimizer does not correctly partition the group. fix and support bf16 type training. - fix calculation of moe ep grad scale and grad_norm for bf16&fp16 --------- Co-authored-by: Olatunji Ruwase <[email protected]>