From d44ce3202e77da068b7b8cd5907de3a9b78eb70f Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 21 Dec 2023 01:00:39 +0530 Subject: [PATCH] Improve test coverage for pointer support --- include/clad/Differentiator/VisitorBase.h | 30 +++++++++++++++++++++ lib/Differentiator/ReverseModeVisitor.cpp | 33 ++++++++++++++--------- lib/Differentiator/VisitorBase.cpp | 30 --------------------- 3 files changed, 50 insertions(+), 43 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 4a35c5f02..4af3a66bd 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -578,6 +578,36 @@ namespace clad { /// store a pointer to their size expression. clang::QualType CloneType(clang::QualType T); + /// Computes effective derivative operands. It should be used when operands + /// might be of pointer types. + /// + /// In the trivial case, both operands are of non-pointer types, and the + /// effective derivative operands are `LDiff.getExpr_dx()` and + /// `RDiff.getExpr_dx()` respectively. + /// + /// Integers used in pointer arithmetic should be considered + /// non-differentiable entities. For example: + /// + /// ``` + /// p + i; + /// ``` + /// + /// Derived statement should be: + /// + /// ``` + /// _d_p + i; + /// ``` + /// + /// instead of: + /// + /// ``` + /// _d_p + _d_i; + /// ``` + /// + /// Therefore, effective derived expression of `i` is `i` instead of `_d_i`. + /// + /// This functions sets `derivedL` and `derivedR` arguments to effective + /// derived expressions. static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, clang::Expr*& derivedL, clang::Expr*& derivedR); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9f72d3dac..ebe153fd1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2060,6 +2060,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); + if (isPointerOp) + addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse); } ResultRef = diff.getExpr_dx(); @@ -2121,15 +2123,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } return {cloneE, derivedE}; - } else if (opCode != UO_LNot) { - // We should not output any warning on visiting boolean conditions - // FIXME: We should support boolean differentiation or ignore it - // completely - unsupportedOpWarn(UnOp->getEndLoc()); - } else if (isa(E)) - diff = Visit(E); - else - diff = StmtDiff(E); + } else { + if (opCode != UO_LNot) + // We should not output any warning on visiting boolean conditions + // FIXME: We should support boolean differentiation or ignore it + // completely + unsupportedOpWarn(UnOp->getEndLoc()); + + if (isa(E)) + diff = Visit(E); + else + diff = StmtDiff(E); + } Expr* op = BuildOp(opCode, diff.getExpr()); return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } @@ -2391,15 +2396,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), direction::reverse); Rdiff = Visit(R, oldValue); - valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(), - Ldiff.getRevSweepAsExpr()); + if (!isPointerOp) + valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(), + Ldiff.getRevSweepAsExpr()); } else if (opCode == BO_SubAssign) { if (!isPointerOp) addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), direction::reverse); Rdiff = Visit(R, BuildOp(UO_Minus, oldValue)); - valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), - Ldiff.getRevSweepAsExpr()); + if (!isPointerOp) + valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), + Ldiff.getRevSweepAsExpr()); } else if (opCode == BO_MulAssign) { // Create a reference variable to keep the result of LHS, since it // must be used on 2 places: when storing to a global variable diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index de7514243..e3287626f 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -782,36 +782,6 @@ namespace clad { return TAL.get(0).getAsType(); } - /// Computes effective derivative operands. It should be used when operands - /// might be of pointer types. - /// - /// In the trivial case, both operands are of non-pointer types, and the - /// effective derivative operands are `LDiff.getExpr_dx()` and - /// `RDiff.getExpr_dx()` respectively. - /// - /// Integers used in pointer arithmetic should be considered - /// non-differentiable entities. For example: - /// - /// ``` - /// p + i; - /// ``` - /// - /// Derived statement should be: - /// - /// ``` - /// _d_p + i; - /// ``` - /// - /// instead of: - /// - /// ``` - /// _d_p + _d_i; - /// ``` - /// - /// Therefore, effective derived expression of `i` is `i` instead of `_d_i`. - /// - /// This functions sets `derivedL` and `derivedR` arguments to effective - /// derived expressions. void VisitorBase::ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, clang::Expr*& derivedL, clang::Expr*& derivedR) {