Skip to content

Commit

Permalink
Add initial support for pointers in reverse mode (#686)
Browse files Browse the repository at this point in the history
This commit adds support for pointer operation in reverse mode. The technique is to maintain a corresponding derivative pointer variable, which gets updated (and stored/restored) in the exact same way as the primal pointer variable in both forward and reverse passes.
Added a workaround (with a FIXME comment) in the UsefulToStoreGlobal method to essentially bypass TBR analysis results for pointer expr.

Fixes #195, Fixes #197
  • Loading branch information
vaithak authored Dec 30, 2023
1 parent 654faee commit b80f03e
Show file tree
Hide file tree
Showing 9 changed files with 598 additions and 99 deletions.
12 changes: 11 additions & 1 deletion include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename T> class array_ref {

public:
/// Delete default constructor
array_ref() = delete;
array_ref() = default;
/// Constructor to store the pointer to and size of an array supplied by the
/// user
CUDA_HOST_DEVICE array_ref(T* arr, std::size_t size)
Expand All @@ -33,16 +33,26 @@ template <typename T> class array_ref {
/// Constructor for clad::array types
CUDA_HOST_DEVICE array_ref(array<T>& a) : m_arr(a.ptr()), m_size(a.size()) {}

/// Operator for conversion from array_ref<T> to T*.
CUDA_HOST_DEVICE operator T*() { return m_arr; }

template <typename U>
CUDA_HOST_DEVICE array_ref<T>& operator=(const array<U>& a) {
assert(m_size == a.size());
for (std::size_t i = 0; i < m_size; ++i)
m_arr[i] = a[i];
return *this;
}
template <typename U>
CUDA_HOST_DEVICE array_ref<T>& operator=(const array_ref<T>& a) {
m_arr = a.ptr();
m_size = a.size();
return *this;
}
/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
CUDA_HOST_DEVICE T* ptr() const { return m_arr; }
CUDA_HOST_DEVICE T*& ptr_ref() { return m_arr; }
/// Returns an array_ref to a part of the underlying array starting at
/// offset and having the specified size
CUDA_HOST_DEVICE array_ref<T> slice(std::size_t offset, std::size_t size) {
Expand Down
37 changes: 37 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,9 @@ namespace clad {
/// Creates the expression Base.size() for the given Base expr. The Base
/// expr must be of clad::array_ref<T> type
clang::Expr* BuildArrayRefSizeExpr(clang::Expr* Base);
/// Creates the expression Base.ptr_ref() for the given Base expr. The Base
/// expr must be of clad::array_ref<T> type
clang::Expr* BuildArrayRefPtrRefExpr(clang::Expr* Base);
/// Checks if the type is of clad::ValueAndPushforward<T,U> type
bool isCladValueAndPushforwardType(clang::QualType QT);
/// Creates the expression Base.slice(Args) for the given Base expr and Args
Expand Down Expand Up @@ -591,6 +594,40 @@ namespace clad {
/// Cloning types is necessary since VariableArrayType
/// store a pointer to their size expression.
clang::QualType CloneType(clang::QualType T);

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
};
} // end namespace clad

Expand Down
46 changes: 0 additions & 46 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,52 +1346,6 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
}
}

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType())) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr();
} else if (utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType())) {
derivedL = LDiff.getExpr();
derivedR = RDiff.getExpr_dx();
}
}

StmtDiff
BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
StmtDiff Ldiff = Visit(BinOp->getLHS());
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ namespace clad {
/// be more complex than just a DeclRefExpr.
/// (e.g. `__real (n++ ? z1 : z2)`)
m_Exprs.push_back(UnOp);
}
} else if (opCode == UnaryOperatorKind::UO_Deref)
m_Exprs.push_back(UnOp);
}

void VisitDeclRefExpr(clang::DeclRefExpr* DRE) {
Expand Down
Loading

0 comments on commit b80f03e

Please sign in to comment.