From 2aa5dcb1ad85bd39e4c56c2729bfae299e0afd74 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Tue, 19 Dec 2023 23:48:27 +0530 Subject: [PATCH] Add initial support for pointers in reverse mode This commit adds support for pointer operation in reverse mode. The technique is maintain a corresponding derivative pointer variable, which gets updated (and stored/restored) in the exact same way as the primal pointer variable in both forward and reverse passes. Added a workaround (with a FIXME comment) in the UsefulToStoreGlobal method to essentially bypass TBR analysis results for pointer expr. Fixes #195, #197 --- include/clad/Differentiator/ArrayRef.h | 3 + include/clad/Differentiator/VisitorBase.h | 34 ++++ lib/Differentiator/BaseForwardModeVisitor.cpp | 46 ----- lib/Differentiator/CladUtils.cpp | 3 +- lib/Differentiator/ReverseModeVisitor.cpp | 154 ++++++++++++---- lib/Differentiator/VisitorBase.cpp | 16 ++ test/FirstDerivative/UnsupportedOpsWarn.C | 14 -- test/Gradient/Pointers.C | 172 ++++++++++++++++++ 8 files changed, 341 insertions(+), 101 deletions(-) diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index efc227522..8fbe4155a 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -33,6 +33,9 @@ template class array_ref { /// Constructor for clad::array types CUDA_HOST_DEVICE array_ref(array& a) : m_arr(a.ptr()), m_size(a.size()) {} + /// Operator for conversion from array_ref to T*. + CUDA_HOST_DEVICE operator T*() { return m_arr; } + template CUDA_HOST_DEVICE array_ref& operator=(const array& a) { assert(m_size == a.size()); diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 09a7af293..4af3a66bd 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -577,6 +577,40 @@ namespace clad { /// Cloning types is necessary since VariableArrayType /// store a pointer to their size expression. clang::QualType CloneType(clang::QualType T); + + /// Computes effective derivative operands. It should be used when operands + /// might be of pointer types. + /// + /// In the trivial case, both operands are of non-pointer types, and the + /// effective derivative operands are `LDiff.getExpr_dx()` and + /// `RDiff.getExpr_dx()` respectively. + /// + /// Integers used in pointer arithmetic should be considered + /// non-differentiable entities. For example: + /// + /// ``` + /// p + i; + /// ``` + /// + /// Derived statement should be: + /// + /// ``` + /// _d_p + i; + /// ``` + /// + /// instead of: + /// + /// ``` + /// _d_p + _d_i; + /// ``` + /// + /// Therefore, effective derived expression of `i` is `i` instead of `_d_i`. + /// + /// This functions sets `derivedL` and `derivedR` arguments to effective + /// derived expressions. + static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, + clang::Expr*& derivedL, + clang::Expr*& derivedR); }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 986b61bdf..ff08e68b2 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1344,52 +1344,6 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } } -/// Computes effective derivative operands. It should be used when operands -/// might be of pointer types. -/// -/// In the trivial case, both operands are of non-pointer types, and the -/// effective derivative operands are `LDiff.getExpr_dx()` and -/// `RDiff.getExpr_dx()` respectively. -/// -/// Integers used in pointer arithmetic should be considered -/// non-differentiable entities. For example: -/// -/// ``` -/// p + i; -/// ``` -/// -/// Derived statement should be: -/// -/// ``` -/// _d_p + i; -/// ``` -/// -/// instead of: -/// -/// ``` -/// _d_p + _d_i; -/// ``` -/// -/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`. -/// -/// This functions sets `derivedL` and `derivedR` arguments to effective -/// derived expressions. -static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, - clang::Expr*& derivedL, - clang::Expr*& derivedR) { - derivedL = LDiff.getExpr_dx(); - derivedR = RDiff.getExpr_dx(); - if (utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType()) && - !utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType())) { - derivedL = LDiff.getExpr_dx(); - derivedR = RDiff.getExpr(); - } else if (utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType()) && - !utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType())) { - derivedL = LDiff.getExpr(); - derivedR = RDiff.getExpr_dx(); - } -} - StmtDiff BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { StmtDiff Ldiff = Visit(BinOp->getLHS()); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index a49d52477..164f5ddda 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -584,7 +584,8 @@ namespace clad { /// be more complex than just a DeclRefExpr. /// (e.g. `__real (n++ ? z1 : z2)`) m_Exprs.push_back(UnOp); - } + } else if (opCode == UnaryOperatorKind::UO_Deref) + m_Exprs.push_back(UnOp); } void VisitDeclRefExpr(clang::DeclRefExpr* DRE) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0b3d53d0f..ebe153fd1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1340,7 +1340,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Create the (_d_param[idx] += dfdx) statement. if (dfdx()) { // FIXME: not sure if this is generic. - // Don't update derivatives of non-record types. + // Don't update derivatives of record types. if (!VD->getType()->isRecordType()) { auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); // Add it to the body statements. @@ -2035,6 +2035,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If it is a post-increment/decrement operator, its result is a reference // and we should return it. Expr* ResultRef = nullptr; + + // For increment/decrement of pointer, perform the same on the + // derivative pointer also. + bool isPointerOp = E->getType()->isPointerType(); + if (opCode == UO_Plus) // xi = +xj // dxi/dxj = +1.0 @@ -2048,10 +2053,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = Visit(E, d); } else if (opCode == UO_PostInc || opCode == UO_PostDec) { diff = Visit(E, dfdx()); + if (isPointerOp) + addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()), + direction::forward); if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) { auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); + if (isPointerOp) + addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse); } ResultRef = diff.getExpr_dx(); @@ -2060,10 +2070,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); } else if (opCode == UO_PreInc || opCode == UO_PreDec) { diff = Visit(E, dfdx()); + if (isPointerOp) + addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()), + direction::forward); if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) { auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); + if (isPointerOp) + addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse); } auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add : BinaryOperatorKind::BO_Sub; @@ -2081,35 +2096,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Add it to the body statements. addToCurrentBlock(add_assign, direction::reverse); } - } else { - // FIXME: This is not adding 'address-of' operator support. - // This is just making this special case differentiable that is required - // for computing hessian: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // This code snippet should be removed once reverse mode officially - // supports pointers. - if (opCode == UnaryOperatorKind::UO_AddrOf) { - if (const auto* MD = dyn_cast(m_Function)) { - if (MD->isInstance()) { - auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); - if (utils::SameCanonicalType(thisType, UnOp->getType())) { - diff = Visit(E); - Expr* cloneE = - BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr()); - Expr* derivedE = - BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx()); - return {cloneE, derivedE}; - } - } + } else if (opCode == UnaryOperatorKind::UO_AddrOf) { + diff = Visit(E); + Expr* cloneE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr()); + Expr* derivedE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx()); + return {cloneE, derivedE}; + } else if (opCode == UnaryOperatorKind::UO_Deref) { + diff = Visit(E); + Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); + Expr* diff_dx = diff.getExpr_dx(); + bool specialDThisCase = false; + Expr* derivedE = nullptr; + if (const auto* MD = dyn_cast(m_Function)) { + if (MD->isInstance() && !diff_dx->getType()->isPointerType()) + specialDThisCase = true; // _d_this is already dereferenced. + } + if (specialDThisCase) + derivedE = diff_dx; + else { + derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx); + // Create the (target += dfdx) statement. + if (dfdx()) { + auto* add_assign = BuildOp(BO_AddAssign, derivedE, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); } } - // We should not output any warning on visiting boolean conditions - // FIXME: We should support boolean differentiation or ignore it - // completely + return {cloneE, derivedE}; + } else { if (opCode != UO_LNot) + // We should not output any warning on visiting boolean conditions + // FIXME: We should support boolean differentiation or ignore it + // completely unsupportedOpWarn(UnOp->getEndLoc()); if (isa(E)) @@ -2134,6 +2152,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // we should return it. Expr* ResultRef = nullptr; + bool isPointerOp = + L->getType()->isPointerType() || R->getType()->isPointerType(); + if (opCode == BO_Add) { // xi = xl + xr // dxi/xl = 1.0 @@ -2306,6 +2327,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* Lblock = endBlock(direction::reverse); llvm::SmallVector ExprsToStore; utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore); + + // We need to store values of derivative pointer variables in forward pass + // and restore them in reverese pass. + if (isPointerOp) { + Expr* Edx = Ldiff.getExpr_dx(); + ExprsToStore.push_back(Edx); + } + if (L->HasSideEffects(m_Context)) { Expr* E = Ldiff.getExpr(); auto* storeE = @@ -2352,24 +2381,32 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Save old value for the derivative of LHS, to avoid problems with cases // like x = x. - auto* oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", - /*forceDeclCreation=*/true); + clang::Expr* oldValue = nullptr; + + // For pointer types, no need to store old derivatives. + if (!isPointerOp) + oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", + /*forceDeclCreation=*/true); if (opCode == BO_Assign) { Rdiff = Visit(R, oldValue); valueForRevPass = Rdiff.getRevSweepAsExpr(); } else if (opCode == BO_AddAssign) { - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), - direction::reverse); + if (!isPointerOp) + addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), + direction::reverse); Rdiff = Visit(R, oldValue); - valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(), - Ldiff.getRevSweepAsExpr()); + if (!isPointerOp) + valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(), + Ldiff.getRevSweepAsExpr()); } else if (opCode == BO_SubAssign) { - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), - direction::reverse); + if (!isPointerOp) + addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), + direction::reverse); Rdiff = Visit(R, BuildOp(UO_Minus, oldValue)); - valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), - Ldiff.getRevSweepAsExpr()); + if (!isPointerOp) + valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), + Ldiff.getRevSweepAsExpr()); } else if (opCode == BO_MulAssign) { // Create a reference variable to keep the result of LHS, since it // must be used on 2 places: when storing to a global variable @@ -2427,8 +2464,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue); - // Update the derivative. - addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse); + // Update the derivative only if LHS is not a pointer type. + if (!isPointerOp) + addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), + direction::reverse); + // Output statements from Visit(L). for (auto it = Lblock_begin; it != Lblock_end; ++it) addToCurrentBlock(*it, direction::reverse); @@ -2460,6 +2500,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return BuildOp(opCode, LExpr, RExpr); } Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); + + // For pointer types. + if (isPointerOp) { + if (opCode == BO_Add || opCode == BO_Sub) { + Expr* derivedL = nullptr; + Expr* derivedR = nullptr; + ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR); + if (opCode == BO_Sub) + derivedR = BuildParens(derivedR); + return StmtDiff(op, BuildOp(opCode, derivedL, derivedR), nullptr, + valueForRevPass); + } + if (opCode == BO_Assign || opCode == BO_AddAssign || + opCode == BO_SubAssign) { + Expr* derivedL = nullptr; + Expr* derivedR = nullptr; + ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR); + addToCurrentBlock(BuildOp(opCode, derivedL, derivedR), + direction::forward); + } + } return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } @@ -2469,6 +2530,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto VDDerivedType = ComputeAdjointType(VD->getType()); bool isDerivativeOfRefType = VD->getType()->isReferenceType(); VarDecl* VDDerived = nullptr; + bool isPointerType = VD->getType()->isPointerType(); // VDDerivedInit now serves two purposes -- as the initial derivative value // or the size of the derivative array -- depending on the primal type. @@ -2529,6 +2591,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (initDiff.getExpr_dx()) VDDerivedInit = initDiff.getExpr_dx(); } + // if VD is a pointer type, then the initial value is set to the derived + // expression of the corresponding pointer type. + else if (isPointerType && VD->getInit()) { + initDiff = Visit(VD->getInit()); + if (initDiff.getExpr_dx()) + VDDerivedInit = initDiff.getExpr_dx(); + } // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. if (VDDerivedType->isRecordType()) @@ -2546,7 +2615,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType) { + if (!isDerivativeOfRefType && !isPointerType) { Expr* derivedE = BuildDeclRef(VDDerived); initDiff = StmtDiff{}; if (VD->getInit()) { @@ -2824,6 +2893,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If TBR analysis is off, assume E is useful to store. if (!enableTBR) return true; + // FIXME: currently, we allow all pointer operations to be stored. + // This is not correct, but we need to implement a more advanced analysis + // to determine which pointer operations are useful to store. + if (E->getType()->isPointerType()) + return true; auto found = m_ToBeRecorded.find(B->getBeginLoc()); return found != m_ToBeRecorded.end(); } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 80b6678c6..e3287626f 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -781,4 +781,20 @@ namespace clad { auto& TAL = specialization->getTemplateArgs(); return TAL.get(0).getAsType(); } + + void VisitorBase::ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, + clang::Expr*& derivedL, + clang::Expr*& derivedR) { + derivedL = LDiff.getExpr_dx(); + derivedR = RDiff.getExpr_dx(); + if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { + derivedL = LDiff.getExpr_dx(); + derivedR = RDiff.getExpr(); + } else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) { + derivedL = LDiff.getExpr(); + derivedR = RDiff.getExpr_dx(); + } + } } // end namespace clad diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 2a618c54d..9c58b7cf9 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -34,23 +34,9 @@ int unOpWarn_0(int x){ // CHECK-NEXT: return 0; // CHECK-NEXT: } -int unOpWarn_1(int x){ - auto pnt = &x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} - return x; -} - -// CHECK: void unOpWarn_1_grad(int x, clad::array_ref _d_x) { -// CHECK-NEXT: int *_d_pnt = 0; -// CHECK-NEXT: int *pnt = &x; -// CHECK-NEXT: goto _label0; -// CHECK-NEXT: _label0: -// CHECK-NEXT: * _d_x += 1; -// CHECK-NEXT: } - int main(){ clad::differentiate(binOpWarn_0, 0); clad::gradient(binOpWarn_1); clad::differentiate(unOpWarn_0, 0); - clad::gradient(unOpWarn_1); } diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 199c3041e..7cbb3922f 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -20,6 +20,167 @@ double nonMemFn(double i) { // CHECK-NEXT: } // CHECK-NEXT: } +double minimalPointer(double x) { + double *p; + p = &x; + *p = (*p)*(*p); + return *p; // x*x +} + +// CHECK: void minimalPointer_grad(double x, clad::array_ref _d_x) { +// CHECK-NEXT: double *_d_p = 0; +// CHECK-NEXT: double *_t0; +// CHECK-NEXT: double *_t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double *p; +// CHECK-NEXT: _t0 = p; +// CHECK-NEXT: _t1 = _d_p; +// CHECK-NEXT: _d_p = &* _d_x; +// CHECK-NEXT: p = &x; +// CHECK-NEXT: _t2 = *p; +// CHECK-NEXT: *p = *p * (*p); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: *_d_p += 1; +// CHECK-NEXT: { +// CHECK-NEXT: *p = _t2; +// CHECK-NEXT: double _r_d0 = *_d_p; +// CHECK-NEXT: double _r0 = _r_d0 * (*p); +// CHECK-NEXT: *_d_p += _r0; +// CHECK-NEXT: double _r1 = *p * _r_d0; +// CHECK-NEXT: *_d_p += _r1; +// CHECK-NEXT: *_d_p -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t0; +// CHECK-NEXT: _d_p = _t1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double arrayPointer(double* arr) { + double *p = arr; + p = p + 1; + double sum = *p; + p++; + sum += (*p)*2; + p += 1; + sum += (*p)*4; + ++p; + sum += (*p)*3; + p -= 2; + p = p - 2; + sum += 5 * (*p); + return sum; // 5*arr[0] + arr[1] + 2*arr[2] + 4*arr[3] + 3*arr[4] +} + +// CHECK: void arrayPointer_grad(double *arr, clad::array_ref _d_arr) { +// CHECK-NEXT: double *_d_p = _d_arr; +// CHECK-NEXT: double *_t0; +// CHECK-NEXT: double *_t1; +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double *_t3; +// CHECK-NEXT: double *_t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double _t6; +// CHECK-NEXT: double *_t7; +// CHECK-NEXT: double *_t8; +// CHECK-NEXT: double *_t9; +// CHECK-NEXT: double *_t10; +// CHECK-NEXT: double _t11; +// CHECK-NEXT: double *p = arr; +// CHECK-NEXT: _t0 = p; +// CHECK-NEXT: _t1 = _d_p; +// CHECK-NEXT: _d_p = _d_p + 1; +// CHECK-NEXT: p = p + 1; +// CHECK-NEXT: double sum = *p; +// CHECK-NEXT: _d_p++; +// CHECK-NEXT: p++; +// CHECK-NEXT: _t2 = sum; +// CHECK-NEXT: sum += *p * 2; +// CHECK-NEXT: _t3 = p; +// CHECK-NEXT: _t4 = _d_p; +// CHECK-NEXT: _d_p += 1; +// CHECK-NEXT: p += 1; +// CHECK-NEXT: _t5 = sum; +// CHECK-NEXT: sum += *p * 4; +// CHECK-NEXT: ++_d_p; +// CHECK-NEXT: ++p; +// CHECK-NEXT: _t6 = sum; +// CHECK-NEXT: sum += *p * 3; +// CHECK-NEXT: _t7 = p; +// CHECK-NEXT: _t8 = _d_p; +// CHECK-NEXT: _d_p -= 2; +// CHECK-NEXT: p -= 2; +// CHECK-NEXT: _t9 = p; +// CHECK-NEXT: _t10 = _d_p; +// CHECK-NEXT: _d_p = _d_p - 2; +// CHECK-NEXT: p = p - 2; +// CHECK-NEXT: _t11 = sum; +// CHECK-NEXT: sum += 5 * (*p); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t11; +// CHECK-NEXT: double _r_d3 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d3; +// CHECK-NEXT: double _r3 = _r_d3 * (*p); +// CHECK-NEXT: double _r4 = 5 * _r_d3; +// CHECK-NEXT: *_d_p += _r4; +// CHECK-NEXT: _d_sum -= _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t9; +// CHECK-NEXT: _d_p = _t10; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t7; +// CHECK-NEXT: _d_p = _t8; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t6; +// CHECK-NEXT: double _r_d2 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d2; +// CHECK-NEXT: double _r2 = _r_d2 * 3; +// CHECK-NEXT: *_d_p += _r2; +// CHECK-NEXT: _d_sum -= _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: --p; +// CHECK-NEXT: --_d_p; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t5; +// CHECK-NEXT: double _r_d1 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d1; +// CHECK-NEXT: double _r1 = _r_d1 * 4; +// CHECK-NEXT: *_d_p += _r1; +// CHECK-NEXT: _d_sum -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t3; +// CHECK-NEXT: _d_p = _t4; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t2; +// CHECK-NEXT: double _r_d0 = _d_sum; +// CHECK-NEXT: _d_sum += _r_d0; +// CHECK-NEXT: double _r0 = _r_d0 * 2; +// CHECK-NEXT: *_d_p += _r0; +// CHECK-NEXT: _d_sum -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p--; +// CHECK-NEXT: _d_p--; +// CHECK-NEXT: } +// CHECK-NEXT: *_d_p += _d_sum; +// CHECK-NEXT: { +// CHECK-NEXT: p = _t0; +// CHECK-NEXT: _d_p = _t1; +// CHECK-NEXT: } +// CHECK-NEXT: } + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -89,4 +250,15 @@ int main() { NON_MEM_FN_TEST(d_nonMemFnReinterpretCast); // CHECK-EXEC: 10.00 NON_MEM_FN_TEST(d_nonMemFnCStyleCast); // CHECK-EXEC: 10.00 + + // Pointer operation tests. + auto d_minimalPointer = clad::gradient(minimalPointer, "x"); + NON_MEM_FN_TEST(d_minimalPointer); // CHECK-EXEC: 10.00 + + auto d_arrayPointer = clad::gradient(arrayPointer, "arr"); + double arr[5] = {1, 2, 3, 4, 5}; + double d_arr[5] = {0, 0, 0, 0, 0}; + clad::array_ref d_arr_ref(d_arr, 5); + d_arrayPointer.execute(arr, d_arr_ref); + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 5.00 1.00 2.00 4.00 3.00 }