Skip to content

Commit

Permalink
reduce all-to-all communication volume when both expert and non-exper…
Browse files Browse the repository at this point in the history
…t are tensor-parallel (#5626)

Example: E + M + D parallel
world_size = 8
model_degree = 2
expert_degree = 4 
mp_group = [0, 1], [2,3], [4,5],[6,7]
expert_parallel_group = [0,2,4,6], [1,3,5,7]

The original execution method was that before executing Expert, there
was no drop operation, and two EPs did all-to-all separately. In the
end, they both obtained complete data, but 0 and 1 obtained exactly the
same data. Similarly, 2, 3, and so on all obtained the same data.
Therefore, we can drop the data before executing all-to-all, and then
execute allgather after all-to-all to obtain the complete data.

After executing Expert, the data on 0 and 1 is exactly the same, so we
can drop it and then execute all-to-all , and then execute allgather to
obtain the complete data.


1. non-expert use TP, expert not use TP: drop -> alltoall -> exe MOE ->
alltoall -> allgather
2. both non-expert and expert all use TP: 
- the original execution order: alltoall -> exe MOE-> allreduce ->
alltoall
- optimized execution order: drop -> alltoall -> allgather -> exe MOE ->
drop ->alltoall -> allgather

Signed-off-by: --local <[email protected]>
Co-authored-by: --local <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Jul 22, 2024
1 parent 213e2d9 commit f5d6c63
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
26 changes: 18 additions & 8 deletions deepspeed/moe/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,23 @@ def _gather_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu

input_ = input_.contiguous()
# Size and dimension.
rank = bwc_tensor_model_parallel_rank(mpu)

tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))
world_size = bwc_tensor_model_parallel_world_size(mpu)
if world_size == 1:
return input_

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu))
if dim == 0:
shape = list(input_.size())
shape[0] = shape[0] * world_size
output = gather_buffer.view(shape)
else:
tensor_list = [
gather_buffer.narrow(0,
input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
]
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()

return output

Expand All @@ -50,6 +58,8 @@ def _drop_tokens(input_, dim=0):
mpu = deepspeed.utils.groups.mpu

total_chunks = bwc_tensor_model_parallel_world_size(mpu)
if total_chunks == 1:
return input_
this_chunk = bwc_tensor_model_parallel_rank(mpu)
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
Expand Down
36 changes: 25 additions & 11 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,18 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).start()

if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, it will create
tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu)
if tensor_model_world_size > 1:
# If the non-expert is tensor-parallel,
# Whether expert is tensor-parallel or not , it will create
# duplicate tokens on the tensor-parallel ranks.
# Since our experts are not tensor-parallel, these duplicates
# need to be dropped to ensure correctness.
# this also doubles up as a communication optimization as we are
# reducing the all-to-all communication volume.
# drop duplicate tokens also doubles up as a communication
# optimization as we are reducing the all-to-all communication volume.
# 1: for not tensor-parallel expert,drop duplicate tokens to ensure
# both correctness and reduce all-to-all communication.
# 2: for tensor-parallel expert,drop duplicate tokens to reduce all-to-all
# communication volume,before expert execution, it is necessary to perform
# an allgather to ensure correctness,
dispatched_input = drop_tokens(dispatched_input, dim=1)

dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
Expand All @@ -548,10 +553,22 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
self.timers(FIRST_ALLTOALL_TIMER).stop()
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)

if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
# if both expert and non-expert are tensor-parallel
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again to ensure correctness
dispatched_input = gather_tokens(dispatched_input, dim=1)

# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)

expert_output = self.experts(dispatched_input)
# Re-shape before drop_tokens: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1:
# if both expert and non-expert are tensor-parallel
# drop duplicate tokens to ensure both correctness
# and reduce all-to-all communication.
expert_output = drop_tokens(expert_output, dim=1)

if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).start()
Expand All @@ -562,10 +579,7 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
self.timers(SECOND_ALLTOALL_TIMER).stop()
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)

# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)

if groups._get_expert_model_parallel_world_size() == 1:
if tensor_model_world_size > 1:
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again for the tensor-parallel
# non-expert of the next layer.
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/transformer/inference/moe_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def forward(self,

if self.expert_mp_group is not None:
world_size = dist.get_world_size(group=self.expert_mp_group)
gather_buffer = torch.zeros(world_size * attention_output.numel(),
gather_buffer = torch.empty(world_size * attention_output.numel(),
dtype=attention_output.dtype,
device=attention_output.device)
dist.all_gather_into_tensor(gather_buffer, attention_output, group=self.expert_mp_group)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,7 +2237,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]:
return grad_dict

def _fp32_state_allgather(self, param, fp32_state_partition):
reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(),
reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(),
dtype=torch.float32,
device=param.device)
my_rank = dist.get_rank(group=self.dp_process_group)
Expand Down

0 comments on commit f5d6c63

Please sign in to comment.