From b00c0c1e3e7e308d95e6f7648da2875ea244e42a Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 25 Apr 2024 13:17:05 +0200 Subject: [PATCH] Moving DerivedFnCollector out of ClangPlugin --- .../clad/Differentiator/DerivativeBuilder.h | 13 ++- .../clad/Differentiator/DerivedFnCollector.h | 40 +++++++++ include/clad/Differentiator/DerivedFnInfo.h | 49 +++++++++++ lib/Differentiator/BaseForwardModeVisitor.cpp | 25 +++--- lib/Differentiator/CMakeLists.txt | 2 + lib/Differentiator/DerivativeBuilder.cpp | 88 +++++++++++-------- lib/Differentiator/DerivedFnCollector.cpp | 39 ++++++++ .../Differentiator}/DerivedFnInfo.cpp | 24 +++-- lib/Differentiator/ReverseModeVisitor.cpp | 61 +++++++------ tools/CMakeLists.txt | 1 - tools/ClangPlugin.cpp | 37 +------- tools/ClangPlugin.h | 45 +++------- tools/DerivedFnInfo.h | 47 ---------- 13 files changed, 262 insertions(+), 209 deletions(-) create mode 100644 include/clad/Differentiator/DerivedFnCollector.h create mode 100644 include/clad/Differentiator/DerivedFnInfo.h create mode 100644 lib/Differentiator/DerivedFnCollector.cpp rename {tools => lib/Differentiator}/DerivedFnInfo.cpp (62%) delete mode 100644 tools/DerivedFnInfo.h diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index d51280bb1..efe369500 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -8,10 +8,11 @@ #define CLAD_DERIVATIVE_BUILDER_H #include "Compatibility.h" -#include "clad/Differentiator/DiffPlanner.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtVisitor.h" #include "clang/Sema/Sema.h" +#include "clad/Differentiator/DerivedFnCollector.h" +#include "clad/Differentiator/DiffPlanner.h" #include #include @@ -85,6 +86,7 @@ namespace clad { clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; + const DerivedFnCollector& m_DFC; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; /// A reference to the model to use for error estimation (if any). @@ -134,7 +136,8 @@ namespace clad { } public: - DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P); + DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, + const DerivedFnCollector& DFC); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). /// \param[in] estModel The error estimation model, can be either @@ -163,6 +166,12 @@ namespace clad { /// context. /// DerivativeAndOverload Derive(const DiffRequest& request); + /// Find the derived function if present in the DerivedFnCollector. + /// + /// \param[in] request The request to find the derived function. + /// + /// \returns The derived function if found, nullptr otherwise. + clang::FunctionDecl* FindDerivedFunction(const DiffRequest& request); }; } // end namespace clad diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h new file mode 100644 index 000000000..b0d297176 --- /dev/null +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -0,0 +1,40 @@ +#ifndef CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H +#define CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H + +#include "clad/Differentiator/DerivedFnInfo.h" + +#include "clang/AST/Decl.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace clad { +/// This class is designed to store collection of `DerivedFnInfo` objects. +/// It's purpose is to avoid repeated generation of same derivatives by +/// making it possible to reuse previously computed derivatives. +class DerivedFnCollector { + using DerivedFns = llvm::SmallVector; + /// Mapping to efficiently find out information about all the derivatives of + /// a function. + llvm::DenseMap + m_DerivedFnInfoCollection; + +public: + /// Adds a derived function to the collection. + void Add(const DerivedFnInfo& DFI); + + /// Finds a `DerivedFnInfo` object in the collection that satisfies the + /// given differentiation request. + DerivedFnInfo Find(const DiffRequest& request) const; + + bool IsDerivative(const clang::FunctionDecl* FD) const; + +private: + /// Returns true if the collection already contains a `DerivedFnInfo` + /// object that represents the same derivative object as the provided + /// argument `DFI`. + bool AlreadyExists(const DerivedFnInfo& DFI) const; +}; +} // namespace clad + +#endif // CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H \ No newline at end of file diff --git a/include/clad/Differentiator/DerivedFnInfo.h b/include/clad/Differentiator/DerivedFnInfo.h new file mode 100644 index 000000000..b2a95598e --- /dev/null +++ b/include/clad/Differentiator/DerivedFnInfo.h @@ -0,0 +1,49 @@ +#ifndef CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H +#define CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H + +#include "clang/AST/Decl.h" +#include "clad/Differentiator/DiffMode.h" +#include "clad/Differentiator/ParseDiffArgsTypes.h" + +namespace clad { +struct DiffRequest; + +/// `DerivedFnInfo` is designed to effectively store information about a +/// derived function. +struct DerivedFnInfo { + const clang::FunctionDecl* m_OriginalFn = nullptr; + clang::FunctionDecl* m_DerivedFn = nullptr; + clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; + DiffMode m_Mode = DiffMode::unknown; + unsigned m_DerivativeOrder = 0; + DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme = false; + bool m_DeclarationOnly = false; + + DerivedFnInfo() = default; + DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, + clang::FunctionDecl* overloadedDerivedFn); + + /// Returns true if the derived function represented by the object, + /// satisfies the requirements of the given differentiation request. + bool SatisfiesRequest(const DiffRequest& request) const; + + /// Returns true if the object represents any derived function; otherwise + /// returns false. + bool IsValid() const; + + const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } + clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } + clang::FunctionDecl* OverloadedDerivedFn() const { + return m_OverloadedDerivedFn; + } + bool DeclarationOnly() const { return m_DeclarationOnly; } + + /// Returns true if `lhs` and `rhs` represents same derivative. + /// Here derivative is any function derived by clad. + static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, + const DerivedFnInfo& rhs); +}; +} // namespace clad + +#endif // CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H \ No newline at end of file diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 2d2c06609..acad4b3a1 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1150,21 +1150,26 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { DiffRequest pushforwardFnRequest; pushforwardFnRequest.Function = FD; pushforwardFnRequest.Mode = GetPushForwardMode(); - pushforwardFnRequest.BaseFunctionName = FD->getNameAsString(); + pushforwardFnRequest.BaseFunctionName = utils::ComputeEffectiveFnName(FD); // pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder; // Silence diag outputs in nested derivation process. pushforwardFnRequest.VerboseDiags = false; - // Derive declaration of the pushforward function. - pushforwardFnRequest.DeclarationOnly = true; + // Check if request already derived in DerivedFunctions. FunctionDecl* pushforwardFD = - plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest); - - // Add the request to derive the definition of the pushforward function - // into the queue. - pushforwardFnRequest.DeclarationOnly = false; - pushforwardFnRequest.DerivedFDPrototype = pushforwardFD; - plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest); + m_Builder.FindDerivedFunction(pushforwardFnRequest); + if (!pushforwardFD) { + // Derive declaration of the pushforward function. + pushforwardFnRequest.DeclarationOnly = true; + pushforwardFD = + plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest); + + // Add the request to derive the definition of the pushforward function + // into the queue. + pushforwardFnRequest.DeclarationOnly = false; + pushforwardFnRequest.DerivedFDPrototype = pushforwardFD; + plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest); + } if (pushforwardFD) { if (baseDiff.getExpr()) { diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 0398df321..561019342 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -25,6 +25,8 @@ llvm_add_library(cladDifferentiator CladUtils.cpp ConstantFolder.cpp DerivativeBuilder.cpp + DerivedFnCollector.cpp + DerivedFnInfo.cpp DiffPlanner.cpp ErrorEstimator.cpp EstimationModel.cpp diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 96db90302..3a6c57ec0 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -36,52 +36,54 @@ using namespace clang; namespace clad { - DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P) - : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), +DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, + const DerivedFnCollector& DFC) + : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), m_NodeCloner(new utils::StmtClone(m_Sema, m_Context)), m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr) {} - DerivativeBuilder::~DerivativeBuilder() {} - - static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { - LookupResult R(semaRef, derivedFD->getNameInfo(), Sema::LookupOrdinaryName); - // FIXME: Attach out-of-line virtual function definitions to the TUScope. - Scope* S = semaRef.getScopeForContext(derivedFD->getDeclContext()); - semaRef.CheckFunctionDeclaration(S, derivedFD, R, - /*IsMemberSpecialization=*/false - /*DeclIsDefn*/CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam(derivedFD)); - - // FIXME: Avoid the DeclContext lookup and the manual setPreviousDecl. - // Consider out-of-line virtual functions. - { - DeclContext* LookupCtx = derivedFD->getDeclContext(); - auto R = LookupCtx->noload_lookup(derivedFD->getDeclName()); - - for (NamedDecl* I : R) { - if (auto* FD = dyn_cast(I)) { - // FIXME: We still do extra work in creating a derivative and throwing - // it away. - if (FD->getDefinition()) - return; - - if (derivedFD->getASTContext() - .hasSameFunctionTypeIgnoringExceptionSpec(derivedFD - ->getType(), - FD->getType())) { - // Register the function on the redecl chain. - derivedFD->setPreviousDecl(FD); - break; - } +DerivativeBuilder::~DerivativeBuilder() {} + +static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { + LookupResult R(semaRef, derivedFD->getNameInfo(), Sema::LookupOrdinaryName); + // FIXME: Attach out-of-line virtual function definitions to the TUScope. + Scope* S = semaRef.getScopeForContext(derivedFD->getDeclContext()); + semaRef.CheckFunctionDeclaration( + S, derivedFD, R, + /*IsMemberSpecialization=*/ + false + /*DeclIsDefn*/ CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam( + derivedFD)); + + // FIXME: Avoid the DeclContext lookup and the manual setPreviousDecl. + // Consider out-of-line virtual functions. + { + DeclContext* LookupCtx = derivedFD->getDeclContext(); + auto R = LookupCtx->noload_lookup(derivedFD->getDeclName()); + + for (NamedDecl* I : R) { + if (auto* FD = dyn_cast(I)) { + // FIXME: We still do extra work in creating a derivative and throwing + // it away. + if (FD->getDefinition()) + return; + + if (derivedFD->getASTContext().hasSameFunctionTypeIgnoringExceptionSpec( + derivedFD->getType(), FD->getType())) { + // Register the function on the redecl chain. + derivedFD->setPreviousDecl(FD); + break; } } - // Inform the decl's decl context for its existance after the lookup, - // otherwise it would end up in the LookupResult. - derivedFD->getDeclContext()->addDecl(derivedFD); - - // FIXME: Rebuild VTable to remove requirements for "forward" declared - // virtual methods } + // Inform the decl's decl context for its existance after the lookup, + // otherwise it would end up in the LookupResult. + derivedFD->getDeclContext()->addDecl(derivedFD); + + // FIXME: Rebuild VTable to remove requirements for "forward" declared + // virtual methods } +} static bool hasAttribute(const Decl *D, attr::Kind Kind) { for (const auto *Attribute : D->attrs()) @@ -378,4 +380,12 @@ namespace clad { return result; } + + FunctionDecl* + DerivativeBuilder::FindDerivedFunction(const DiffRequest& request) { + auto DFI = m_DFC.Find(request); + if (DFI.IsValid()) + return DFI.DerivedFn(); + return nullptr; + } }// end namespace clad diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp new file mode 100644 index 000000000..1e9c9a837 --- /dev/null +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -0,0 +1,39 @@ +#include "clad/Differentiator/DerivedFnCollector.h" +#include "clad/Differentiator/DiffPlanner.h" + +namespace clad { +void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { + assert(!AlreadyExists(DFI) && + "We are generating same derivative more than once, or calling " + "`DerivedFnCollector::Add` more than once for the same derivative " + ". Ideally, we shouldn't do either."); + m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); +} + +bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { + auto subCollectionIt = m_DerivedFnInfoCollection.find(DFI.OriginalFn()); + if (subCollectionIt == m_DerivedFnInfoCollection.end()) + return false; + const auto& subCollection = subCollectionIt->second; + const auto* it = + std::find_if(subCollection.begin(), subCollection.end(), + [&DFI](const DerivedFnInfo& info) { + return DerivedFnInfo::RepresentsSameDerivative(DFI, info); + }); + return it != subCollection.end(); +} + +DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const { + auto subCollectionIt = m_DerivedFnInfoCollection.find(request.Function); + if (subCollectionIt == m_DerivedFnInfoCollection.end()) + return DerivedFnInfo(); + const auto& subCollection = subCollectionIt->second; + const auto* it = std::find_if(subCollection.begin(), subCollection.end(), + [&request](const DerivedFnInfo& DFI) { + return DFI.SatisfiesRequest(request); + }); + if (it == subCollection.end()) + return DerivedFnInfo(); + return *it; +} +} // namespace clad \ No newline at end of file diff --git a/tools/DerivedFnInfo.cpp b/lib/Differentiator/DerivedFnInfo.cpp similarity index 62% rename from tools/DerivedFnInfo.cpp rename to lib/Differentiator/DerivedFnInfo.cpp index 21f97bce0..dce053b2a 100644 --- a/tools/DerivedFnInfo.cpp +++ b/lib/Differentiator/DerivedFnInfo.cpp @@ -1,5 +1,4 @@ -#include "DerivedFnInfo.h" - +#include "clad/Differentiator/DerivedFnInfo.h" #include "clad/Differentiator/DiffPlanner.h" using namespace clang; @@ -19,17 +18,16 @@ bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { request.CurrentDerivativeOrder == m_DerivativeOrder && request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme && request.DeclarationOnly == m_DeclarationOnly); - } +} - bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; } +bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; } - bool DerivedFnInfo::RepresentsSameDerivative(const DerivedFnInfo& lhs, - const DerivedFnInfo& rhs) { - return lhs.m_OriginalFn == rhs.m_OriginalFn && - lhs.m_DerivativeOrder == rhs.m_DerivativeOrder && - lhs.m_Mode == rhs.m_Mode && - lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo && - lhs.m_UsesEnzyme == rhs.m_UsesEnzyme && - lhs.m_DeclarationOnly == rhs.m_DeclarationOnly; - } +bool DerivedFnInfo::RepresentsSameDerivative(const DerivedFnInfo& lhs, + const DerivedFnInfo& rhs) { + return lhs.m_OriginalFn == rhs.m_OriginalFn && + lhs.m_DerivativeOrder == rhs.m_DerivativeOrder && + lhs.m_Mode == rhs.m_Mode && lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo && + lhs.m_UsesEnzyme == rhs.m_UsesEnzyme && + lhs.m_DeclarationOnly == rhs.m_DeclarationOnly; +} } // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e5b08e58f..2f5607b74 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1799,24 +1799,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); - FunctionDecl* pullbackFD = nullptr; - if (!m_ExternalSource) { - // Derive the declaration of the pullback function. - pullbackRequest.DeclarationOnly = true; - pullbackFD = - plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); - - // Add the request to derive the definition of the pullback function. - pullbackRequest.DeclarationOnly = false; - pullbackRequest.DerivedFDPrototype = pullbackFD; - plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest); - } else { - // FIXME: Error estimation currently uses singleton objects - - // m_ErrorEstHandler and m_EstModel, which is cleared after each - // error_estimate request. This requires the pullback to be derived at - // the same time to access the singleton objects. - pullbackFD = - plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + FunctionDecl* pullbackFD = + m_Builder.FindDerivedFunction(pullbackRequest); + if (!pullbackFD) { + if (!m_ExternalSource) { + // Derive the declaration of the pullback function. + pullbackRequest.DeclarationOnly = true; + pullbackFD = + plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + + // Add the request to derive the definition of the pullback + // function. + pullbackRequest.DeclarationOnly = false; + pullbackRequest.DerivedFDPrototype = pullbackFD; + plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest); + } else { + // FIXME: Error estimation currently uses singleton objects - + // m_ErrorEstHandler and m_EstModel, which is cleared after each + // error_estimate request. This requires the pullback to be derived + // at the same time to access the singleton objects. + pullbackFD = + plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + } } // Clad failed to derive it. @@ -1907,15 +1911,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, clad::utils::ComputeEffectiveFnName(FD); calleeFnForwPassReq.VerboseDiags = true; - // Derive declaration of the the forward pass function. - calleeFnForwPassReq.DeclarationOnly = true; FunctionDecl* calleeFnForwPassFD = - plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); - - // Add the request to derive the definition of the forward pass function. - calleeFnForwPassReq.DeclarationOnly = false; - calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; - plugin::AddRequestToSchedule(m_CladPlugin, calleeFnForwPassReq); + m_Builder.FindDerivedFunction(calleeFnForwPassReq); + if (!calleeFnForwPassFD) { + // Derive declaration of the the forward pass function. + calleeFnForwPassReq.DeclarationOnly = true; + calleeFnForwPassFD = + plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + + // Add the request to derive the definition of the forward pass + // function. + calleeFnForwPassReq.DeclarationOnly = false; + calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; + plugin::AddRequestToSchedule(m_CladPlugin, calleeFnForwPassReq); + } assert(calleeFnForwPassFD && "Clad failed to generate callee function forward pass function"); diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index cecee04b3..b5884863d 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -5,7 +5,6 @@ endif() set(CLAD_PLUGIN_SRC ClangPlugin.cpp ClangBackendPlugin.cpp - DerivedFnInfo.cpp RequiredSymbols.cpp ) diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index f4c4c1c32..3f2ba506a 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -122,7 +122,7 @@ namespace clad { Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this, m_DFC)); RequestOptions opts{}; SetRequestOptions(opts); @@ -505,41 +505,6 @@ namespace clad { else return true; } - - void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { - assert(!AlreadyExists(DFI) && - "We are generating same derivative more than once, or calling " - "`DerivedFnCollector::Add` more than once for the same derivative " - ". Ideally, we shouldn't do either."); - m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); - } - - bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { - auto subCollectionIt = m_DerivedFnInfoCollection.find(DFI.OriginalFn()); - if (subCollectionIt == m_DerivedFnInfoCollection.end()) - return false; - auto& subCollection = subCollectionIt->second; - auto it = std::find_if(subCollection.begin(), subCollection.end(), - [&DFI](const DerivedFnInfo& info) { - return DerivedFnInfo:: - RepresentsSameDerivative(DFI, info); - }); - return it != subCollection.end(); - } - - DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const { - auto subCollectionIt = m_DerivedFnInfoCollection.find(request.Function); - if (subCollectionIt == m_DerivedFnInfoCollection.end()) - return DerivedFnInfo(); - auto& subCollection = subCollectionIt->second; - auto it = std::find_if(subCollection.begin(), subCollection.end(), - [&request](DerivedFnInfo DFI) { - return DFI.SatisfiesRequest(request); - }); - if (it == subCollection.end()) - return DerivedFnInfo(); - return *it; - } } // end namespace clad // Attach the frontend plugin. diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index b7f6075c2..c69016681 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -7,8 +7,8 @@ #ifndef CLAD_CLANG_PLUGIN #define CLAD_CLANG_PLUGIN -#include "DerivedFnInfo.h" #include "clad/Differentiator/DerivativeBuilder.h" +#include "clad/Differentiator/DerivedFnCollector.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" @@ -37,41 +37,16 @@ namespace clang { namespace clad { - bool checkClangVersion(); - /// This class is designed to store collection of `DerivedFnInfo` objects. - /// It's purpose is to avoid repeated generation of same derivatives by - /// making it possible to reuse previously computed derivatives. - class DerivedFnCollector { - using DerivedFns = llvm::SmallVector; - /// Mapping to efficiently find out information about all the derivatives of - /// a function. - llvm::DenseMap m_DerivedFnInfoCollection; +bool checkClangVersion(); +class CladTimerGroup { + llvm::TimerGroup m_Tg; + std::vector> m_Timers; - public: - /// Adds a derived function to the collection. - void Add(const DerivedFnInfo& DFI); - - /// Finds a `DerivedFnInfo` object in the collection that satisfies the - /// given differentiation request. - DerivedFnInfo Find(const DiffRequest& request) const; - - bool IsDerivative(const clang::FunctionDecl* FD) const; - - private: - /// Returns true if the collection already contains a `DerivedFnInfo` - /// object that represents the same derivative object as the provided - /// argument `DFI`. - bool AlreadyExists(const DerivedFnInfo& DFI) const; - }; - class CladTimerGroup { - llvm::TimerGroup m_Tg; - std::vector> m_Timers; - - public: - CladTimerGroup(); - void StartNewTimer(llvm::StringRef TimerName, llvm::StringRef TimerDesc); - void StopTimer(); - }; +public: + CladTimerGroup(); + void StartNewTimer(llvm::StringRef TimerName, llvm::StringRef TimerDesc); + void StopTimer(); +}; namespace plugin { struct DifferentiationOptions { diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h deleted file mode 100644 index 955d51f98..000000000 --- a/tools/DerivedFnInfo.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef CLAD_DERIVED_FN_H -#define CLAD_DERIVED_FN_H - -#include "clang/AST/Decl.h" -#include "clad/Differentiator/DiffMode.h" -#include "clad/Differentiator/ParseDiffArgsTypes.h" - -namespace clad { - struct DiffRequest; - - /// `DerivedFnInfo` is designed to effectively store information about a - /// derived function. - struct DerivedFnInfo { - const clang::FunctionDecl* m_OriginalFn = nullptr; - clang::FunctionDecl* m_DerivedFn = nullptr; - clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; - DiffMode m_Mode = DiffMode::unknown; - unsigned m_DerivativeOrder = 0; - DiffInputVarsInfo m_DiffVarsInfo; - bool m_UsesEnzyme = false; - bool m_DeclarationOnly = false; - - DerivedFnInfo() {} - DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, - clang::FunctionDecl* overloadedDerivedFn); - - /// Returns true if the derived function represented by the object, - /// satisfies the requirements of the given differentiation request. - bool SatisfiesRequest(const DiffRequest& request) const; - - /// Returns true if the object represents any derived function; otherwise - /// returns false. - bool IsValid() const; - - const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } - clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } - clang::FunctionDecl* OverloadedDerivedFn() const { return m_OverloadedDerivedFn; } - bool DeclarationOnly() const { return m_DeclarationOnly; } - - /// Returns true if `lhs` and `rhs` represents same derivative. - /// Here derivative is any function derived by clad. - static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, - const DerivedFnInfo& rhs); - }; -} // namespace clad - -#endif \ No newline at end of file