Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support bf16_optimizer moe expert parallel training and moe EP grad_scale/grad_norm fix #5259

Merged
merged 20 commits into from
Mar 27, 2024

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Mar 12, 2024

  • bf16 moe EP requires different partitions and this will impact dp gradient allreduce, zero1 params allgather, as well as gradient_norm allreduce. Currently, the bf16_optimizer does not correctly partition the group. fix and support bf16 type training.
  • fix calculation of moe ep grad scale and grad_norm for bf16&fp16

@tjruwase tjruwase requested review from tohtana and removed request for mrwyattii March 12, 2024 12:05
@tohtana
Copy link
Contributor

tohtana commented Mar 13, 2024

Hi @inkcherry, thank you for the contribution!
Can you elaborate the issue in the current code? It seems the gradient norm used for clipping isn't accurate, right?

@inkcherry
Copy link
Contributor Author

inkcherry commented Mar 14, 2024

Hi , @tohtana
I think that the final_grad_norm should be calculated as follows: final_grad_norm = non_ep_grad_norm + allreduce_sum(ep_grad_norm, group=ep_group), where ep_group represents the group used for the all-to-all communication in the MoE layer. non_ep_grad_norm and ep_grad_norm represent the gradients computed separately for experts and non-experts(after the tensor parallelism allreduce sum operation, if tensor parallelism has been applied)

However, it seems that the current fp16 optimizer implementation computes final_grad_norm as follows:
final_grad_norm = allreduce_avg(all_grad_norm, group=dp_group) = allreduce_avg(non_ep_grad_norm + ep_grad_norm, group=dp_group) = non_ep_grad_norm + allreduce_avg(ep_grad_norm, group=dp_group) In this computation, the second term, allreduce_avg(ep_grad_norm, group=dp_group)=allreduce_avg(ep_grad_norm, group=ep_group), might result in a decrease by a factor of ep_size.

Is this what you referred to as the issue? I am uncertain whether this is an issue or if the current design choice is based on experimental results or other considerations. Therefore, I have maintained the same implementation as fp16 optimizer for now. What are your thoughts?

@inkcherry
Copy link
Contributor Author

inkcherry commented Mar 14, 2024

Another change is that, due to the existence of EP, the params_group of moe layer needs to implement [partition code][bwd dp gradient allreduce] in groups of ep_dp_group. The parameters of the MoE layer are only replicated within the ep_dp_group.
However, the current code defaults to all being in the regular dp_group.

@mosheisland
Copy link
Contributor

Hi , @tohtana I think that the final_grad_norm should be calculated as follows: final_grad_norm = non_ep_grad_norm + allreduce_sum(ep_grad_norm, group=ep_group), where ep_group represents the group used for the all-to-all communication in the MoE layer. non_ep_grad_norm and ep_grad_norm represent the gradients computed separately for experts and non-experts(after the tensor parallelism allreduce sum operation, if tensor parallelism has been applied)

However, it seems that the current fp16 optimizer implementation computes final_grad_norm as follows: final_grad_norm = allreduce_avg(all_grad_norm, group=dp_group) = allreduce_avg(non_ep_grad_norm + ep_grad_norm, group=dp_group) = non_ep_grad_norm + allreduce_avg(ep_grad_norm, group=dp_group) In this computation, the second term, allreduce_avg(ep_grad_norm, group=dp_group)=allreduce_avg(ep_grad_norm, group=ep_group), might result in a decrease by a factor of ep_size.

Is this what you referred to as the issue? I am uncertain whether this is an issue or if the current design choice is based on experimental results or other considerations. Therefore, I have maintained the same implementation as fp16 optimizer for now. What are your thoughts?

Hi @inkcherry, I think that you are correct and the final grad norm should be calculated as you wrote:
final_grad_norm = non_ep_grad_norm + allreduce_sum(ep_grad_norm, group=ep_group).

Apparently, I was working on the same issue in parallel to you and I have written a method to calculate the global norm.

This is only a suggestion.

``

def _get_norm_with_moe_layers(self, non_expert_norm, expert_tensors, norm_type=2):
    """ Compute the global norm with MoE experts

    Inputs:
    non_expert_norm (float) : the calculated norm of the non-expert params
    expert_tensors (Dict[ep_name, List[Tensor]): Dictionary of expert group name to list of grad tensors
    norm_type (int): the norm to use

    Returns:
        if norm is (-/+) inf, returns -1
        otherwise the global norm (float)              
    """

    def to_tensor(v):
        return get_accelerator().FloatTensor(float(v)).detach()

    group_norms = [non_expert_norm]
    for exp_name, tensors in expert_tensors.items():
        group_norm = get_global_norm_of_tensors(input_tensors=tensors,
                                                mpu=self.mpu,
                                                norm_type=self.norm_type,
                                                use_graph=self.graph_harvesting,
                                                moe_ep_group=groups._get_expert_parallel_group(exp_name))
        group_norms.append(group_norm)

    # check if all norms are valid
    group_norms = torch.stack([to_tensor(norm) for norm in group_norms])
    if group_norms.eq(-1).any():
        return -1

    # combine norms
    if norm_type == inf:
        total_norm = group_norms.max().item()
    else:
        total_norm = group_norms.pow(norm_type).sum()
        total_norm = total_norm.item()**(1. / norm_type)
        if total_norm == float('inf') or total_norm == -float('inf'):
            total_norm = -1

    return total_norm

``

@tohtana
Copy link
Contributor

tohtana commented Mar 15, 2024

@inkcherry Thank you for the detailed explanation! I agree this is the right fix. The behaviors of BF16/FP16 should match.
Can you also check @mosheisland's proposal to make sure if you haven't missed something.

Honestly, I'm not so sure if we should use a different scaling for experts or not. With data parallel, we usually scale gradients according to the ratio of number of samples. It seems reasonable to scale gradients (not only gradient norm) according to number of tokens that each expert computes. I would appreciate it if you could share any thought.

@inkcherry
Copy link
Contributor Author

inkcherry commented Mar 15, 2024

Thank you very much for your suggestions. @mosheisland @tohtana
I revisited the gradient section and indeed, for a single step (using the same amount of data), setting different ep_size values will lead to different gradients because here, the gradients over ep_dp_group are averaged(

def allreduce_bucket(self, bucket, dp_group):
. ( (It seems that using ep_dp_group for sum communication here, but dividing by dp_world_size would be more appropriate?) Consequently, for the same amount of data, the gradient is directly proportional to ep_size. To ensure consistency, the convergence of training should not be affected by parallel parameters like ep_size. Thus, seems expert_grad_norm dividing by ep_size at the end is correct. So, it seems reasonable to scale down ep_size. I appreciate the suggestions from mosheisland and I hope to refine the code based on it later.

Excluding the discussion about EP, focusing solely on the scaling of gradients in the context of the MOE,

  • for scenarios with uneven token distribution among experts, scaling different experts' gradient based on their data volume seems reasonable.
  • If load balancing is considered, scaling by the token count, possibly by dividing by num_experts, seems logical, given that the MOE layer receives fewer tokens compared to non-MOE layers. However, these are personal ideas that require specific experiments for validation. : )

@mosheisland
Copy link
Contributor

I agree with @inkcherry and I also think that MoE parameter grads should be summed using ep_dp_group and scaled using dp group.
One strong motivation is that changing parallel configuration should not affect the actual training (up to numerical changes).
For example,
Assuming using same global batch size and micro batch size.
And assuming following 2 configurations:
Configuration 1: 1x GPU with DP=1 EP=1 num_experts=4
Configuration 2: 2x GPU with DP=1 EP=2 num_experts=4 (==> EP_DP=1)
Then, both configurations should have "same" MoE parameter gradients.

However, currently configuration 2 parameter gradients are exactly 2x than configuration 1 parameter gradients.
To fix it, as suggested by @inkcherry, we need to divide the gradients by DP world size (instead of EP_DP world size).

Actually, locally I already implemented this change it as part of my changes.
This ensure same MoE parameter gradients regardless of the scaling parallel topology used (I tested this across multiple scaling configurations).

@tohtana
Copy link
Contributor

tohtana commented Mar 17, 2024

Thank you @inkcherry @mosheisland for sharing your insights. They seem reasonable to me.
Please let me know after you finish adding the remaining fix. I will review again and approve this PR soon.

@inkcherry
Copy link
Contributor Author

inkcherry commented Mar 19, 2024

The following three configurations were tested for convergence testing.
8GPUs, 1B model with 4 experts, global_bs=256, 300 iters.

  • fp16 ,expert-parallel-size 1 (no ep)
  • bf16 ,expert-parallel-size 1
  • bf16 ,expert-parallel-size 4
    image

hi, @mosheisland , I added the modifications discussed above, do you have any suggestions or comments? I would appreciate your thoughts: ), If you're okay with these modifications, let's ping tohtana to take a look.

