Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd graph to propagate after multi-device syncing (for loss functions in ddp) #2754

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def class_reduce(
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
gathered_result[torch.distributed.get_rank(group)] = result
Borda marked this conversation as resolved.
Show resolved Hide resolved
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return gathered_result


Expand Down Expand Up @@ -144,4 +146,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
gathered_result[torch.distributed.get_rank(group)] = result
Borda marked this conversation as resolved.
Show resolved Hide resolved
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return gathered_result
126 changes: 126 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,62 @@ def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PRO
assert (val == torch.ones_like(val)).all()


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test only considers tensors of the same shape across different ranks.

Note that this test only works for torch>=2.0.

"""
tensor = torch.ones(50, requires_grad=True)
Borda marked this conversation as resolved.
Show resolved Hide resolved
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test considers tensors of different shapes across different ranks.

Note that this test only works for torch>=2.0.

"""
tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True)
Borda marked this conversation as resolved.
Show resolved Hide resolved
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None:
dummy = DummyMetricSum()
dummy._reductions = {"x": torch.sum}
Expand Down Expand Up @@ -105,6 +161,76 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test only considers tensors of the same shape across different ranks.

Note that this test only works for torch>=2.0.

"""
tensor = torch.ones(50, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test considers tensors of different shapes across different ranks.

Note that this test only works for torch>=2.0.

"""
tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


@pytest.mark.DDP()
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.parametrize(
"process",
[
_test_ddp_gather_autograd_same_shape,
_test_ddp_gather_autograd_different_shape,
],
)
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_non_contiguous_tensors(rank):
class DummyCatMetric(Metric):
full_state_update = True
Expand Down
Loading