Skip to content

Commit

Permalink
Add initial support for pointers in reverse mode
Browse files Browse the repository at this point in the history
This commit adds support for pointer operation in reverse mode.
The technique is maintain a corresponding derivative pointer
variable, which gets updated (and stored/restored) in the exact same way
as the primal pointer variable in both forward and reverse passes.

Added a workaround (with a FIXME comment) in the UsefulToStoreGlobal
method to essentially bypass TBR analysis results for pointer expr.

Fixes #195, #197
  • Loading branch information
vaithak committed Dec 20, 2023
1 parent 6a42b46 commit 2aa5dcb
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 101 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ template <typename T> class array_ref {
/// Constructor for clad::array types
CUDA_HOST_DEVICE array_ref(array<T>& a) : m_arr(a.ptr()), m_size(a.size()) {}

/// Operator for conversion from array_ref<T> to T*.
CUDA_HOST_DEVICE operator T*() { return m_arr; }

template <typename U>
CUDA_HOST_DEVICE array_ref<T>& operator=(const array<U>& a) {
assert(m_size == a.size());
Expand Down
34 changes: 34 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,40 @@ namespace clad {
/// Cloning types is necessary since VariableArrayType
/// store a pointer to their size expression.
clang::QualType CloneType(clang::QualType T);

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
};
} // end namespace clad

Expand Down
46 changes: 0 additions & 46 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,52 +1344,6 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
}
}

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType())) {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr();
} else if (utils::isArrayOrPointerType(RDiff.getExpr_dx()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr_dx()->getType())) {
derivedL = LDiff.getExpr();
derivedR = RDiff.getExpr_dx();
}
}

StmtDiff
BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
StmtDiff Ldiff = Visit(BinOp->getLHS());
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ namespace clad {
/// be more complex than just a DeclRefExpr.
/// (e.g. `__real (n++ ? z1 : z2)`)
m_Exprs.push_back(UnOp);
}
} else if (opCode == UnaryOperatorKind::UO_Deref)
m_Exprs.push_back(UnOp);
}

void VisitDeclRefExpr(clang::DeclRefExpr* DRE) {
Expand Down
154 changes: 114 additions & 40 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Create the (_d_param[idx] += dfdx) statement.
if (dfdx()) {
// FIXME: not sure if this is generic.
// Don't update derivatives of non-record types.
// Don't update derivatives of record types.
if (!VD->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
Expand Down Expand Up @@ -2035,6 +2035,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If it is a post-increment/decrement operator, its result is a reference
// and we should return it.
Expr* ResultRef = nullptr;

// For increment/decrement of pointer, perform the same on the
// derivative pointer also.
bool isPointerOp = E->getType()->isPointerType();

if (opCode == UO_Plus)
// xi = +xj
// dxi/dxj = +1.0
Expand All @@ -2048,10 +2053,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = Visit(E, d);
} else if (opCode == UO_PostInc || opCode == UO_PostDec) {
diff = Visit(E, dfdx());
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()),
direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
if (isPointerOp)
addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse);
}

