diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 9e0394c8a..8e12af090 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -47,8 +47,6 @@ class BaseForwardModeVisitor virtual void ExecuteInsidePushforwardFunctionBlock(); - static bool IsDifferentiableType(clang::QualType T); - virtual StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 6e3d88a8a..d51a5c253 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -334,8 +334,6 @@ namespace clad { bool IsMemoryFunction(const clang::FunctionDecl* FD); bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); - bool IsDifferentiableType(clang::QualType QT); - /// Removes the local const qualifiers from a QualType and returns a new /// type. clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C, diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 04a67ca86..d38d3c089 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -208,6 +208,8 @@ namespace clad { return QT->isArrayType() || QT->isPointerType(); } + static bool IsDifferentiableType(clang::QualType T); + clang::CompoundStmt* MakeCompoundStmt(const Stmts& Stmts); /// Get the latest block of code (i.e. place for statements output). diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index d0e53581a..6f2dd6962 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -38,21 +38,6 @@ BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder) BaseForwardModeVisitor::~BaseForwardModeVisitor() {} -bool BaseForwardModeVisitor::IsDifferentiableType(QualType T) { - QualType origType = T; - // FIXME: arbitrary dimension array type as well. - while (utils::isArrayOrPointerType(T)) - T = utils::GetValueType(T); - T = T.getNonReferenceType(); - if (T->isEnumeralType()) - return false; - if (T->isRealType() || T->isStructureOrClassType()) - return true; - if (origType->isPointerType() && T->isVoidType()) - return true; - return false; -} - bool IsRealNonReferenceType(QualType T) { return T.getNonReferenceType()->isRealType(); } @@ -224,7 +209,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, // non-reference type for creating the derivatives. QualType dParamType = param->getType().getNonReferenceType(); // We do not create derived variable for array/pointer parameters. - if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) || + if (!IsDifferentiableType(dParamType) || utils::isArrayOrPointerType(dParamType)) continue; Expr* dParam = nullptr; @@ -420,7 +405,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, for (auto* PVD : m_Function->parameters()) { paramTypes.push_back(PVD->getType()); - if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + if (IsDifferentiableType(PVD->getType())) derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType())); } @@ -485,7 +470,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, if (identifierMissing) m_DeclReplacements[PVD] = newPVD; - if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + if (!IsDifferentiableType(PVD->getType())) continue; auto derivedPVDName = "_d_" + std::string(PVDII->getName()); IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); @@ -1069,7 +1054,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { } } CallArgs.push_back(argDiff.getExpr()); - if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) { + if (IsDifferentiableType(arg->getType())) { Expr* dArg = argDiff.getExpr_dx(); // FIXME: What happens when dArg is nullptr? diffArgs.push_back(dArg); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index f447db52e..d4a2ba0ee 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -685,11 +685,6 @@ namespace clad { #endif } - bool IsDifferentiableType(clang::QualType QT) { - // FIXME: consider analysing object types with this - return !utils::GetValueType(QT)->isIntegerType(); - } - clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C, clang::Sema& S) { clang::Qualifiers quals(T.getQualifiers()); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 46a694108..7bb084267 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -271,7 +271,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (request.Args) { DVI = request.DVI; for (const auto& dParam : DVI) - if (utils::IsDifferentiableType(dParam.param->getType())) + // no need to create adjoints for non-differentiable parameters. + if (IsDifferentiableType(dParam.param->getType())) args.push_back(dParam.param); } else @@ -600,7 +601,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) { ParmVarDecl* param = paramsRef[i]; // no need to create adjoints for non-differentiable variables. - if (!utils::IsDifferentiableType(param->getType())) + if (!IsDifferentiableType(param->getType())) continue; // derived variables are already created for independent variables. if (m_Variables.count(param)) @@ -1545,7 +1546,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // modified by the derived callee function. // Also, no need to create adjoint variables for non-differentiable types. if (utils::IsReferenceOrPointerArg(arg) || - !utils::IsDifferentiableType(arg->getType())) { + !IsDifferentiableType(arg->getType())) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -2582,7 +2583,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Integer types are not differentiable, // no need to construct an adjoint. - if (!utils::IsDifferentiableType(VD->getType())) { + if (!IsDifferentiableType(VD->getType())) { Expr* init = nullptr; if (VD->getInit()) init = Visit(VD->getInit()).getExpr(); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 694d293f8..0390a7dda 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -902,7 +902,7 @@ namespace clad { m_Mode != DiffMode::experimental_pushforward) for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) { QualType paramTy = originalFD->getParamDecl(i)->getType(); - if (!utils::IsDifferentiableType(paramTy)) { + if (!IsDifferentiableType(paramTy)) { QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema); VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy)); Expr* arg = BuildDeclRef(argDecl); @@ -961,4 +961,19 @@ namespace clad { } return false; } + + bool VisitorBase::IsDifferentiableType(QualType T) { + QualType origType = T; + // FIXME: arbitrary dimension array type as well. + while (utils::isArrayOrPointerType(T)) + T = utils::GetValueType(T); + T = T.getNonReferenceType(); + if (T->isEnumeralType()) + return false; + if (T->isFloatingType() || T->isStructureOrClassType()) + return true; + if (origType->isPointerType() && T->isVoidType()) + return true; + return false; + } } // end namespace clad