Skip to content

Commit

Permalink
propagate rank result to gathered result for autograd compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Oct 8, 2024
1 parent fd11f33 commit 8122e9f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def class_reduce(
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result


Expand Down Expand Up @@ -144,4 +146,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result

0 comments on commit 8122e9f

Please sign in to comment.