Skip to content

Commit

Permalink
Reword reverse-mode non-differentiable attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed May 30, 2024
1 parent 7553ad9 commit 4e9f2b0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
56 changes: 51 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DiffParams args{};
if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
args.push_back(dParam.param);
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
#ifndef NDEBUG
Expand Down Expand Up @@ -1387,8 +1387,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

if (clad::utils::hasNonDifferentiableAttribute(CE))
return { Clone(CE), nullptr };
SourceLocation validLoc { CE->getBeginLoc() };
// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
llvm::SmallVector<Expr*, 4> ClonedArgs;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i)
ClonedArgs.push_back(Clone(CE->getArg(i)));

Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc)
.get();
// Creating a zero derivative
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);

// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
Expand Down Expand Up @@ -2793,6 +2810,33 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Decl*, 4> declsDiff;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;

// If the type is marked as non_differentiable, skip generating its derivative
// Get the iterator
const auto* declsBegin = DS->decls().begin();
const auto* declsEnd = DS->decls().end();

// If the DeclStmt is not empty, check the first declaration.
if (declsBegin != declsEnd && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
// Check for non-differentiable types.
QualType QT = VD->getType();
if (QT->isPointerType())
QT = QT->getPointeeType();
auto* typeDecl = QT->getAsCXXRecordDecl();
if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) {
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
else
diag(DiagnosticsEngine::Warning, D->getEndLoc(),
"Unsupported declaration");
}
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
}

// reverse_mode_forward_pass does not have a reverse pass so declarations
// don't have to be moved to the function global scope.
bool promoteToFnScope = !getCurrentScope()->isFunctionScope() &&
Expand Down Expand Up @@ -2975,8 +3019,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0);
if (clad::utils::hasNonDifferentiableAttribute(ME))
return { clonedME, nullptr };
return {clonedME, zero};
if (!baseDiff.getExpr_dx())
return {clonedME, nullptr};
MemberExpr* derivedME = utils::BuildMemberExpr(
Expand Down Expand Up @@ -4003,7 +4049,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
paramDerivatives.push_back(dPVD);
paramDerivatives.push_back(dPVD);
++dParamTypesIdx;

if (dPVD->getIdentifier())
Expand Down
4 changes: 2 additions & 2 deletions test/ReverseMode/NonDifferentiable.C
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ public:
SimpleFunctions1() noexcept : x(0), y(0) {}
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
non_differentiable double y;
double y;
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; }
non_differentiable double mem_fn_2(double i, double j) { return i * j; }
double mem_fn_2(double i, double j) { return i * j; }
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; }
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; }
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; }
Expand Down

0 comments on commit 4e9f2b0

Please sign in to comment.