diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index efc227522..6fd89247d 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -22,7 +22,7 @@ template class array_ref { public: /// Delete default constructor - array_ref() = delete; + array_ref() = default; /// Constructor to store the pointer to and size of an array supplied by the /// user CUDA_HOST_DEVICE array_ref(T* arr, std::size_t size) @@ -33,6 +33,9 @@ template class array_ref { /// Constructor for clad::array types CUDA_HOST_DEVICE array_ref(array& a) : m_arr(a.ptr()), m_size(a.size()) {} + /// Operator for conversion from array_ref to T*. + CUDA_HOST_DEVICE operator T*() { return m_arr; } + template CUDA_HOST_DEVICE array_ref& operator=(const array& a) { assert(m_size == a.size()); @@ -40,9 +43,16 @@ template class array_ref { m_arr[i] = a[i]; return *this; } + template + CUDA_HOST_DEVICE array_ref& operator=(const array_ref& a) { + m_arr = a.ptr(); + m_size = a.size(); + return *this; + } /// Returns the size of the underlying array CUDA_HOST_DEVICE std::size_t size() const { return m_size; } CUDA_HOST_DEVICE T* ptr() const { return m_arr; } + CUDA_HOST_DEVICE T*& ptr_ref() { return m_arr; } /// Returns an array_ref to a part of the underlying array starting at /// offset and having the specified size CUDA_HOST_DEVICE array_ref slice(std::size_t offset, std::size_t size) { diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 6e8c65006..91870f6e8 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -527,6 +527,9 @@ namespace clad { /// Creates the expression Base.size() for the given Base expr. The Base /// expr must be of clad::array_ref type clang::Expr* BuildArrayRefSizeExpr(clang::Expr* Base); + /// Creates the expression Base.ptr_ref() for the given Base expr. The Base + /// expr must be of clad::array_ref type + clang::Expr* BuildArrayRefPtrRefExpr(clang::Expr* Base); /// Checks if the type is of clad::ValueAndPushforward type bool isCladValueAndPushforwardType(clang::QualType QT); /// Creates the expression Base.slice(Args) for the given Base expr and Args @@ -591,6 +594,40 @@ namespace clad { /// Cloning types is necessary since VariableArrayType /// store a pointer to their size expression. clang::QualType CloneType(clang::QualType T); + + /// Computes effective derivative operands. It should be used when operands + /// might be of pointer types. + /// + /// In the trivial case, both operands are of non-pointer types, and the + /// effective derivative operands are `LDiff.getExpr_dx()` and + /// `RDiff.getExpr_dx()` respectively. + /// + /// Integers used in pointer arithmetic should be considered + /// non-differentiable entities. For example: + /// + /// ``` + /// p + i; + /// ``` + /// + /// Derived statement should be: + /// + /// ``` + /// _d_p + i; + /// ``` + /// + /// instead of: + /// + /// ``` + /// _d_p + _d_i; + /// ``` + /// + /// Therefore, effective derived expression of `i` is `i` instead of `_d_i`. + /// + /// This functions sets `derivedL` and `derivedR` arguments to effective + /// derived expressions. + void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, + clang::Expr*& derivedL, + clang::Expr*& derivedR); }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index d65fae998..1dc39bce1 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1346,52 +1346,6 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } } -/// Computes effective derivative operands. It should be used when operands -/// might be of pointer types. -/// -/// In the trivial case, both operands are of non-pointer types, and the -/// effective derivative operands are `LDiff.getExpr_dx()` and -/// `RDiff.getExpr_dx()` respectively. -/// -/// Integers used in pointer arithmetic should be considered -/// non-differentiable entities. For example: -/// -/// ``` -/// p + i; -/// ``` -/// -/// Derived statement should be: -/// -/// ``` -/// _d_p + i; -/// ``` -/// -/// instead of: -/// -/// ``` -/// _d_p + _d_i; -/// ``` -/// -/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`. -/// -/// This functions sets `derivedL` and `derivedR` arguments to effective -/// derived expressions. -static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, - clang::Expr*& derivedL, - clang::Expr*& derivedR) { - derivedL = LDiff.getExpr_dx(); - derivedR = RDiff.getExpr_dx(); - if (utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType()) && - !utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType())) { - derivedL = LDiff.getExpr_dx(); - derivedR = RDiff.getExpr(); - } else if (utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType()) && - !utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType())) { - derivedL = LDiff.getExpr(); - derivedR = RDiff.getExpr_dx(); - } -} - StmtDiff BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { StmtDiff Ldiff = Visit(BinOp->getLHS()); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index a49d52477..164f5ddda 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -584,7 +584,8 @@ namespace clad { /// be more complex than just a DeclRefExpr. /// (e.g. `__real (n++ ? z1 : z2)`) m_Exprs.push_back(UnOp); - } + } else if (opCode == UnaryOperatorKind::UO_Deref) + m_Exprs.push_back(UnOp); } void VisitDeclRefExpr(clang::DeclRefExpr* DRE) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 7c8f25ad4..b67f23043 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1338,7 +1338,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Create the (_d_param[idx] += dfdx) statement. if (dfdx()) { // FIXME: not sure if this is generic. - // Don't update derivatives of non-record types. + // Don't update derivatives of record types. if (!VD->getType()->isRecordType()) { auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); // Add it to the body statements. @@ -2033,6 +2033,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If it is a post-increment/decrement operator, its result is a reference // and we should return it. Expr* ResultRef = nullptr; + + // For increment/decrement of pointer, perform the same on the + // derivative pointer also. + bool isPointerOp = E->getType()->isPointerType(); + if (opCode == UO_Plus) // xi = +xj // dxi/dxj = +1.0 @@ -2046,22 +2051,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = Visit(E, d); } else if (opCode == UO_PostInc || opCode == UO_PostDec) { diff = Visit(E, dfdx()); + Expr* diff_dx = diff.getExpr_dx(); + if (isPointerOp && isCladArrayType(diff_dx->getType())) + diff_dx = BuildArrayRefPtrRefExpr(diff_dx); + if (isPointerOp) + addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward); if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) { auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); + if (isPointerOp) + addToCurrentBlock(BuildOp(op, diff_dx), direction::reverse); } - ResultRef = diff.getExpr_dx(); + ResultRef = diff_dx; valueForRevPass = diff.getRevSweepAsExpr(); if (m_ExternalSource) m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); } else if (opCode == UO_PreInc || opCode == UO_PreDec) { diff = Visit(E, dfdx()); + Expr* diff_dx = diff.getExpr_dx(); + if (isPointerOp && isCladArrayType(diff_dx->getType())) + diff_dx = BuildArrayRefPtrRefExpr(diff_dx); + if (isPointerOp) + addToCurrentBlock(BuildOp(opCode, diff_dx), direction::forward); if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) { auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc; addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())), direction::reverse); + if (isPointerOp) + addToCurrentBlock(BuildOp(op, diff_dx), direction::reverse); } auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add : BinaryOperatorKind::BO_Sub; @@ -2079,35 +2098,40 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Add it to the body statements. addToCurrentBlock(add_assign, direction::reverse); } - } else { - // FIXME: This is not adding 'address-of' operator support. - // This is just making this special case differentiable that is required - // for computing hessian: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // This code snippet should be removed once reverse mode officially - // supports pointers. - if (opCode == UnaryOperatorKind::UO_AddrOf) { - if (const auto* MD = dyn_cast(m_Function)) { - if (MD->isInstance()) { - auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); - if (utils::SameCanonicalType(thisType, UnOp->getType())) { - diff = Visit(E); - Expr* cloneE = - BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr()); - Expr* derivedE = - BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx()); - return {cloneE, derivedE}; - } - } + } else if (opCode == UnaryOperatorKind::UO_AddrOf) { + diff = Visit(E); + Expr* cloneE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr()); + Expr* derivedE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx()); + return {cloneE, derivedE}; + } else if (opCode == UnaryOperatorKind::UO_Deref) { + diff = Visit(E); + Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); + Expr* diff_dx = diff.getExpr_dx(); + bool specialDThisCase = false; + Expr* derivedE = nullptr; + if (const auto* MD = dyn_cast(m_Function)) { + if (MD->isInstance() && !diff_dx->getType()->isPointerType()) + specialDThisCase = true; // _d_this is already dereferenced. + } + if (specialDThisCase) + derivedE = diff_dx; + else { + derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx); + // Create the (target += dfdx) statement. + if (dfdx()) { + auto* add_assign = BuildOp(BO_AddAssign, derivedE, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); } } - // We should not output any warning on visiting boolean conditions - // FIXME: We should support boolean differentiation or ignore it - // completely + return {cloneE, derivedE}; + } else { if (opCode != UO_LNot) + // We should only output warnings on visiting boolean conditions + // when it is related to some indepdendent variable and causes + // discontinuity in the function space. + // FIXME: We should support boolean differentiation or ignore it + // completely unsupportedOpWarn(UnOp->getEndLoc()); if (isa(E)) @@ -2132,6 +2156,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // we should return it. Expr* ResultRef = nullptr; + bool isPointerOp = + L->getType()->isPointerType() || R->getType()->isPointerType(); + if (opCode == BO_Add) { // xi = xl + xr // dxi/xl = 1.0 @@ -2307,6 +2334,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* Lblock = endBlock(direction::reverse); llvm::SmallVector ExprsToStore; utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore); + + // We need to store values of derivative pointer variables in forward pass + // and restore them in reverse pass. + if (isPointerOp) { + Expr* Edx = Ldiff.getExpr_dx(); + ExprsToStore.push_back(Edx); + } + if (L->HasSideEffects(m_Context)) { Expr* E = Ldiff.getExpr(); auto* storeE = @@ -2353,20 +2388,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Save old value for the derivative of LHS, to avoid problems with cases // like x = x. - auto* oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", - /*forceDeclCreation=*/true); + clang::Expr* oldValue = nullptr; + + // For pointer types, no need to store old derivatives. + if (!isPointerOp) + oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", + /*forceDeclCreation=*/true); if (opCode == BO_Assign) { Rdiff = Visit(R, oldValue); valueForRevPass = Rdiff.getRevSweepAsExpr(); } else if (opCode == BO_AddAssign) { Rdiff = Visit(R, oldValue); - valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(), - Ldiff.getRevSweepAsExpr()); + if (!isPointerOp) + valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(), + Ldiff.getRevSweepAsExpr()); } else if (opCode == BO_SubAssign) { Rdiff = Visit(R, BuildOp(UO_Minus, oldValue)); - valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), - Ldiff.getRevSweepAsExpr()); + if (!isPointerOp) + valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), + Ldiff.getRevSweepAsExpr()); } else if (opCode == BO_MulAssign) { // Create a reference variable to keep the result of LHS, since it // must be used on 2 places: when storing to a global variable @@ -2459,6 +2500,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return BuildOp(opCode, LExpr, RExpr); } Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); + + // For pointer types. + if (isPointerOp) { + if (opCode == BO_Add || opCode == BO_Sub) { + Expr* derivedL = nullptr; + Expr* derivedR = nullptr; + ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR); + if (opCode == BO_Sub) + derivedR = BuildParens(derivedR); + return StmtDiff(op, BuildOp(opCode, derivedL, derivedR), nullptr, + valueForRevPass); + } + if (opCode == BO_Assign || opCode == BO_AddAssign || + opCode == BO_SubAssign) { + Expr* derivedL = nullptr; + Expr* derivedR = nullptr; + ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR); + addToCurrentBlock(BuildOp(opCode, derivedL, derivedR), + direction::forward); + } + } return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } @@ -2468,6 +2530,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto VDDerivedType = ComputeAdjointType(VD->getType()); bool isDerivativeOfRefType = VD->getType()->isReferenceType(); VarDecl* VDDerived = nullptr; + bool isPointerType = VD->getType()->isPointerType(); // VDDerivedInit now serves two purposes -- as the initial derivative value // or the size of the derivative array -- depending on the primal type. @@ -2528,6 +2591,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (initDiff.getExpr_dx()) VDDerivedInit = initDiff.getExpr_dx(); } + // if VD is a pointer type, then the initial value is set to the derived + // expression of the corresponding pointer type. + else if (isPointerType && VD->getInit()) { + initDiff = Visit(VD->getInit()); + VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); + // If it's a pointer to a constant type, then remove the constness. + if (VD->getType()->getPointeeType().isConstQualified()) { + // first extract the pointee type + auto pointeeType = VD->getType()->getPointeeType(); + // then remove the constness + pointeeType.removeLocalConst(); + // then create a new pointer type with the new pointee type + VDDerivedType = m_Context.getPointerType(pointeeType); + } + VDDerivedInit = getZeroInit(VDDerivedType); + } // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. if (VDDerivedType->isRecordType()) @@ -2545,7 +2624,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType) { + if (!isDerivativeOfRefType && !isPointerType) { Expr* derivedE = BuildDeclRef(VDDerived); initDiff = StmtDiff{}; if (VD->getInit()) { @@ -2611,6 +2690,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); } } + if (isPointerType) { + Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, + derivedVDE, initDiff.getExpr_dx()); + addToCurrentBlock(assignDerivativeE, direction::forward); + if (isInsideLoop) { + auto tape = MakeCladTapeFor(derivedVDE); + addToCurrentBlock(tape.Push); + auto* reverseSweepDerivativePointerE = + BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); + m_LoopBlock.back().push_back( + BuildDeclStmt(reverseSweepDerivativePointerE)); + derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE); + } + } m_Variables.emplace(VDClone, derivedVDE); return VarDeclDiff(VDClone, VDDerived); @@ -2827,6 +2920,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If TBR analysis is off, assume E is useful to store. if (!enableTBR) return true; + // FIXME: currently, we allow all pointer operations to be stored. + // This is not correct, but we need to implement a more advanced analysis + // to determine which pointer operations are useful to store. + if (E->getType()->isPointerType()) + return true; auto found = m_ToBeRecorded.find(B->getBeginLoc()); return found != m_ToBeRecorded.end(); } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 3de0394e9..5b0cf2507 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -713,6 +713,10 @@ namespace clad { return BuildCallExprToMemFn(Base, /*MemberFunctionName=*/"slice", Args); } + Expr* VisitorBase::BuildArrayRefPtrRefExpr(Expr* Base) { + return BuildCallExprToMemFn(Base, /*MemberFunctionName=*/"ptr_ref", {}); + } + bool VisitorBase::isCladArrayType(QualType QT) { // FIXME: Replace this check with a clang decl check return QT.getAsString().find("clad::array") != std::string::npos || @@ -843,4 +847,30 @@ namespace clad { auto& TAL = specialization->getTemplateArgs(); return TAL.get(0).getAsType(); } + + void VisitorBase::ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, + clang::Expr*& derivedL, + clang::Expr*& derivedR) { + derivedL = LDiff.getExpr_dx(); + derivedR = RDiff.getExpr_dx(); + if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && + utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { + if (isCladArrayType(derivedL->getType())) + derivedL = BuildArrayRefPtrRefExpr(derivedL); + if (isCladArrayType(derivedR->getType())) + derivedR = BuildArrayRefPtrRefExpr(derivedR); + } else if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { + derivedL = LDiff.getExpr_dx(); + if (isCladArrayType(derivedL->getType())) + derivedL = BuildArrayRefPtrRefExpr(derivedL); + derivedR = RDiff.getExpr(); + } else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) { + derivedL = LDiff.getExpr(); + derivedR = RDiff.getExpr_dx(); + if (isCladArrayType(derivedR->getType())) + derivedR = BuildArrayRefPtrRefExpr(derivedR); + } + } } // end namespace clad diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 2a618c54d..9c58b7cf9 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -34,23 +34,9 @@ int unOpWarn_0(int x){ // CHECK-NEXT: return 0; // CHECK-NEXT: } -int unOpWarn_1(int x){ - auto pnt = &x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} - return x; -} - -// CHECK: void unOpWarn_1_grad(int x, clad::array_ref _d_x) { -// CHECK-NEXT: int *_d_pnt = 0; -// CHECK-NEXT: int *pnt = &x; -// CHECK-NEXT: goto _label0; -// CHECK-NEXT: _label0: -// CHECK-NEXT: * _d_x += 1; -// CHECK-NEXT: } - int main(){ clad::differentiate(binOpWarn_0, 0); clad::gradient(binOpWarn_1); clad::differentiate(unOpWarn_0, 0); - clad::gradient(unOpWarn_1); } diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index e4bd9fbb1..0ea8da25f 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -1,9 +1,10 @@ // RUN: %cladclang %s -I%S/../../include -oPointers.out 2>&1 | FileCheck %s // RUN: ./Pointers.out | FileCheck -check-prefix=CHECK-EXEC %s -// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPointers.out -// RUN: ./Pointers.out | FileCheck -check-prefix=CHECK-EXEC %s // CHECK-NOT: {{.*error|warning|note:.*}} +// FIXME: This test does not work with enable-tbr flag, because the +// current implementation of TBR analysis doesn't support pointers. + #include "clad/Differentiator/Differentiator.h" double nonMemFn(double i) { @@ -18,6 +19,322 @@ double nonMemFn(double i) { // CHECK-NEXT: } // CHECK-NEXT: } +double minimalPointer(double x) { + double* const p = &x; + *p = (*p)*(*p); + return *p; // x*x +} + +// CHECK: void minimalPointer_grad(double x, clad::array_ref _d_x) { +// CHECK-NEXT: double *_d_p = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: _d_p = &* _d_x; +// CHECK-NEXT: double *const p = &x; +// CHECK-NEXT: _t0 = *p; +// CHECK-NEXT: *p = *p * (*p); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: *_d_p += 1; +// CHECK-NEXT: { +// CHECK-NEXT: *p = _t0; +// CHECK-NEXT: double _r_d0 = *_d_p; +// CHECK-NEXT: *_d_p += _r_d0 * (*p); +// CHECK-NEXT: *_d_p += *p * _r_d0; +// CHECK-NEXT: *_d_p -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double arrayPointer(const double* arr) { + const double *p = arr; + p = p + 1; + double sum = *p; + p++; + sum += (*p)*2; + p += 1; + sum += (*p)*4; + ++p; + sum += (*p)*3; + p -= 2; + p = p - 2; + sum += 5 * (*p); + return sum; // 5*arr[0] + arr[1] + 2*arr[2] + 4*arr[3] + 3*arr[4] +} + +// CHECK: void arrayPointer_grad(const double *arr, clad::array_ref _d_arr) { +// CHECK-NEXT: double *_d_p = 0; +// CHECK-NEXT: const double *_t0; +// CHECK-NEXT: double *_t1; +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: const double *_t3; +// CHECK-NEXT: double *_t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double _t6; +// CHECK-NEXT: const double *_t7; +// CHECK-NEXT: double *_t8; +// CHECK-NEXT: const double *_t9; +// CHECK-NEXT: double *_t10; +// CHECK-NEXT: double _t11; +// CHECK-NEXT: _d_p = _d_arr; +// CHECK-NEXT: const double *p = arr; +// CHECK-NEXT: _t0 = p; +// CHECK-NEXT: _t1 = _d_p; +// CHECK-NEXT: _d_p = _d_p + 1; +// CHECK-NEXT: p = p + 1; +// CHECK-NEXT: double sum = *p; +// CHECK-NEXT: _d_p++; +// CHECK-NEXT: p++; +// CHECK-NEXT: _t2 = sum; +// CHECK-NEXT: sum += *p * 2; +// CHECK-NEXT: _t3 = p; +// CHECK-NEXT: _t4 = _d_p; +// CHECK-NEXT: _d_p += 1; +// CHECK-NEXT: p += 1; +// CHECK-NEXT: _t5 = sum; +// CHECK-NEXT: sum += *p * 4; +// CHECK-NEXT: ++_d_p; +// CHECK-NEXT: ++p; +// CHECK-NEXT: _t6 = sum; +// CHECK-NEXT: sum += *p * 3; +// CHECK-NEXT: _t7 = p; +// CHECK-NEXT: _t8 = _d_p; +// CHECK-NEXT: _d_p -= 2; +// CHECK-NEXT: p -= 2; +// CHECK-NEXT: _t9 = p; +// CHECK-NEXT: _t10 = _d_p; +// CHECK-NEXT: _d_p = _d_p - 2; +// CHECK-NEXT: p = p - 2; +// CHECK-NEXT: _t11 = sum; +// CHECK-NEXT: sum += 5 * (*p); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t11; +// CHECK-NEXT: double _r_d3 = _d_sum; +// CHECK-NEXT: *_d_p += 5 * _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t9; +// CHECK-NEXT: _d_p = _t10; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t7; +// CHECK-NEXT: _d_p = _t8; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t6; +// CHECK-NEXT: double _r_d2 = _d_sum; +// CHECK-NEXT: *_d_p += _r_d2 * 3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: --p; +// CHECK-NEXT: --_d_p; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t5; +// CHECK-NEXT: double _r_d1 = _d_sum; +// CHECK-NEXT: *_d_p += _r_d1 * 4; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p = _t3; +// CHECK-NEXT: _d_p = _t4; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t2; +// CHECK-NEXT: double _r_d0 = _d_sum; +// CHECK-NEXT: *_d_p += _r_d0 * 2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: p--; +// CHECK-NEXT: _d_p--; +// CHECK-NEXT: } +// CHECK-NEXT: *_d_p += _d_sum; +// CHECK-NEXT: { +// CHECK-NEXT: p = _t0; +// CHECK-NEXT: _d_p = _t1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double pointerParam(const double* arr, size_t n) { + double sum = 0; + for (size_t i=0; i < n; ++i) { + size_t* j = &i; + sum += arr[0] * (*j); + arr = arr + 1; + } + return sum; +} + +// CHECK: void pointerParam_grad_0(const double *arr, size_t n, clad::array_ref _d_arr) { +// CHECK-NEXT: size_t _d_n = 0; +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: size_t _d_i = 0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: size_t *_d_j = 0; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape > _t5 = {}; +// CHECK-NEXT: double sum = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: for (size_t i = 0; i < n; ++i) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: _d_j = &_d_i; +// CHECK-NEXT: clad::push(_t1, _d_j); +// CHECK-NEXT: size_t *j = &i; +// CHECK-NEXT: clad::push(_t3, sum); +// CHECK-NEXT: sum += arr[0] * (*j); +// CHECK-NEXT: clad::push(_t4, arr); +// CHECK-NEXT: clad::push(_t5, _d_arr); +// CHECK-NEXT: _d_arr.ptr_ref() = _d_arr.ptr_ref() + 1; +// CHECK-NEXT: arr = arr + 1; +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: --i; +// CHECK-NEXT: size_t *_t2 = clad::pop(_t1); +// CHECK-NEXT: { +// CHECK-NEXT: arr = clad::pop(_t4); +// CHECK-NEXT: _d_arr = clad::pop(_t5); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = clad::pop(_t3); +// CHECK-NEXT: double _r_d0 = _d_sum; +// CHECK-NEXT: _d_arr[0] += _r_d0 * (*j); +// CHECK-NEXT: *_t2 += arr[0] * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double pointerMultipleParams(const double* a, const double* b) { + double sum = b[2]; + b = a; + a = 1+a; + ++b; + sum += a[0] + b[0]; // += 2*a[1] + b++; a++; + sum += a[0] + b[0]; // += 2*a[2] + b--; a--; + sum += a[0] + b[0]; // += 2*a[1] + --b; --a; + sum += a[0] + b[0]; // += 2*a[0] + return sum; // 2*a[0] + 4*a[1] + 2*a[2] + b[2] +} + +// CHECK: void pointerMultipleParams_grad(const double *a, const double *b, clad::array_ref _d_a, clad::array_ref _d_b) { +// CHECK-NEXT: double _d_sum = 0; +// CHECK-NEXT: const double *_t0; +// CHECK-NEXT: clad::array_ref _t1; +// CHECK-NEXT: const double *_t2; +// CHECK-NEXT: clad::array_ref _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double _t6; +// CHECK-NEXT: double _t7; +// CHECK-NEXT: double sum = b[2]; +// CHECK-NEXT: _t0 = b; +// CHECK-NEXT: _t1 = _d_b; +// CHECK-NEXT: _d_b.ptr_ref() = _d_a.ptr_ref(); +// CHECK-NEXT: b = a; +// CHECK-NEXT: _t2 = a; +// CHECK-NEXT: _t3 = _d_a; +// CHECK-NEXT: _d_a.ptr_ref() = 1 + _d_a.ptr_ref(); +// CHECK-NEXT: a = 1 + a; +// CHECK-NEXT: ++_d_b.ptr_ref(); +// CHECK-NEXT: ++b; +// CHECK-NEXT: _t4 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: _d_b.ptr_ref()++; +// CHECK-NEXT: b++; +// CHECK-NEXT: _d_a.ptr_ref()++; +// CHECK-NEXT: a++; +// CHECK-NEXT: _t5 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: _d_b.ptr_ref()--; +// CHECK-NEXT: b--; +// CHECK-NEXT: _d_a.ptr_ref()--; +// CHECK-NEXT: a--; +// CHECK-NEXT: _t6 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: --_d_b.ptr_ref(); +// CHECK-NEXT: --b; +// CHECK-NEXT: --_d_a.ptr_ref(); +// CHECK-NEXT: --a; +// CHECK-NEXT: _t7 = sum; +// CHECK-NEXT: sum += a[0] + b[0]; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_sum += 1; +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t7; +// CHECK-NEXT: double _r_d3 = _d_sum; +// CHECK-NEXT: _d_a[0] += _r_d3; +// CHECK-NEXT: _d_b[0] += _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: ++a; +// CHECK-NEXT: ++_d_a.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: ++b; +// CHECK-NEXT: ++_d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t6; +// CHECK-NEXT: double _r_d2 = _d_sum; +// CHECK-NEXT: _d_a[0] += _r_d2; +// CHECK-NEXT: _d_b[0] += _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a++; +// CHECK-NEXT: _d_a.ptr_ref()++; +// CHECK-NEXT: _d_a.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: b++; +// CHECK-NEXT: _d_b.ptr_ref()++; +// CHECK-NEXT: _d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t5; +// CHECK-NEXT: double _r_d1 = _d_sum; +// CHECK-NEXT: _d_a[0] += _r_d1; +// CHECK-NEXT: _d_b[0] += _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a--; +// CHECK-NEXT: _d_a.ptr_ref()--; +// CHECK-NEXT: _d_a.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: b--; +// CHECK-NEXT: _d_b.ptr_ref()--; +// CHECK-NEXT: _d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: sum = _t4; +// CHECK-NEXT: double _r_d0 = _d_sum; +// CHECK-NEXT: _d_a[0] += _r_d0; +// CHECK-NEXT: _d_b[0] += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: --b; +// CHECK-NEXT: --_d_b.ptr_ref(); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: a = _t2; +// CHECK-NEXT: _d_a = _t3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: b = _t0; +// CHECK-NEXT: _d_b = _t1; +// CHECK-NEXT: } +// CHECK-NEXT: _d_b[2] += _d_sum; +// CHECK-NEXT: } + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -87,4 +404,29 @@ int main() { NON_MEM_FN_TEST(d_nonMemFnReinterpretCast); // CHECK-EXEC: 10.00 NON_MEM_FN_TEST(d_nonMemFnCStyleCast); // CHECK-EXEC: 10.00 + + // Pointer operation tests. + auto d_minimalPointer = clad::gradient(minimalPointer, "x"); + NON_MEM_FN_TEST(d_minimalPointer); // CHECK-EXEC: 10.00 + + auto d_arrayPointer = clad::gradient(arrayPointer, "arr"); + double arr[5] = {1, 2, 3, 4, 5}; + double d_arr[5] = {0, 0, 0, 0, 0}; + clad::array_ref d_arr_ref(d_arr, 5); + d_arrayPointer.execute(arr, d_arr_ref); + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 5.00 1.00 2.00 4.00 3.00 + + auto d_pointerParam = clad::gradient(pointerParam, "arr"); + d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0; + d_pointerParam.execute(arr, 5, d_arr_ref); + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 1.00 2.00 3.00 4.00 + + auto d_pointerMultipleParams = clad::gradient(pointerMultipleParams); + double b_arr[5] = {1, 2, 3, 4, 5}; + double d_b_arr[5] = {0, 0, 0, 0, 0}; + clad::array_ref d_b_arr_ref(d_b_arr, 5); + d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0; + d_pointerMultipleParams.execute(arr, b_arr, d_arr_ref, d_b_arr_ref); + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 2.00 4.00 2.00 0.00 0.00 + printf("%.2f %.2f %.2f %.2f %.2f\n", d_b_arr[0], d_b_arr[1], d_b_arr[2], d_b_arr[3], d_b_arr[4]); // CHECK-EXEC: 0.00 0.00 1.00 0.00 0.00 }