-
Notifications
You must be signed in to change notification settings - Fork 401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autograd with DDP #2745
Labels
enhancement
New feature or request
Comments
Hi! thanks for your contribution!, great first issue! |
It looks like adding def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result found here. |
Open
4 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I have a setup with
torch.Lightning
where I'm using customtorchmetrics.Metric
as loss function contributions. Now I want to be able to do it withddp
by settingdist_sync_on_step=True
, but the gradients are not propagated during theall_gather
. All I want is for the tensor on the current process to have its autograd graph kept for the backward pass after the syncing operations. I've only just began looking into distributed stuff intorch
, so I'm not experienced in these matters. But following theforward()
call ofMetric
(at each training batch step), it then calls_forward_reduce_state_update()
, which calls thecompute()
function wrapped by_wrap_compute()
, which would dosync()
, which finally calls_sync_dist()
. And it looks like the syncing usestorchmetrics.utilities.distributed.gather_all_tensors
.I just wanted to ask if it is possible to achieve what I want by modiyfing
_simple_gather_all_tensors
(here)?_simple_gather_all_tensors
presented here for reference.I'm guessing that
result
still carries the autograd graph. My naive hope is that we can just updategathered_result
with the inputresult
(carrying the autograd graph) to achieve the desired effect.For context, my use case is such that batches can have very inhomogeneous
numel
s, so each device could have error tensors with very differentnumel
s such that taking a mean ofMeanSquaredError
s may not be ideal. Ideally, if the syncing holds the autograd graph, the per-step loss would be the "true" metric as per its definition and the gradients would be consistent with that definition (so syncing is done once for for each loss metric contribution, and once for the backward at each training step, I think).Thank you!
The text was updated successfully, but these errors were encountered: