Skip to content

Commit

Permalink
fix pointer operations inside loop and as params
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Dec 22, 2023
1 parent 3028b7c commit 7f455a7
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 7 deletions.
37 changes: 36 additions & 1 deletion include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename T> class array_ref {

public:
/// Delete default constructor
array_ref() = delete;
array_ref() = default;
/// Constructor to store the pointer to and size of an array supplied by the
/// user
CUDA_HOST_DEVICE array_ref(T* arr, std::size_t size)
Expand All @@ -43,9 +43,16 @@ template <typename T> class array_ref {
m_arr[i] = a[i];
return *this;
}
template <typename U>
CUDA_HOST_DEVICE array_ref<T>& operator=(const array_ref<T>& a) {
m_arr = a.ptr();
m_size = a.size();
return *this;
}
/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
CUDA_HOST_DEVICE T* ptr() const { return m_arr; }
CUDA_HOST_DEVICE T*& ptr_ref() { return m_arr; }
/// Returns an array_ref to a part of the underlying array starting at
/// offset and having the specified size
CUDA_HOST_DEVICE array_ref<T> slice(std::size_t offset, std::size_t size) {
Expand All @@ -61,6 +68,34 @@ template <typename T> class array_ref {
/// Returns the reference to the underlying array
CUDA_HOST_DEVICE T& operator*() { return *m_arr; }

// Increment and decrement operators - update the underlying pointer.
/// Prefix increment operator.
CUDA_HOST_DEVICE array_ref<T>& operator++() {
++m_arr;
--m_size;
return *this;
}
/// Postfix increment operator.
CUDA_HOST_DEVICE array_ref<T> operator++(int) {
array_ref<T> tmp(*this);
++m_arr;
--m_size;
return tmp;
}
/// Prefix decrement operator.
CUDA_HOST_DEVICE array_ref<T>& operator--() {
--m_arr;
++m_size;
return *this;
}
/// Postfix decrement operator.
CUDA_HOST_DEVICE array_ref<T> operator--(int) {
array_ref<T> tmp(*this);
--m_arr;
++m_size;
return tmp;
}

// Arithmetic overloads
/// Divides the arrays element wise
template <typename U>
Expand Down
9 changes: 6 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ namespace clad {
/// Creates the expression Base.size() for the given Base expr. The Base
/// expr must be of clad::array_ref<T> type
clang::Expr* BuildArrayRefSizeExpr(clang::Expr* Base);
/// Creates the expression Base.ptr_ref() for the given Base expr. The Base
/// expr must be of clad::array_ref<T> type
clang::Expr* BuildArrayRefPtrRefExpr(clang::Expr* Base);
/// Checks if the type is of clad::ValueAndPushforward<T,U> type
bool isCladValueAndPushforwardType(clang::QualType QT);
/// Creates the expression Base.slice(Args) for the given Base expr and Args
Expand Down Expand Up @@ -608,9 +611,9 @@ namespace clad {
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
};
} // end namespace clad

Expand Down
9 changes: 9 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2696,6 +2696,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
BuildDeclStmt(reverseSweepDerivativePointerE));
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
}
m_Variables.emplace(VDClone, derivedVDE);

Expand Down
16 changes: 15 additions & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,10 @@ namespace clad {
return BuildCallExprToMemFn(Base, /*MemberFunctionName=*/"slice", Args);
}

Expr* VisitorBase::BuildArrayRefPtrRefExpr(Expr* Base) {
return BuildCallExprToMemFn(Base, /*MemberFunctionName=*/"ptr_ref", {});
}

bool VisitorBase::isCladArrayType(QualType QT) {
// FIXME: Replace this check with a clang decl check
return QT.getAsString().find("clad::array") != std::string::npos ||
Expand Down Expand Up @@ -788,13 +792,23 @@ namespace clad {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
if (isCladArrayType(derivedL->getType()))
derivedL = BuildArrayRefPtrRefExpr(derivedL);
if (isCladArrayType(derivedR->getType()))
derivedR = BuildArrayRefPtrRefExpr(derivedR);
} else if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr_dx();
if (isCladArrayType(derivedL->getType()))
derivedL = BuildArrayRefPtrRefExpr(derivedL);
derivedR = RDiff.getExpr();
} else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr();
derivedR = RDiff.getExpr_dx();
if (isCladArrayType(derivedR->getType()))
derivedR = BuildArrayRefPtrRefExpr(derivedR);
}
}
} // end namespace clad
17 changes: 15 additions & 2 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// RUN: %cladclang %s -I%S/../../include -oPointers.out 2>&1 | FileCheck %s
// RUN: ./Pointers.out | FileCheck -check-prefix=CHECK-EXEC %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPointers.out
// RUN: ./Pointers.out | FileCheck -check-prefix=CHECK-EXEC %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
Expand Down Expand Up @@ -172,6 +170,16 @@ double arrayPointer(const double* arr) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double pointerParam(const double* arr, size_t n) {
double sum = 0;
for (size_t i=0; i < n; ++i) {
size_t* j = &i;
sum += arr[0] * (*j);
arr = arr + 1;
}
return sum;
}

#define NON_MEM_FN_TEST(var)\
res[0]=0;\
var.execute(5,res);\
Expand Down Expand Up @@ -252,4 +260,9 @@ int main() {
clad::array_ref<double> d_arr_ref(d_arr, 5);
d_arrayPointer.execute(arr, d_arr_ref);
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 5.00 1.00 2.00 4.00 3.00

auto d_pointerParam = clad::gradient(pointerParam, "arr");
d_arr[0] = d_arr[1] = d_arr[2] = d_arr[3] = d_arr[4] = 0;
d_pointerParam.execute(arr, 5, d_arr_ref);
printf("%.2f %.2f %.2f %.2f %.2f\n", d_arr[0], d_arr[1], d_arr[2], d_arr[3], d_arr[4]); // CHECK-EXEC: 0.00 1.00 2.00 3.00 4.00
}

0 comments on commit 7f455a7

Please sign in to comment.