Skip to content

Commit

Permalink
Improve test coverage for pointer support
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Dec 20, 2023
1 parent eaa534a commit d44ce32
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 43 deletions.
30 changes: 30 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,36 @@ namespace clad {
/// store a pointer to their size expression.
clang::QualType CloneType(clang::QualType T);

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
static void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);
Expand Down
33 changes: 20 additions & 13 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepAsExpr())),
direction::reverse);
if (isPointerOp)
addToCurrentBlock(BuildOp(op, diff.getExpr_dx()), direction::reverse);

Check warning on line 2064 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L2064

Added line #L2064 was not covered by tests
}

ResultRef = diff.getExpr_dx();
Expand Down Expand Up @@ -2121,15 +2123,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
return {cloneE, derivedE};
} else if (opCode != UO_LNot) {
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
// completely
unsupportedOpWarn(UnOp->getEndLoc());
} else if (isa<DeclRefExpr>(E))
diff = Visit(E);
else
diff = StmtDiff(E);
} else {
if (opCode != UO_LNot)
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
// completely
unsupportedOpWarn(UnOp->getEndLoc());

if (isa<DeclRefExpr>(E))
diff = Visit(E);
else
diff = StmtDiff(E);
}
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}
Expand Down Expand Up @@ -2391,15 +2396,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
if (!isPointerOp)
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_SubAssign) {
if (!isPointerOp)
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
if (!isPointerOp)
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_MulAssign) {
// Create a reference variable to keep the result of LHS, since it
// must be used on 2 places: when storing to a global variable
Expand Down
30 changes: 0 additions & 30 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,36 +782,6 @@ namespace clad {
return TAL.get(0).getAsType();
}

/// Computes effective derivative operands. It should be used when operands
/// might be of pointer types.
///
/// In the trivial case, both operands are of non-pointer types, and the
/// effective derivative operands are `LDiff.getExpr_dx()` and
/// `RDiff.getExpr_dx()` respectively.
///
/// Integers used in pointer arithmetic should be considered
/// non-differentiable entities. For example:
///
/// ```
/// p + i;
/// ```
///
/// Derived statement should be:
///
/// ```
/// _d_p + i;
/// ```
///
/// instead of:
///
/// ```
/// _d_p + _d_i;
/// ```
///
/// Therefore, effective derived expression of `i` is `i` instead of `_d_i`.
///
/// This functions sets `derivedL` and `derivedR` arguments to effective
/// derived expressions.
void VisitorBase::ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR) {
Expand Down

0 comments on commit d44ce32

Please sign in to comment.