Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use deepspeed.comm instead of torch.distributed #5225

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
class _AllToAll(torch.autograd.Function):

@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
input: Tensor) -> Tensor: # type: ignore
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/comm/coalesced_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import torch
from torch import Tensor
from deepspeed import comm as dist
# NOTE: Use torch.distributed's ProcessGroup class until we have our own.
from torch.distributed import ProcessGroup, all_to_all_single
from deepspeed.comm import ProcessGroup, all_to_all_single
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import instrument_w_nvtx
from deepspeed.ops import op_builder
Expand Down
Loading