Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove wait in all_to_all_single custom op #2646

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 37 additions & 67 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def all2all_pooled_sync(
qcomm_ctx = None

with record_function("## alltoall_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_output_embeddings = AllToAllSingle.apply(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
Expand Down Expand Up @@ -572,7 +572,7 @@ def variable_batch_all2all_pooled_sync(
torch._check(s0 <= sharded_input_embeddings.size(0))
sharded_output_embeddings.copy_(sharded_input_embeddings[:s0])
else:
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_output_embeddings = AllToAllSingle.apply(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
Expand Down Expand Up @@ -722,7 +722,7 @@ def all2all_sequence_sync(
qcomm_ctx = None

with record_function("## alltoall_seq_embedding_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_output_embeddings = AllToAllSingle.apply(
sharded_input_embeddings,
output_splits,
input_splits,
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def reduce_scatter_v_sync(
input_splits = rsi.input_splits
output_splits = [rsi.input_splits[rank]] * world_size
# TODO(ivankobzarev): Replace with _functional_collectives.reduce_scatter_v when it is added
a2a_output = torch.ops.torchrec.all_to_all_single(
a2a_output = AllToAllSingle.apply(
input,
output_splits,
input_splits,
Expand Down Expand Up @@ -2348,71 +2348,41 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:

if not torch._running_with_deploy(): # noqa C901
# Torch Library op def can not be used in Deploy
class AllToAllSingle(torch.autograd.Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
group_name: str,
group_size: int,
gradient_division: bool,
) -> Tensor:
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
ctx.group_name = group_name
ctx.group_size = group_size
ctx.gradient_division = gradient_division
return torch.distributed._functional_collectives.all_to_all_single(
input, output_split_sizes, input_split_sizes, group_name
)

# torchrec::all_to_all_single
@torch.library.custom_op("torchrec::all_to_all_single", mutates_args=())
def all_to_all_single(
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
group_name: str,
group_size: int,
gradient_division: bool,
) -> Tensor:
out = torch.ops._c10d_functional.all_to_all_single(
input, output_split_sizes, input_split_sizes, group_name
)
return torch.ops._c10d_functional.wait_tensor(out)

@torch.library.register_fake("torchrec::all_to_all_single")
def all_to_all_single_fake(
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
group_name: str,
group_size: int,
gradient_division: bool,
) -> Tensor:
return torch.ops._c10d_functional.all_to_all_single(
input, output_split_sizes, input_split_sizes, group_name
)

# pyre-ignore
def all_to_all_single_setup_context(ctx, inputs, output) -> None:
(
_,
output_split_sizes,
input_split_sizes,
group_name,
group_size,
gradient_division,
) = inputs
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
ctx.group_name = group_name
ctx.group_size = group_size
ctx.gradient_division = gradient_division

# pyre-ignore
def all_to_all_single_backward(ctx, grad):
# TODO(ivankobzarev): Support codecs(quantization) on backward
a2a_out = torch.ops._c10d_functional.all_to_all_single(
grad,
ctx.output_split_sizes,
ctx.input_split_sizes,
ctx.group_name,
)
grad = torch.ops._c10d_functional.wait_tensor(a2a_out)
if ctx.gradient_division:
grad.div_(ctx.group_size)

return grad, None, None, None, None, None
@staticmethod
# pyre-ignore
def backward(ctx, grad):
grad = torch.distributed._functional_collectives.all_to_all_single(
grad,
ctx.output_split_sizes,
ctx.input_split_sizes,
ctx.group_name,
)
if ctx.gradient_division:
grad.div_(ctx.group_size)

torch.library.register_autograd(
"torchrec::all_to_all_single",
all_to_all_single_backward,
setup_context=all_to_all_single_setup_context,
)
return grad, None, None, None, None, None

# torchrec::reduce_scatter_tensor
@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=())
Expand Down
Loading