Skip to content

Commit

Permalink
Remove unused alltoallv function
Browse files Browse the repository at this point in the history
Summary: This code isn't referenced or used anywhere.

Differential Revision: D67110533
  • Loading branch information
sarckk authored and facebook-github-bot committed Dec 12, 2024
1 parent 3928a1b commit 4a227d0
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 187 deletions.
109 changes: 0 additions & 109 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,115 +739,6 @@ def all2all_sequence_sync(
return sharded_output_embeddings.view(-1, D)


def alltoallv(
inputs: List[Tensor],
out_split: Optional[List[int]] = None,
per_rank_split_lengths: Optional[List[int]] = None,
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
) -> Awaitable[List[Tensor]]:
"""
Performs `alltoallv` operation for a list of input embeddings. Each process scatters
the list to all processes in the group.
Args:
inputs (List[Tensor]): list of tensors to scatter, one per rank. The tensors in
the list usually have different lengths.
out_split (Optional[List[int]]): output split sizes (or dim_sum_per_rank), if
not specified, we will use `per_rank_split_lengths` to construct a output
split with the assumption that all the embs have the same dimension.
per_rank_split_lengths (Optional[List[int]]): split lengths per rank. If not
specified, the `out_split` must be specified.
group (Optional[dist.ProcessGroup]): the process group to work on. If None, the
default process group will be used.
codecs (Optional[QuantizedCommCodecs]): quantized communication codecs.
Returns:
Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting list of tensors.
.. warning::
`alltoallv` is experimental and subject to change.
"""

if group is None:
group = dist.distributed_c10d._get_default_group()

world_size: int = group.size()
my_rank: int = group.rank()

B_global = inputs[0].size(0)

D_local_list = []
for e in inputs:
D_local_list.append(e.size()[1])

B_local, B_local_list = _get_split_lengths_by_len(world_size, my_rank, B_global)

if out_split is not None:
dims_sum_per_rank = out_split
elif per_rank_split_lengths is not None:
# all the embs have the same dimension
dims_sum_per_rank = []
for s in per_rank_split_lengths:
dims_sum_per_rank.append(s * D_local_list[0])
else:
raise RuntimeError("Need to specify either out_split or per_rank_split_lengths")

a2ai = All2AllVInfo(
dims_sum_per_rank=dims_sum_per_rank,
B_local=B_local,
B_local_list=B_local_list,
D_local_list=D_local_list,
B_global=B_global,
codecs=codecs,
)

if get_use_sync_collectives():
return NoWait(all2allv_sync(group, a2ai, inputs))

myreq = Request(group, device=inputs[0].device)
All2Allv_Req.apply(group, myreq, a2ai, inputs)

return myreq


def all2allv_sync(
pg: dist.ProcessGroup,
a2ai: All2AllVInfo,
inputs: List[Tensor],
) -> List[Tensor]:
input_split_sizes = []
sum_D_local_list = sum(a2ai.D_local_list)
for m in a2ai.B_local_list:
input_split_sizes.append(m * sum_D_local_list)

output_split_sizes = []
for e in a2ai.dims_sum_per_rank:
output_split_sizes.append(a2ai.B_local * e)

input = torch.cat(inputs, dim=1).view([-1])
if a2ai.codecs is not None:
input = a2ai.codecs.forward.encode(input)

with record_function("## alltoallv_bwd_single ##"):
output = torch.ops.torchrec.all_to_all_single(
input,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)

if a2ai.codecs is not None:
output = a2ai.codecs.forward.decode(output)

outputs = []
for out in output.split(output_split_sizes):
outputs.append(out.view([a2ai.B_local, -1]))
return outputs


def reduce_scatter_pooled(
inputs: List[Tensor],
group: Optional[dist.ProcessGroup] = None,
Expand Down
78 changes: 0 additions & 78 deletions torchrec/distributed/tests/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,84 +204,6 @@ def _run_multi_process_test(
p.join()
self.assertEqual(0, p.exitcode)

@classmethod
def _test_alltoallv(
cls,
rank: int,
world_size: int,
backend: str,
compile_config: _CompileConfig,
specify_pg: bool,
) -> None:
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
pg = GroupMember.WORLD
assert pg is not None

device = torch.device(f"cuda:{rank}")

torch.cuda.set_device(device)

B_global = 10
D0 = 8
D1 = 9

input_embedding0 = torch.rand(
(B_global, D0),
device=device,
requires_grad=True,
)
input_embedding1 = torch.rand(
(B_global, D1),
device=device,
requires_grad=True,
)

input_embeddings = [input_embedding0, input_embedding1]
out_split = [17, 17]

# pyre-ignore
def fn(*args, **kwargs) -> List[torch.Tensor]:
return comm_ops.alltoallv(*args, **kwargs).wait()

fn_transform = compile_config_to_fn_transform(compile_config)

with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
v_embs_out = fn_transform(fn)(
input_embeddings, out_split=out_split, group=pg if specify_pg else None
)

res = torch.cat(v_embs_out, dim=1).cpu()
assert tuple(res.size()) == (5, 34)
dist.destroy_process_group()

@unittest.skipIf(
torch.cuda.device_count() < 2, "Need at least two ranks to run this test"
)
# pyre-ignore
@given(
specify_pg=st.sampled_from([True]),
test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]),
)
@settings(deadline=None)
def test_alltoallv(
self,
specify_pg: bool,
test_compiled_with_noncompiled_ranks: bool,
) -> None:
self._run_multi_process_test(
world_size=self.WORLD_SIZE,
backend="nccl",
# pyre-ignore [6]
callable=self._test_alltoallv,
compile_config=_CompileConfig(
test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks
),
specify_pg=specify_pg,
)

@classmethod
def _test_alltoall_sequence(
cls,
Expand Down

0 comments on commit 4a227d0

Please sign in to comment.