Skip to content

Commit

Permalink
More idist allreduce/gather test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Nov 7, 2024
1 parent 448a75b commit 61a2c29
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 74 deletions.
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def spawn(
}

def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:
if group == dist.GroupMember.NON_GROUP_MEMBER:
return tensor

if op not in self._reduce_op_map:
raise ValueError(f"Unsupported reduction operation: '{op}'")
if group is not None and not isinstance(group, dist.ProcessGroup):
Expand Down
127 changes: 69 additions & 58 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,21 @@ def _test_distrib_all_reduce(device):

def _test_distrib_all_reduce_group(device):
if idist.get_world_size() > 1 and idist.backend() is not None:
ranks = [0, 1]
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
rank = idist.get_rank()
t = torch.tensor([rank], device=device)
bnd = idist.backend()

group = idist.new_group(ranks)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
for group in [idist.new_group(ranks), ranks]:
t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
res = idist.all_reduce(t, group=group)
else:
res = idist.all_reduce(t, group=group)
else:
res = idist.all_reduce(t, group=group)
assert res == torch.tensor([sum(ranks)], device=device)

t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
res = idist.all_reduce(t, group=ranks)
else:
res = idist.all_reduce(t, group=ranks)
assert res == torch.tensor([sum(ranks)], device=device)
if rank in ranks:
assert res == torch.tensor([sum(ranks)], device=device)
else:
assert res == t

ranks = "abc"

Expand Down Expand Up @@ -218,33 +213,23 @@ def _test_distrib_all_gather(device):


def _test_distrib_all_gather_group(device):
if idist.get_world_size() > 1:
if idist.get_world_size() > 1 and idist.backend() is not None:
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
sorted_ranks = sorted(ranks)
rank = idist.get_rank()
bnd = idist.backend()

t = torch.tensor([rank], device=device)
group = idist.new_group(ranks)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=group)
else:
res = idist.all_gather(t, group=group)
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
else:
assert res == t

t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
else:
res = idist.all_gather(t, group=ranks)
if rank in ranks:
assert torch.equal(res, torch.tensor(ranks, device=device))
for group in [idist.new_group(ranks), ranks]:
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
res = idist.all_gather(t, group=group)
else:
assert res == t
res = idist.all_gather(t, group=group)
if rank in ranks:
assert (res == torch.tensor(sorted_ranks, device=device)).all(), (res, ranks)
else:
assert res == t

t = {
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
Expand All @@ -261,7 +246,7 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group=ranks)
if rank in ranks:
assert isinstance(res, list) and len(res) == len(ranks)
for i, obj in zip(ranks, res):
for i, obj in zip(sorted_ranks, res):
assert isinstance(obj, dict)
assert list(obj.keys()) == ["a", "b", "c"], obj
expected_device = (
Expand Down Expand Up @@ -295,20 +280,44 @@ def _test_idist_all_gather_tensors_with_shapes(device):
torch.manual_seed(41)
rank = idist.get_rank()
ws = idist.get_world_size()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
reference = torch.randn(ws + 6, ws + 6, ws + 6, device=device)

ref_indices_per_rank = {
# rank: (start_index, end_index, size)
r: (r + 1, 2 * r + 2, r + 1)
for r in range(ws)
}
start_index, end_index, _ = ref_indices_per_rank[rank]
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
start_index : end_index + 1,
start_index : end_index + 2,
start_index : end_index + 3,
]
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)])
tensors = all_gather_tensors_with_shapes(
rank_tensor,
[
[
ref_indices_per_rank[r][2] + 1,
ref_indices_per_rank[r][2] + 2,
ref_indices_per_rank[r][2] + 3,
]
for r in range(ws)
],
)
for r in range(ws):
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
start_index, end_index, _ = ref_indices_per_rank[r]
ref_tensor = reference[
start_index : end_index + 1,
start_index : end_index + 2,
start_index : end_index + 3,
]
assert (r_tensor == tensors[r]).all()
assert torch.allclose(ref_tensor, tensors[r]), (
r,
ref_tensor.shape,
ref_tensor.mean(),
tensors[r].shape,
tensors[r].mean(),
)


def _test_idist_all_gather_tensors_with_shapes_group(device):
Expand All @@ -320,27 +329,29 @@ def _test_idist_all_gather_tensors_with_shapes_group(device):
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
reference = torch.randn(
ws * (ws + 1) // 2 + 1, ws * (ws + 3) // 2 + 1, ws * (ws + 5) // 2 + 1, device=device
)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 2,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 3,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 4,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 2, r + 3, r + 4] for r in ranks], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 2, r + 3, r + 4] for r in ranks], ranks)
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 2,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 3,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 4,
]
assert (r_tensor == tensors[r - 1]).all()
assert r_tensor.allclose(tensors[r - 1])
else:
assert [rank_tensor] == tensors

Expand Down
17 changes: 2 additions & 15 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
Expand All @@ -253,21 +255,6 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_all_gather_tensors_with_shapes_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
def test_idist_all_gather_tensors_with_shapes_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)

Expand Down
2 changes: 1 addition & 1 deletion tests/run_cpu_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fi
# Run 2 processes with --dist=each
CUDA_VISIBLE_DEVICES="" run_tests \
--core_args "-m distributed -vvv tests/ignite" \
--world_size 2 \
--world_size 4 \
--cache_dir ".cpu-distrib" \
--skip_distrib_tests 0 \
--use_coverage 1 \
Expand Down

0 comments on commit 61a2c29

Please sign in to comment.