-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Incorrect derivative when loops contains continue #833
Fix Incorrect derivative when loops contains continue #833
Conversation
clang-tidy review says "All clean, LGTM! 👍" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the pull-request. The solution seems good. Can you please add tests and update the pull-request description?
Also, can you please open an issue regarding fixing incorrect derivatives when loops contains break
? You can use this test case:
#include "clad/Differentiator/Differentiator.h"
#include <iostream>
#define show(x) std::cout << #x << ": " << x << "\n";
double fn(double u, double v) {
double res = 0;
for (int i = 1; i < 3; i++) {
res += i * u;
if (i == 1)
break;
}
return res;
}
int main() {
auto d_fn = clad::gradient(fn);
double u = 3, v = 5;
double du, dv;
du = dv = 0;
d_fn.execute(u, v, &du, &dv);
show(du);
show(dv);
}
activeBreakContHandler->EndCFSwitchStmtScope(); | ||
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); | ||
PopBreakContStmtHandler(); | ||
|
||
// Increment statement in the for-loop is only executed if the iteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update this comment.
I will update the tests, the description and the comment.
Sure thing! |
Upon checking out the above example, the error occurs after the change of this PR right? I noticed that the decrement of i-- should be after the switch case when there's a break stmt as otherwise we "miss" a value of i. I will open a new PR for the new issue after this PR is merged. |
Yes, however, the
In the case of a |
I have not forgotten this. I will update it in the upcoming week. Sorry for the delay. |
clang-tidy review says "All clean, LGTM! 👍" |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #833 +/- ##
==========================================
+ Coverage 94.42% 94.43% +0.01%
==========================================
Files 50 50
Lines 8729 8751 +22
==========================================
+ Hits 8242 8264 +22
Misses 487 487
|
clang-tidy review says "All clean, LGTM! 👍" |
1 similar comment
clang-tidy review says "All clean, LGTM! 👍" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a test case where Clad was giving incorrect derivative result before?
clang-tidy review says "All clean, LGTM! 👍" |
@parth-07 I tried to add a new function in Loops.C, but even though I get the correct result when I compile locally this file and run it, somehow when I test using Update: the problem lies in TBR analysis, the increment is not included. I'm working on it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
CreateCFTapeSizeExprForCurrentCase() { | ||
return m_RMV.BuildOp(BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), | ||
ConstantFolder::synthesizeLiteral( | ||
m_RMV.m_Context.IntTy, m_RMV.m_Context, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: argument comment missing for literal argument 'val' [bugprone-argument-comment]
m_RMV.m_Context.IntTy, m_RMV.m_Context, 0)); | |
m_RMV.m_Context.IntTy, m_RMV.m_Context, /*val=*/0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use the new ZeroInit interface here that you wrote?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this doesn't refer to an initialization, I think ConstantFolder::synthesizeLiteral
is a better fit. getZeroInit
calls this function internally either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add tests as well?
56b0c9f
to
6f8bc21
Compare
Can we undo the formatting changes which are not related to the changes of this pr? |
70e3cc5
to
6f8bc21
Compare
Yes sorry about that, I was trying to fix it, cause my local version of git-clang-format produces different results than the CI. I think now it includes only the necessary changes more or less. |
@kchristin22 Can you please add tests in the PR? |
0622a9b
to
b4917c0
Compare
@@ -2081,19 +2086,19 @@ double fn26(double i, double j) { | |||
// CHECK-NEXT: if (!_t0) | |||
// CHECK-NEXT: break; | |||
// CHECK-NEXT: } | |||
// CHECK-NEXT: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PetroZarytskyi @ovdiiuv without using TBR, this follows the pattern of avoiding performing the increment differentiation step when this is the case of the break branch. So, this is derived like so when TBR analysis is not performed:
if (clad::size(_t3) != 0 && clad::back(_t3) != 1) {
res = clad::pop(_t2);
double _r_d1 = _d_res;
_d_res = 0.;
*_d_i += 7 * _r_d1 * j;
*_d_j += 7 * i * _r_d1;
_d_c += 0;
--c;
}
It seems that when the increment's diff is a compound stmt this is instead:
{
res = clad::pop(_t2);
double _r_d1 = _d_res;
_d_res = 0.;
*_d_i += 7 * _r_d1 * j;
*_d_j += 7 * i * _r_d1;
_d_c += 0;
--c;
}
Is this intentional? Should I open an issue for this?
CreateCFTapeSizeExprForCurrentCase() { | ||
return m_RMV.BuildOp(BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(), | ||
ConstantFolder::synthesizeLiteral( | ||
m_RMV.m_Context.IntTy, m_RMV.m_Context, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this doesn't refer to an initialization, I think ConstantFolder::synthesizeLiteral
is a better fit. getZeroInit
calls this function internally either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
m_CurrentBreakFlagExpr = | ||
BuildOp(BinaryOperatorKind::BO_LOr, | ||
BuildOp(BinaryOperatorKind::BO_NE, revCounter, | ||
BuildDeclRef(loopCounter.getNumRevIterations())), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: initializing non-owner 'DefaultStmt *' with a newly created 'gsl::owner<>' [cppcoreguidelines-owning-memory]
auto* newDefaultStmt =
^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
// Increment statement in the for-loop is only executed if the iteration | ||
// did not end with a break/continue statement. Therefore, forLoopIncDiff | ||
// should be inside the last switch case in the reverse pass. | ||
activeBreakContHandler->EndCFSwitchStmtScope(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: use auto when initializing with a template cast to avoid duplicating the type name [modernize-use-auto]
activeBreakContHandler->EndCFSwitchStmtScope(); | |
auto* forwardSS = |
e1745be
to
348b3d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
} | ||
|
||
/// Returns the number of reverse iterations to be executed. | ||
clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: function 'getNumRevIterations' should be marked [[nodiscard]] [modernize-use-nodiscard]
clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; } | |
[[nodiscard]] clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
05da570
to
1bae2cd
Compare
For more details please see #710 (comment). This PR concerns only the fix of #710, so the only change needed is the one regarding the position of the step of the forward loop's variable in the reverse pass. Tests need to be updated accordingly if the change is approved.
Closes #710, closes #851.