Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix expert grad scaling problem with ZeRO optimizer #6546

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,14 +1070,10 @@ def average_tensor(self, tensor):
for i, param, param_id in self.params_in_ipg_bucket:

process_group = self.dp_process_group
grad_reduc = self.get_gradient_for_reduction(param)
#Averages gradients at parameter level if ipg has a moe param
#Otherwise averaging is done at the entire buffer level at the end of the loop
# MoE param have different groups

if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param) else self.dp_process_group
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))

partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
Expand Down Expand Up @@ -1116,8 +1112,7 @@ def average_tensor(self, tensor):
curr_size += numel
prev_id, prev_process_group = partition_id, process_group

if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only grad for expert is not correct, we only need to make 'grad_reduc' divide edp_world_size -> divide dp_world_size, why we need use 'tensor' for divide, it may contain more data not only gradient ? I just feel confused about here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, there are only gradients waiting to do all-reduce in 'tensor'.

From the code, 'tensor' may be a buffer in 'self.ipg_buffer' or the gradient of 'self.extra_large_param_to_reduce' . So, 'tensor' is composed of data from one or more weight gradients, and the data pointer of 'grad_reduc' points to an address within 'tensor'.

According to the comments in the code, the logic of the old version code is:

  • Averages gradients at parameter level if ipg has a moe param, i.e. do average on 'grad_reduc'
  • Otherwise averaging is done at the entire buffer level at the end of the loop, i.e. do average on 'tensor'.

He did this because he wanted to divide the expert gradient by edp_size and the non-expert gradient by dp_size, so he must do the average at the parameter level when there is a moe param. But in our PR, we divide all weight gradients by dp_size, so we can directly do the average at the entire buffer level.

In addition, maybe I need also delete those old comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for clarification, I agree with you for deleting those old comments.

tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

buckets = {}
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import deepspeed
import pytest
import gc
import random
from unit.common import DistributedTest
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
Expand Down Expand Up @@ -238,3 +239,114 @@ def check_equal(logits, cap, sparse_truth, res):
[2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)


class TestExpertWeightGradWithZero(DistributedTest):
world_size = 2

@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test(self, zero_stage):

if not required_torch_version(min_version=1.8):
pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

def seed_everything(seed=11):
random.seed(seed)
torch.manual_seed(seed)
get_accelerator().manual_seed(seed)
get_accelerator().manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def get_state_dict_ep2(state_dict):
"""
convert state_dict from EP=1 to EP=2
"""
rank = int(deepspeed.comm.get_rank())
ep_state_dict = dict()
dst_sub_key = f"deepspeed_moe.experts.deepspeed_experts.0"
src_sub_key = f"deepspeed_moe.experts.deepspeed_experts.{rank}"
for moe_layer in ["moe_1", "moe_2"]:
for mlp_in_moe in [0, 1]:
dst_key = f"{moe_layer}.{dst_sub_key}.{mlp_in_moe}"
src_key = f"{moe_layer}.{src_sub_key}.{mlp_in_moe}"
ep_state_dict[f"{dst_key}.weight"] = state_dict[f"{src_key}.weight"].detach().clone()
ep_state_dict[f"{dst_key}.bias"] = state_dict[f"{src_key}.bias"].detach().clone()

for key in state_dict.keys():
if "deepspeed_moe.experts.deepspeed_experts" not in key:
ep_state_dict[key] = state_dict[key].detach().clone()
return ep_state_dict

def get_models(hidden_dim):
model_ep1 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=1, use_rts=False)
model_ep2 = SimpleMoEModel(hidden_dim=hidden_dim, num_experts=2, ep_size=2, use_rts=False)

state_dict_ep1 = model_ep1.state_dict()
state_dict_ep2 = get_state_dict_ep2(state_dict_ep1)
model_ep2.load_state_dict(state_dict_ep2)

model_ep1, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep1)
model_ep2, _, _, _ = deepspeed.initialize(config=config_dict, model=model_ep2)

return model_ep1, model_ep2

def extract_expert_grad(model, expert_id):

def _get_weight_bias(experts):
return ([deepspeed.utils.safe_get_full_grad(expert[0].weight)
for expert in experts][expert_id].detach().clone(),
[deepspeed.utils.safe_get_full_grad(expert[0].bias)
for expert in experts][expert_id].detach().clone(),
[deepspeed.utils.safe_get_full_grad(expert[1].weight)
for expert in experts][expert_id].detach().clone(),
[deepspeed.utils.safe_get_full_grad(expert[1].bias)
for expert in experts][expert_id].detach().clone())

return (*_get_weight_bias(model.moe_1.deepspeed_moe.experts.deepspeed_experts),
*_get_weight_bias(model.moe_2.deepspeed_moe.experts.deepspeed_experts))

seed_everything()

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.1,
}
},
"zero_optimization": {
"stage": zero_stage
}
}

hidden_dim = 4
total_samples = 2
rank = deepspeed.comm.get_rank()
model_ep1, model_ep2 = get_models(hidden_dim)

data_loader = sequence_dataloader(model=model_ep1,
total_samples=total_samples,
hidden_dim=hidden_dim,
device=model_ep1.device,
dtype=torch.float32)
expert_weight_grad_ep1 = []
expert_weight_grad_ep2 = []
for batch in data_loader:
loss_ep1 = model_ep1(batch[0], batch[1])
loss_ep2 = model_ep2(batch[0], batch[1])

model_ep1.backward(loss_ep1)
model_ep2.backward(loss_ep2)

expert_weight_grad_ep1.extend(extract_expert_grad(model_ep1, rank))
expert_weight_grad_ep2.extend(extract_expert_grad(model_ep2, 0))

model_ep1.step()
model_ep2.step()

assert len(expert_weight_grad_ep1) == len(expert_weight_grad_ep2)
for grad_from_ep1, grad_from_ep2 in zip(expert_weight_grad_ep1, expert_weight_grad_ep2):
assert torch.allclose(grad_from_ep1, grad_from_ep2, atol=0, rtol=1e-4)
8 changes: 5 additions & 3 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(self, x, y, **kwargs):

class SimpleMoEModel(torch.nn.Module):

def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False, use_rts=True):
super(SimpleMoEModel, self).__init__()
self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim)
expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim))
Expand All @@ -89,7 +89,8 @@ def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
k=1,
use_rts=use_rts)
# interleaving MoE modules with dense to create an opportunity
# for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
Expand All @@ -98,7 +99,8 @@ def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
k=1,
use_rts=use_rts)
self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

Expand Down
Loading