diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index 8fbe4155a..d38e93eb7 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -22,7 +22,7 @@ template class array_ref { public: /// Delete default constructor - array_ref() = delete; + array_ref() = default; /// Constructor to store the pointer to and size of an array supplied by the /// user CUDA_HOST_DEVICE array_ref(T* arr, std::size_t size) @@ -43,9 +43,16 @@ template class array_ref { m_arr[i] = a[i]; return *this; } + template + CUDA_HOST_DEVICE array_ref& operator=(const array_ref& a) { + m_arr = a.ptr(); + m_size = a.size(); + return *this; + } /// Returns the size of the underlying array CUDA_HOST_DEVICE std::size_t size() const { return m_size; } CUDA_HOST_DEVICE T* ptr() const { return m_arr; } + CUDA_HOST_DEVICE T*& ptr_ref() { return m_arr; } /// Returns an array_ref to a part of the underlying array starting at /// offset and having the specified size CUDA_HOST_DEVICE array_ref slice(std::size_t offset, std::size_t size) { @@ -61,6 +68,34 @@ template class array_ref { /// Returns the reference to the underlying array CUDA_HOST_DEVICE T& operator*() { return *m_arr; } + // Increment and decrement operators - update the underlying pointer. + /// Prefix increment operator. + CUDA_HOST_DEVICE array_ref& operator++() { + ++m_arr; + --m_size; + return *this; + } + /// Postfix increment operator. + CUDA_HOST_DEVICE array_ref operator++(int) { + array_ref tmp(*this); + ++m_arr; + --m_size; + return tmp; + } + /// Prefix decrement operator. + CUDA_HOST_DEVICE array_ref& operator--() { + --m_arr; + ++m_size; + return *this; + } + /// Postfix decrement operator. + CUDA_HOST_DEVICE array_ref operator--(int) { + array_ref tmp(*this); + --m_arr; + ++m_size; + return tmp; + } + // Arithmetic overloads /// Divides the arrays element wise template diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 4af3a66bd..543bbcdf7 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -513,6 +513,9 @@ namespace clad { /// Creates the expression Base.size() for the given Base expr. The Base /// expr must be of clad::array_ref type clang::Expr* BuildArrayRefSizeExpr(clang::Expr* Base); + /// Creates the expression Base.ptr_ref() for the given Base expr. The Base + /// expr must be of clad::array_ref type + clang::Expr* BuildArrayRefPtrRefExpr(clang::Expr* Base); /// Checks if the type is of clad::ValueAndPushforward type bool isCladValueAndPushforwardType(clang::QualType QT); /// Creates the expression Base.slice(Args) for the given Base expr and Args @@ -608,9 +611,9 @@ namespace clad { /// /// This functions sets `derivedL` and `derivedR` arguments to effective /// derived expressions. - static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, - clang::Expr*& derivedL, - clang::Expr*& derivedR); + void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff, + clang::Expr*& derivedL, + clang::Expr*& derivedR); }; } // end namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index dea966744..80285f7d5 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2696,6 +2696,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, initDiff.getExpr_dx()); addToCurrentBlock(assignDerivativeE, direction::forward); + if (isInsideLoop) { + auto tape = MakeCladTapeFor(derivedVDE); + addToCurrentBlock(tape.Push); + auto* reverseSweepDerivativePointerE = + BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); + m_LoopBlock.back().push_back( + BuildDeclStmt(reverseSweepDerivativePointerE)); + derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE); + } } m_Variables.emplace(VDClone, derivedVDE); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index e3287626f..019801b1f 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -651,6 +651,10 @@ namespace clad { return BuildCallExprToMemFn(Base, /*MemberFunctionName=*/"slice", Args); } + Expr* VisitorBase::BuildArrayRefPtrRefExpr(Expr* Base) { + return BuildCallExprToMemFn(Base, /*MemberFunctionName=*/"ptr_ref", {}); + } + bool VisitorBase::isCladArrayType(QualType QT) { // FIXME: Replace this check with a clang decl check return QT.getAsString().find("clad::array") != std::string::npos || @@ -788,13 +792,23 @@ namespace clad { derivedL = LDiff.getExpr_dx(); derivedR = RDiff.getExpr_dx(); if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && - !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { + utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { + if (isCladArrayType(derivedL->getType())) + derivedL = BuildArrayRefPtrRefExpr(derivedL); + if (isCladArrayType(derivedR->getType())) + derivedR = BuildArrayRefPtrRefExpr(derivedR); + } else if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { derivedL = LDiff.getExpr_dx(); + if (isCladArrayType(derivedL->getType())) + derivedL = BuildArrayRefPtrRefExpr(derivedL); derivedR = RDiff.getExpr(); } else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) { derivedL = LDiff.getExpr(); derivedR = RDiff.getExpr_dx(); + if (isCladArrayType(derivedR->getType())) + derivedR = BuildArrayRefPtrRefExpr(derivedR); } } } // end namespace clad diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 86c9602fb..538c2e8eb 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -1,7 +1,5 @@ // RUN: %cladclang %s -I%S/../../include -oPointers.out 2>&1 | FileCheck %s // RUN: ./Pointers.out | FileCheck -check-prefix=CHECK-EXEC %s -// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPointers.out -// RUN: ./Pointers.out | FileCheck -check-prefix=CHECK-EXEC %s // CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" @@ -172,6 +170,16 @@ double arrayPointer(const double* arr) { // CHECK-NEXT: } // CHECK-NEXT: } +double pointerParam(const double* arr, size_t n) { + double sum = 0; + for (size_t i=0; i < n; ++i) { + size_t* j = &i; + sum += arr[0] * (*j); + arr = arr + 1; + } + return sum; +} + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -252,4 +260,9 @@ int main() { clad::array_ref d_arr_ref(d_arr, 5); d_arrayPointer.execute(arr, d_arr_ref); printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 5.00 1.00 2.00 4.00 3.00 + + auto d_pointerParam = clad::gradient(pointerParam, "arr"); + d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0; + d_pointerParam.execute(arr, 5, d_arr_ref); + printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 1.00 2.00 3.00 4.00 }