values.mul_(self.gradient_predivide_factor() /
(dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
if dp_world_size is None:
dp_world_size = dist.get_world_size(group=dp_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be moved to before if self.postscale_gradienst():

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -261,7 +261,7 @@ def step(self, closure=None):
pg = self.deepspeed.mpu.get_data_parallel_group()
else:
pg = groups._get_data_parallel_group()
all_groups_norm = get_norm_with_moe_layers(all_groups_norm, pg)
all_groups_norm = get_norm_with_moe_layers_fast(all_groups_norm, pg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the non-fast also for fp16? I think the "fast" version will give incorrect results when you are using multiple EP sizes in a single model (e.g. first layer uses EP=4 and 2nd layer uses EP=2).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have also unified the past fp16 with the current bf16.

@mosheisland
Copy link
Contributor

@inkcherry - added few comments inside the source code.
In addition, I think that to verify that the change is working correctly without regressions, I would test the following:

  1. Compare loss curve of bf16 without moe before and after this change (to make sure no regressions introduced)
  2. Compare loss curve of fp16 without moe before and after this change (to make sure no regressions introduced)
  3. Compare loss curve of bf16 with 4 experts, 2 devices: first config with EP=1 second config with EP=2
  4. Compare loss curve of fp16 with 4 experts, 2 devices: first config with EP=1 second config with EP=2

In all above, I would also compare grad norm graph. After the expert grad reduction fix, grad norm curve should be the same regardless of the EP used (as opposed to before the fix).

In order to make the loss curve comparisons more clear, I would go with a smaller model but with more steps (e.g. 1000).

@inkcherry inkcherry requested a review from awan-10 as a code owner March 20, 2024 10:05
@inkcherry
Copy link
Contributor Author

inkcherry commented Mar 21, 2024

1.3Bmodel, 4experts, 4GPUs with ep_size=1/2 , global_bs=16 ,mp=1

Before this change FP16: loss& grad_norm
(The images on the webpage may be scaled, requiring a click for translation)
image
image

The gradient norm shows a difference of approximately 0.26 from the first iteration. Some discrepancies in the loss can be observed after roughly 300 steps.

After this change:

  • FP16:

image
image

  • BF16:

image
image

The grad_norm remains consistent at the beginning(Although there were some gaps after 200 steps, which may be low precision gap accumulation). Moreover, there is improved consistency in the loss. At 1000 steps, no discrepancies similar to those in the old_fp16 graph have appeared.

@inkcherry inkcherry changed the title support bf16_optimizer moe expert parallel training support bf16_optimizer moe expert parallel training and moe grad_scale/grad_norm fix Mar 21, 2024
@inkcherry
Copy link
Contributor Author

hi, @tohtana , I have added the remaining fixes.
As mentioned by @mosheisland , I have conducted local testing for 1 and 2, and supplemented some more detailed tests of 3 to 4.
Could you please take a look? Thanks.

@inkcherry inkcherry changed the title support bf16_optimizer moe expert parallel training and moe grad_scale/grad_norm fix support bf16_optimizer moe expert parallel training and moe EP grad_scale/grad_norm fix Mar 21, 2024
@tohtana
Copy link
Contributor

tohtana commented Mar 22, 2024

Thank you @inkcherry, I reviewed the code. This is truly amazing work! I also appreciate the thorough verification.
Now many people are starting using MoE. I believe many users benefit from this work.

@@ -884,6 +895,8 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
if moe_ep_group is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group)
total_norm = total_norm_cuda[0].item()
Copy link
Contributor

@mosheisland mosheisland Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_norm = total_norm_cuda[0].item() -
This line needs to activated also for the if mpu is not None case
maybe shift left one tab.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, fixed.

@inkcherry inkcherry requested a review from loadams as a code owner March 27, 2024 03:31
@inkcherry
Copy link
Contributor Author

hi @tohtana , thanks for the review. and I just fixed the CI error. Could you please trigger the CI again?

@tjruwase tjruwase added this pull request to the merge queue Mar 27, 2024
Merged via the queue into microsoft:master with commit e5dd550 Mar 27, 2024
12 checks passed
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
…cale/grad_norm fix (microsoft#5259)

- bf16 moe EP requires different partitions and this will impact dp
gradient allreduce, zero1 params allgather, as well as gradient_norm
allreduce. Currently, the bf16_optimizer does not correctly partition
the group. fix and support bf16 type training.
- fix calculation of moe ep grad scale and grad_norm for bf16&fp16

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this pull request Jun 11, 2024
…cale/grad_norm fix (microsoft#5259)

- bf16 moe EP requires different partitions and this will impact dp
gradient allreduce, zero1 params allgather, as well as gradient_norm
allreduce. Currently, the bf16_optimizer does not correctly partition
the group. fix and support bf16 type training.
- fix calculation of moe ep grad scale and grad_norm for bf16&fp16

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants