Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for pointers in reverse mode #686

Merged
merged 1 commit into from
Dec 30, 2023

Conversation

vaithak
Copy link
Collaborator

@vaithak vaithak commented Dec 19, 2023

This commit adds support for pointer operation in reverse mode. The technique is to maintain a corresponding derivative pointer variable, which gets updated (and stored/restored) in the exact same way as the primal pointer variable in both forward and reverse passes.
Added a workaround (with a FIXME comment) in the UsefulToStoreGlobal method to essentially bypass TBR analysis results for pointer expr.

Fixes #195, Fixes #197

@vaithak
Copy link
Collaborator Author

vaithak commented Dec 19, 2023

@PetroZarytskyi can you provide some comments on what extra needs to be done to make a pointer update statement to be added in the final output when enable-tbr is on?

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Show resolved Hide resolved
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vaithak vaithak marked this pull request as ready for review December 20, 2023 17:20
Copy link

codecov bot commented Dec 20, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (654faee) 94.48% compared to head (44b6c50) 94.51%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #686      +/-   ##
==========================================
+ Coverage   94.48%   94.51%   +0.02%     
==========================================
  Files          48       48              
  Lines        7091     7177      +86     
==========================================
+ Hits         6700     6783      +83     
- Misses        391      394       +3     
Files Coverage Δ
include/clad/Differentiator/VisitorBase.h 100.00% <ø> (ø)
lib/Differentiator/BaseForwardModeVisitor.cpp 98.83% <ø> (-0.02%) ⬇️
lib/Differentiator/CladUtils.cpp 97.08% <100.00%> (-0.52%) ⬇️
lib/Differentiator/ReverseModeVisitor.cpp 96.08% <100.00%> (+0.04%) ⬆️
lib/Differentiator/VisitorBase.cpp 98.05% <100.00%> (+0.10%) ⬆️

... and 1 file with indirect coverage changes

Files Coverage Δ
include/clad/Differentiator/VisitorBase.h 100.00% <ø> (ø)
lib/Differentiator/BaseForwardModeVisitor.cpp 98.83% <ø> (-0.02%) ⬇️
lib/Differentiator/CladUtils.cpp 97.08% <100.00%> (-0.52%) ⬇️
lib/Differentiator/ReverseModeVisitor.cpp 96.08% <100.00%> (+0.04%) ⬆️
lib/Differentiator/VisitorBase.cpp 98.05% <100.00%> (+0.10%) ⬆️

... and 1 file with indirect coverage changes

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Owner

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks good. Can you add more test cases to hit the codecov missing coverage report? Please add more information in the commit log to explain some of the basics. I suspect this commit can close a few bugs that were reported - can you go over them and see if they are fixed and enumerate them in the commit message with the Fixes prefix?

lib/Differentiator/VisitorBase.cpp Outdated Show resolved Hide resolved
test/FirstDerivative/UnsupportedOpsWarn.C Show resolved Hide resolved
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vaithak vaithak force-pushed the pointer-support branch 2 times, most recently from 2aa5dcb to dffbf31 Compare December 20, 2023 20:22
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved
test/Gradient/Pointers.C Outdated Show resolved Hide resolved
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

1 similar comment
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@@ -22,7 +22,7 @@ template <typename T> class array_ref {

public:
/// Delete default constructor
array_ref() = delete;
array_ref() = default;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change required?

Copy link
Collaborator Author

@vaithak vaithak Dec 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the primal contains an operation of the form arr = arr + 1, where arr is a pointer param, this will result in d_arr = d_arr + 1 statement, where d_arr is of type array_ref.
Now that we are updating this value, this means we need to store its old value and restore it in the reverse pass, thus, requiring a clad::array_ref<T> _t0 initialization, which will require a default constructor.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should not create references out of thin air.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, yes. But this is essentially the same restriction as we have for user-defined objects: #627.

@vaithak vaithak force-pushed the pointer-support branch 2 times, most recently from 52d653e to c86c6d2 Compare December 24, 2023 19:09
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

lib/Differentiator/VisitorBase.cpp Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Show resolved Hide resolved
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vgvassilev
Copy link
Owner

Can you rebase and squash commits where necessary?

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vaithak
Copy link
Collaborator Author

vaithak commented Dec 29, 2023

Can you rebase and squash commits where necessary?

updated 👍🏼

@vgvassilev vgvassilev requested a review from parth-07 December 29, 2023 16:04
Copy link
Collaborator

@parth-07 parth-07 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

// This is not correct, but we need to implement a more advanced analysis
// to determine which pointer operations are useful to store.
if (E->getType()->isPointerType())
return true;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for this branch to make code of happy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is essentially a workaround for some cases of pointers when enableTbr is used, but as mentioned in some comment above, this doesn't cover every case. Adding this to test is essentially adding check with enable-tbr flag for pointer cases, but it will pass for some test and fail for others. Hence, adding it to pointers.c test won't be possible.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well in theory we can add a tbr mode test and mark it xfail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try this out. I have very little idea about xfail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a new test for TBR analysis with pointers and marked it as xfail for now.

@parth-07
Copy link
Collaborator

A few important points to note:

  • This pull request adds clad::array_ref<T>::ptr_ref interface. We need a similar interface / functionality for clad::array<T> as well for the same reason. For example:
double arr[2] = {...};
double *u = arr + 1; // arr derivative type is clad::array, therefore, _d_arr + 1 would be element-wise addition! 
  • Pointer arithmetic in a function call argument causes a compilation crash.
fn(arr + 0, ...)

Two separate causes are responsible for the crash:

  • restore-value analysis tries to restore the value of expression arr + 0. That obviously doesn't work because arr + 0 is not an lvalue. @PetroZarytskyi We perhaps need a check to only restore the value of the function arguments that are lvalue expressions.
  • Currently, clad tries to create a dummy variable clad::array<double> _r0(_d_arr.ptr_ref() + 0). This doesn't work because clad::array cannot be initialized from just a pointer. It needs a size as well, And _d_arr.ptr_ref() + 0 doesn't have size information. One other important thing to add is that this clad::array<double> _r0(_d_arr.ptr_ref() + 0) declaration is completely useless for automatic differentiation. It's only used by the error-estimation framework.

I think we should handle these issues separately in subsequent pull requests. This pull request already contains too many significant changes.

This commit adds support for pointer operation in reverse mode. The technique is to maintain a corresponding derivative pointer variable, which gets updated (and stored/restored) in the exact same way as the primal pointer variable in both forward and reverse passes.
Added a workaround (with a FIXME comment) in the UsefulToStoreGlobal method to essentially bypass TBR analysis results for pointer expr.

Fixes vgvassilev#195, Fixes vgvassilev#197
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vgvassilev vgvassilev merged commit b80f03e into vgvassilev:master Dec 30, 2023
78 checks passed
@vaithak vaithak deleted the pointer-support branch March 13, 2024 13:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants