Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign of loop's body in reverse pass #835

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

kchristin22
Copy link
Collaborator

This PR closes #832. What these changes establish is a cleaner switch body, where the cases are independent and indexed correctly, AST speaking.

This comment contains a comparison of the current version and the proposed one. Cases of nested continue statements have also been tested and the derived body and the results were correct.

The workflow in a nutshell is the following:

  • Replace the then or else stmt of if with the case one when there's a continue stmt in its body
  • The branch of if that doesn't include a continue stmt is added firstly as it belongs to the previous switch case (outer one)
  • After visiting the nodes of the loop's body in the reverse pass, the case stmts are brought forward, on the first level of the loop's body, by performing a Depth-First search on its compounds
  • The compounds/stmts that are on the same level as the case stmt and that follow it (belong to this case) are also brought foward, while the rest of the stmts remain to their level and later appended in the beginning of the loop's body (as the main loop's body compound)
  • At that point we add the switch case of the main body in the beginning and set the stmts present between two cases as a substmt to the first of them

Keep in mind that since the compilation ends up in a completely different result compared to the current implementation in cases were there's a continue stmt, the tests have to be updated (so they're expected to currently fail).

If the changes are well received, support for the break stmt in a similar manner only requires an additional variable.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 10 out of 18. Check the log or trigger a new build to see more.

@@ -58,6 +58,8 @@ namespace clad {
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool hasContStmt = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'hasContStmt' has protected visibility [cppcoreguidelines-non-private-member-variables-in-classes]

    bool hasContStmt = false;
         ^

@@ -58,6 +58,8 @@
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool hasContStmt = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: invalid case style for protected member 'hasContStmt' [readability-identifier-naming]

Suggested change
bool hasContStmt = false;
bool m_hasContStmt = false;

/// because we need to register all the switch cases later with the
/// switch statement that will be used to manage the control flow in
/// the reverse block.
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_SwitchCases' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
                                               ^

/// the reverse block.
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;

ReverseModeVisitor& m_RMV;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_RMV' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      ReverseModeVisitor& m_RMV;
                          ^

void ReverseModeVisitor::AppendCaseStmts(llvm::SmallVectorImpl<Stmt*>& curBlock,
llvm::SmallVectorImpl<Stmt*>& cases,
Stmt* S, bool& afterCase) {
if (auto CS = dyn_cast_or_null<CompoundStmt>(S)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto CS' can be declared as 'auto *CS' [llvm-qualified-auto]

Suggested change
if (auto CS = dyn_cast_or_null<CompoundStmt>(S)) {
if (auto *CS = dyn_cast_or_null<CompoundStmt>(S)) {

// hence, we store the original flag's value
SaveAndRestore<bool> SaveAfterCase(afterCase);
afterCase = false;
for (auto stmt : CS->body())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto stmt' can be declared as 'auto *stmt' [llvm-qualified-auto]

Suggested change
for (auto stmt : CS->body())
for (auto *stmt : CS->body())

AppendCaseStmts(newBlock, cases, stmt, afterCase);
if (!newBlock.empty()){
auto Stmts_ref = clad_compat::makeArrayRef(newBlock.data(), newBlock.size());
auto newCS = clad_compat::CompoundStmt_Create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto newCS' can be declared as 'auto *newCS' [llvm-qualified-auto]

Suggested change
auto newCS = clad_compat::CompoundStmt_Create(
auto *newCS = clad_compat::CompoundStmt_Create(

} else if (isa<CaseStmt>(S)) {
afterCase = true;
cases.push_back(S);
} else if (auto If = dyn_cast_or_null<IfStmt>(S)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto If' can be declared as 'auto *If' [llvm-qualified-auto]

Suggested change
} else if (auto If = dyn_cast_or_null<IfStmt>(S)) {
} else if (auto *If = dyn_cast_or_null<IfStmt>(S)) {

afterCase = true;
cases.push_back(S);
} else if (auto If = dyn_cast_or_null<IfStmt>(S)) {
if (auto IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto IfThenCS' can be declared as 'auto *IfThenCS' [llvm-qualified-auto]

Suggested change
if (auto IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) {
if (auto *IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) {

Stmts thenBlock;
SaveAndRestore<bool> SaveAfterCase(afterCase);
afterCase = false;
for (auto stmt : IfThenCS->body())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto stmt' can be declared as 'auto *stmt' [llvm-qualified-auto]

Suggested change
for (auto stmt : IfThenCS->body())
for (auto *stmt : IfThenCS->body())

Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc,
m_Sema.ActOnNullStmt(noLoc).get(), noLoc, elseDiff.getStmt_dx());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we invert the condition instead of using a null stmt?
e.g.

 if(!cond)
   *else stmt*

instead of

 if(cond)
   ;
 else
   *else stmt*

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this approach before but the way I believe this would turn out is checking for each binary condition operator (<=, >=, <, >, ==, !=) and replacing it accordingly. I will look into it more though, thanks for the comment.

Copy link
Collaborator

@parth-07 parth-07 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the pull request. We need to verify it on more cases and discuss it further to make this change ready to merge, as I am currently not confident about the correctness of this design change. The below example fails when Clad is built using the changes in this pull request:

#include "clad/Differentiator/Differentiator.h"
#include <iostream>

#define show(x) std::cout << #x << ": " << x << "\n";

double fn(double u, double v) {
    double res = 0;
    for (int i = 0; i < 3; i++) {
        res += u;
        if (i == 1)
            continue;
    }
    return res;
}

int main() {
    auto d_fn = clad::gradient(fn);
    double u = 3, v = 5;
    double du, dv;
    du = dv = 0;
    d_fn.execute(u, v, &du, &dv);
    show(du);
    show(dv);
}

Output:
du: 1
dv: 0

Expected output:
du: 3
dv: 0

Clad on master gives correct result.


where the cases are independent and indexed correctly, AST speaking.

In the current design, is the indexing of case statements invalid as per C++ standard?

@kchristin22
Copy link
Collaborator Author

Thank you for the pull request. We need to verify it on more cases and discuss it further to make this change ready to merge, as I am currently not confident about the correctness of this design change. The below example fails when Clad is built using the changes in this pull request:

#include "clad/Differentiator/Differentiator.h"
#include <iostream>

#define show(x) std::cout << #x << ": " << x << "\n";

double fn(double u, double v) {
    double res = 0;
    for (int i = 0; i < 3; i++) {
        res += u;
        if (i == 1)
            continue;
    }
    return res;
}

int main() {
    auto d_fn = clad::gradient(fn);
    double u = 3, v = 5;
    double du, dv;
    du = dv = 0;
    d_fn.execute(u, v, &du, &dv);
    show(du);
    show(dv);
}

Output: du: 1 dv: 0

Expected output: du: 3 dv: 0

Clad on master gives correct result.

where the cases are independent and indexed correctly, AST speaking.

In the current design, is the indexing of case statements invalid as per C++ standard?

Yes, I completely agree that more tests need to be executed to verify that all cases are correctly derived. May I suggest creating a table to gather all the possible cases that may occur in order to make sure that we have not missed any? I can start working on it and append whatever else you and @PetroZarytskyi would like.

Thank you for noticing it. I indeed did not cover the case where multiple cases share some stmts. Right now the approach I have in mind is keeping track of the stmts added in each level (this is the curBlock list probably at each level) till we hit a case and append them to the case hit. So it would look sth like this:
image

However I understand that this is not very efficient, and if I tried to emulate a fall-through approach it would look similar to the original approach. I will have to investigate further.

@parth-07
Copy link
Collaborator

May I suggest creating a table to gather all the possible cases that may occur in order to make sure that we have not missed any? I can start working on it and append whatever else you and @PetroZarytskyi would like.

I would suggest to wait because this feature is not a priority right now. This feature would be important to resolve if the current indexing of case statements is invalid as per the C++ standard. By current, I mean the implementation in the master branch. I know the indexing of case statements in the current implementation is quirky and complex but we should be okay as long as it does not violate C++ standard.

Right now the approach I have in mind is keeping track of the stmts added in each level (this is the curBlock list probably at each level) till we hit a case and append them to the case hit. So it would look sth like this:

Implementing this seems a little complicated and can lead to a significant code-size increase.

@kchristin22
Copy link
Collaborator Author

Yes I agree, we can revisit this if I find another way to do this using fall-through. Feel free to add this as a Draft PR or even close it if it helps in scheduling the reviews of PRs.

@kchristin22 kchristin22 marked this pull request as draft November 19, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consider a redesign of the for loop's body in reverse pass
3 participants