Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd graph to propagate after multi-device syncing (for loss functions in ddp) #2754

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

cw-tan
Copy link

@cw-tan cw-tan commented Sep 17, 2024

What does this PR do?

Fixes #2745

Single-line enhancement proposed in #2745, that is, to enable the propagation of the autograd graph after the all_gather operation. This is useful for syncing loss functions in a ddp setting.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2754.org.readthedocs.build/en/2754/

@Borda
Copy link
Member

Borda commented Sep 17, 2024

That sounds good to me, but can we add a test for this enhancement?

@cw-tan
Copy link
Author

cw-tan commented Sep 17, 2024

That sounds good to me, but can we add a test for this enhancement?

Thanks for the prompt response @Borda.

I'm thinking that _test_ddp_gather_uneven_tensors (here) and _test_ddp_gather_uneven_tensors_multidim (here) in tests/unittests/bases/test_ddp.py already cover the correctness of gather_all_tensors. I'm not sure what other ddp tests there are, but those tests should help tell us if the change I made isn't breaking existing functionality. Let me know if you had something else in mind for this.

I can make an additional unittest in tests/unittests/bases/test_ddp.py to give a tensor that requires_grad to gather_all_tensors, compute some scalar from them (proxy for a loss), and compute grads two ways (one going through the all_gather, one that doesn't) and compare. So this tests that the change achieves the desired effect. How does that sound?

Copy link

codecov bot commented Sep 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 69%. Comparing base (6bfb775) to head (b5f285d).
Report is 5 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #2754    +/-   ##
=======================================
- Coverage      69%     69%    -0%     
=======================================
  Files         329     316    -13     
  Lines       18083   17926   -157     
=======================================
- Hits        12505   12347   -158     
- Misses       5578    5579     +1     

@Borda
Copy link
Member

Borda commented Sep 17, 2024

I can make an additional unittest in tests/unittests/bases/test_ddp.py to give a tensor that requires_grad to gather_all_tensors, compute some scalar from them (proxy for a loss), and compute grads two ways (one going through the all_gather, one that doesn't) and compare. So this tests that the change achieves the desired effect. How does that sound?

yeah, that sounds good to me :)

@Borda Borda added the enhancement New feature or request label Sep 17, 2024
@cw-tan cw-tan force-pushed the all_gather_ad branch 4 times, most recently from 6c926d7 to 1d0dabe Compare September 18, 2024 02:54
@cw-tan
Copy link
Author

cw-tan commented Sep 18, 2024

Update: to accommodate both cases where tensors from different ranks have the same/different shape, the line to put the original tensor (holding the AD graph) back into the gathered list was added in two places in the code.

Because of the two cases, I wrote two unittests to account for each. Interestingly, both pass 2.X stable, but for 1.X LTS, the "same shape" test passes but "different shape" test fails, and for 1.10 oldest, the "different shape" test passes but "same shape" test fails😅. I'll double check for bugs, but the actual code change is just two lines (and all other tests pass, so existing functionality still works), and the unittests are pretty short. The dependency of the unittests passing on different torch versions seems to indicate that it might be a torch versioning issue, maybe to do with ddp behavior? Any thoughts, @Borda ?

@Borda
Copy link
Member

Borda commented Sep 19, 2024

I wrote two unittests to account for each. Interestingly, both pass 2.X stable, but for 1.X LTS, the "same shape" test passes but "different shape" test fails, and for 1.10 oldest, the "different shape" test passes but "same shape" test fails😅.

that is strange and worse some more investigation...
cc: @SkafteNicki

Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked briefly why the tests do not pass on older versions of Pytorch but could not find a reason.

I think we should just only support this for Pytorch > 2.0 and then add this to the documentation.

src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
tests/unittests/bases/test_ddp.py Outdated Show resolved Hide resolved
tests/unittests/bases/test_ddp.py Outdated Show resolved Hide resolved
@cw-tan cw-tan force-pushed the all_gather_ad branch 2 times, most recently from dc35370 to e693ace Compare October 8, 2024 16:28
@Borda Borda requested a review from SkafteNicki October 8, 2024 17:26
@cw-tan cw-tan force-pushed the all_gather_ad branch 2 times, most recently from ce5dca1 to ffc67f6 Compare October 8, 2024 18:47
Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seeems the two test functions are now included twice in the test_ddp.py file, please check

src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
@SkafteNicki SkafteNicki added this to the v1.4.x milestone Oct 9, 2024
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Oct 9, 2024
@mergify mergify bot added the ready label Oct 9, 2024
@cw-tan
Copy link
Author

cw-tan commented Oct 9, 2024

seeems the two test functions are now included twice in the test_ddp.py file, please check

ah yes, sorry about that -- probably left behind when I tried rebasing and force-pushing with the pre-commit CI commits.

Thanks for the review and changes @SkafteNicki. Unfortunately, it looks like the 2.X stable tests are failing now. This may suggest that something more subtle was happening with the failure of the torch < 2.0 tests earlier?

@SkafteNicki
Copy link
Member

seeems the two test functions are now included twice in the test_ddp.py file, please check

ah yes, sorry about that -- probably left behind when I tried rebasing and force-pushing with the pre-commit CI commits.

Thanks for the review and changes @SkafteNicki. Unfortunately, it looks like the 2.X stable tests are failing now. This may suggest that something more subtle was happening with the failure of the torch < 2.0 tests earlier?

That is strange, but yes @cw-tan it may very well be the case that this is also what caused the error in the older Pytorch versions.
I will try to find some time to debug but feel free to also take a stab at this new issue

@SkafteNicki
Copy link
Member

@cw-tan this is really strange, I am trying to debug this locally and I am seeing that the tests are failing at random. Eg. if I run them 10 times in a row I get a output from pytest like this:

FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_same_shape-1-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_same_shape-3-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_same_shape-7-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_same_shape-9-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_different_shape-5-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_different_shape-6-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_different_shape-8-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_different_shape-9-10] - AssertionError
FAILED tests/unittests/bases/test_ddp.py::test_ddp_autograd[_test_ddp_gather_autograd_different_shape-10-10] - AssertionError

with 4/10 of the "same shape" tests failing and 5/10 of the "different shape" test failing. But I cannot see there is any randomization going on in the tests?

@mergify mergify bot removed the ready label Oct 10, 2024
@cw-tan
Copy link
Author

cw-tan commented Oct 10, 2024

@SkafteNicki indeed this is an odd one. Though adding the with torch.no_grad(): in my recent commit only had the "different shape" test failing -- were both "same shape" and "different shape" tests failing before? I'm thinking maybe covering more of the code with torch.no_grad() except for the parts we want for the autograd graph to be propagated might be worth a try. Though I don't actually know why it would help a priori.

@cw-tan
Copy link
Author

cw-tan commented Oct 11, 2024

@SkafteNicki sorry for the mess, I'm just trying to use the CI tests on all torch versions again but hopefully incorporating several trials (to check for indeterminism) and with the no_grad changes I made.

@SkafteNicki
Copy link
Member

@SkafteNicki sorry for the mess, I'm just trying to use the CI tests on all torch versions again but hopefully incorporating several trials (to check for indeterminism) and with the no_grad changes I made.

@cw-tan That is completely okay whatever it takes to debug the issue. If you want to locally to run tests multiple times i recommend installing:
https://pypi.org/project/pytest-repeat/
and then running pytest command with
pytest --count=X
for X repeating evaluations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request has conflicts
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Autograd with DDP
3 participants