Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 8, 2024
1 parent 3591435 commit dc35370
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESS
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. This test only considers tensors of the same shape across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.ones(50, requires_grad=True)
result = gather_all_tensors(tensor)
Expand All @@ -108,6 +109,7 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. This test considers tensors of different shapes across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True)
result = gather_all_tensors(tensor)
Expand Down

0 comments on commit dc35370

Please sign in to comment.