ResultRef = diff.getExpr_dx();
Expand All @@ -2060,10 +2070,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff);
} else if (opCode == UO_PreInc || opCode == UO_PreDec) {
diff = Visit(E, dfdx());
if (isPointerOp)
addToCurrentBlock(BuildOp(opCode, diff.getExpr_dx()),
direction::forward);
if (UsefulToStoreGlobal(diff.getRevSweepAsExpr())) {
auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
if (isPointerOp)
addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse);
}
auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add
: BinaryOperatorKind::BO_Sub;
Expand All @@ -2081,35 +2096,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
} else {
// FIXME: This is not adding 'address-of' operator support.
// This is just making this special case differentiable that is required
// for computing hessian:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// This code snippet should be removed once reverse mode officially
// supports pointers.
if (opCode == UnaryOperatorKind::UO_AddrOf) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance()) {
auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD);
if (utils::SameCanonicalType(thisType, UnOp->getType())) {
diff = Visit(E);
Expr* cloneE =
BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr());
Expr* derivedE =
BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx());
return {cloneE, derivedE};
}
}
} else if (opCode == UnaryOperatorKind::UO_AddrOf) {
diff = Visit(E);
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr());
Expr* derivedE = BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx());
return {cloneE, derivedE};
} else if (opCode == UnaryOperatorKind::UO_Deref) {
diff = Visit(E);
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());
Expr* diff_dx = diff.getExpr_dx();
bool specialDThisCase = false;
Expr* derivedE = nullptr;
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance() && !diff_dx->getType()->isPointerType())
specialDThisCase = true; // _d_this is already dereferenced.
}
if (specialDThisCase)
derivedE = diff_dx;
else {
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
auto* add_assign = BuildOp(BO_AddAssign, derivedE, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
}
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
// completely
return {cloneE, derivedE};
} else {
if (opCode != UO_LNot)
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
// completely
unsupportedOpWarn(UnOp->getEndLoc());

if (isa<DeclRefExpr>(E))
Expand All @@ -2134,6 +2152,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// we should return it.
Expr* ResultRef = nullptr;

bool isPointerOp =
L->getType()->isPointerType() || R->getType()->isPointerType();

if (opCode == BO_Add) {
// xi = xl + xr
// dxi/xl = 1.0
Expand Down Expand Up @@ -2306,6 +2327,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* Lblock = endBlock(direction::reverse);
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore);

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverese pass.
if (isPointerOp) {
Expr* Edx = Ldiff.getExpr_dx();
ExprsToStore.push_back(Edx);
}

if (L->HasSideEffects(m_Context)) {
Expr* E = Ldiff.getExpr();
auto* storeE =
Expand Down Expand Up @@ -2352,24 +2381,32 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Save old value for the derivative of LHS, to avoid problems with cases
// like x = x.
auto* oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);
clang::Expr* oldValue = nullptr;

// For pointer types, no need to store old derivatives.
if (!isPointerOp)
oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d",
/*forceDeclCreation=*/true);

if (opCode == BO_Assign) {
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepAsExpr();
} else if (opCode == BO_AddAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
if (!isPointerOp)
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_SubAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
if (!isPointerOp)
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_MulAssign) {
// Create a reference variable to keep the result of LHS, since it
// must be used on 2 places: when storing to a global variable
Expand Down Expand Up @@ -2427,8 +2464,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue);

// Update the derivative.
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse);
// Update the derivative only if LHS is not a pointer type.
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);

// Output statements from Visit(L).
for (auto it = Lblock_begin; it != Lblock_end; ++it)
addToCurrentBlock(*it, direction::reverse);
Expand Down Expand Up @@ -2460,6 +2500,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return BuildOp(opCode, LExpr, RExpr);
}
Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr());

// For pointer types.
if (isPointerOp) {
if (opCode == BO_Add || opCode == BO_Sub) {
Expr* derivedL = nullptr;
Expr* derivedR = nullptr;
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
if (opCode == BO_Sub)
derivedR = BuildParens(derivedR);
return StmtDiff(op, BuildOp(opCode, derivedL, derivedR), nullptr,
valueForRevPass);
}
if (opCode == BO_Assign || opCode == BO_AddAssign ||
opCode == BO_SubAssign) {
Expr* derivedL = nullptr;
Expr* derivedR = nullptr;
ComputeEffectiveDOperands(Ldiff, Rdiff, derivedL, derivedR);
addToCurrentBlock(BuildOp(opCode, derivedL, derivedR),
direction::forward);
}
}
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

Expand All @@ -2469,6 +2530,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto VDDerivedType = ComputeAdjointType(VD->getType());
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
Expand Down Expand Up @@ -2529,6 +2591,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// if VD is a pointer type, then the initial value is set to the derived
// expression of the corresponding pointer type.
else if (isPointerType && VD->getInit()) {
initDiff = Visit(VD->getInit());
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
if (VDDerivedType->isRecordType())
Expand All @@ -2546,7 +2615,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType) {
if (!isDerivativeOfRefType && !isPointerType) {
Expr* derivedE = BuildDeclRef(VDDerived);
initDiff = StmtDiff{};
if (VD->getInit()) {
Expand Down Expand Up @@ -2824,6 +2893,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If TBR analysis is off, assume E is useful to store.
if (!enableTBR)
return true;
// FIXME: currently, we allow all pointer operations to be stored.
// This is not correct, but we need to implement a more advanced analysis
// to determine which pointer operations are useful to store.
if (E->getType()->isPointerType())
return true;
auto found = m_ToBeRecorded.find(B->getBeginLoc());
return found != m_ToBeRecorded.end();
}
Expand Down
Loading

0 comments on commit 2aa5dcb

Please sign in to comment.