From 4e9f2b013cbf4b6dfb39c848b0305490a1e2480a Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Thu, 30 May 2024 17:01:25 +0300 Subject: [PATCH] Reword reverse-mode non-differentiable attribute --- lib/Differentiator/ReverseModeVisitor.cpp | 56 +++++++++++++++++++++-- test/ReverseMode/NonDifferentiable.C | 4 +- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f596cdfb0..18c11d4f5 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -482,7 +482,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DiffParams args{}; if (!request.DVI.empty()) for (const auto& dParam : request.DVI) - args.push_back(dParam.param); + args.push_back(dParam.param); else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); #ifndef NDEBUG @@ -1387,8 +1387,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } - if (clad::utils::hasNonDifferentiableAttribute(CE)) - return { Clone(CE), nullptr }; + SourceLocation validLoc { CE->getBeginLoc() }; + // If the function is non_differentiable, return zero derivative. + if (clad::utils::hasNonDifferentiableAttribute(CE)) { + // Calling the function without computing derivatives + llvm::SmallVector ClonedArgs; + for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) + ClonedArgs.push_back(Clone(CE->getArg(i))); + + Expr* Call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, ClonedArgs, validLoc) + .get(); + // Creating a zero derivative + auto* zero = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + + // Returning the function call and zero derivative + return StmtDiff(Call, zero); + } auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we @@ -2793,6 +2810,33 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector declsDiff; // Need to put array decls inlined. llvm::SmallVector localDeclsDiff; + + // If the type is marked as non_differentiable, skip generating its derivative + // Get the iterator + const auto* declsBegin = DS->decls().begin(); + const auto* declsEnd = DS->decls().end(); + + // If the DeclStmt is not empty, check the first declaration. + if (declsBegin != declsEnd && isa(*declsBegin)) { + auto* VD = dyn_cast(*declsBegin); + // Check for non-differentiable types. + QualType QT = VD->getType(); + if (QT->isPointerType()) + QT = QT->getPointeeType(); + auto* typeDecl = QT->getAsCXXRecordDecl(); + if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) { + for (auto* D : DS->decls()) { + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + else + diag(DiagnosticsEngine::Warning, D->getEndLoc(), + "Unsupported declaration"); + } + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); + } + } + // reverse_mode_forward_pass does not have a reverse pass so declarations // don't have to be moved to the function global scope. bool promoteToFnScope = !getCurrentScope()->isFunctionScope() && @@ -2975,8 +3019,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "CXXMethodDecl nodes not supported yet!"); MemberExpr* clonedME = utils::BuildMemberExpr( m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); + auto zero = + ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0); if (clad::utils::hasNonDifferentiableAttribute(ME)) - return { clonedME, nullptr }; + return {clonedME, zero}; if (!baseDiff.getExpr_dx()) return {clonedME, nullptr}; MemberExpr* derivedME = utils::BuildMemberExpr( @@ -4003,7 +4049,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, PVD->getStorageClass()); - paramDerivatives.push_back(dPVD); + paramDerivatives.push_back(dPVD); ++dParamTypesIdx; if (dPVD->getIdentifier()) diff --git a/test/ReverseMode/NonDifferentiable.C b/test/ReverseMode/NonDifferentiable.C index 4731bda33..bce3e2b1a 100644 --- a/test/ReverseMode/NonDifferentiable.C +++ b/test/ReverseMode/NonDifferentiable.C @@ -13,9 +13,9 @@ public: SimpleFunctions1() noexcept : x(0), y(0) {} SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} double x; - non_differentiable double y; + double y; double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } - non_differentiable double mem_fn_2(double i, double j) { return i * j; } + double mem_fn_2(double i, double j) { return i * j; } double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; } double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; } double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; }