diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index d92211b9d220..e6a5292d7e4f 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -95,11 +95,7 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: class _AllToAll(torch.autograd.Function): @staticmethod - def forward( - ctx: Any, - # TODO: replace with DS process group - group: torch.distributed.ProcessGroup, - input: Tensor) -> Tensor: # type: ignore + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() output = torch.empty_like(input) diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index d63d7e985e07..543795126fab 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -12,8 +12,7 @@ import torch from torch import Tensor from deepspeed import comm as dist -# NOTE: Use torch.distributed's ProcessGroup class until we have our own. -from torch.distributed import ProcessGroup, all_to_all_single +from deepspeed.comm import ProcessGroup, all_to_all_single from deepspeed.accelerator import get_accelerator from deepspeed.utils import instrument_w_nvtx from deepspeed.ops import op_builder