Skip to content

Commit

Permalink
Moving DerivedFnCollector out of ClangPlugin
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Apr 25, 2024
1 parent 962e5d9 commit d879f1b
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 209 deletions.
13 changes: 11 additions & 2 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
#define CLAD_DERIVATIVE_BUILDER_H

#include "Compatibility.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"
#include "clad/Differentiator/DerivedFnCollector.h"
#include "clad/Differentiator/DiffPlanner.h"

#include <array>
#include <stack>
Expand Down Expand Up @@ -85,6 +86,7 @@ namespace clad {
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
const DerivedFnCollector& m_DFC;
std::unique_ptr<utils::StmtClone> m_NodeCloner;
clang::NamespaceDecl* m_BuiltinDerivativesNSD;
/// A reference to the model to use for error estimation (if any).
Expand Down Expand Up @@ -134,7 +136,8 @@ namespace clad {
}

public:
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P);
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC);
~DerivativeBuilder();
/// Reset the model use for error estimation (if any).
/// \param[in] estModel The error estimation model, can be either
Expand Down Expand Up @@ -163,6 +166,12 @@ namespace clad {
/// context.
///
DerivativeAndOverload Derive(const DiffRequest& request);
/// Find the derived function if present in the DerivedFnCollector.
///
/// \param[in] request The request to find the derived function.
///
/// \returns The derived function if found, nullptr otherwise.
clang::FunctionDecl* FindDerivedFunction(const DiffRequest& request);
};

} // end namespace clad
Expand Down
40 changes: 40 additions & 0 deletions include/clad/Differentiator/DerivedFnCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H
#define CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H

#include "clad/Differentiator/DerivedFnInfo.h"

#include "clang/AST/Decl.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"

namespace clad {
/// This class is designed to store collection of `DerivedFnInfo` objects.
/// It's purpose is to avoid repeated generation of same derivatives by
/// making it possible to reuse previously computed derivatives.
class DerivedFnCollector {
using DerivedFns = llvm::SmallVector<DerivedFnInfo, 16>;
/// Mapping to efficiently find out information about all the derivatives of
/// a function.
llvm::DenseMap<const clang::FunctionDecl*, DerivedFns>
m_DerivedFnInfoCollection;

public:
/// Adds a derived function to the collection.
void Add(const DerivedFnInfo& DFI);

/// Finds a `DerivedFnInfo` object in the collection that satisfies the
/// given differentiation request.
DerivedFnInfo Find(const DiffRequest& request) const;

bool IsDerivative(const clang::FunctionDecl* FD) const;

private:
/// Returns true if the collection already contains a `DerivedFnInfo`
/// object that represents the same derivative object as the provided
/// argument `DFI`.
bool AlreadyExists(const DerivedFnInfo& DFI) const;
};
} // namespace clad

#endif // CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H
49 changes: 49 additions & 0 deletions include/clad/Differentiator/DerivedFnInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H
#define CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H

#include "clang/AST/Decl.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

namespace clad {
struct DiffRequest;

/// `DerivedFnInfo` is designed to effectively store information about a
/// derived function.
struct DerivedFnInfo {
const clang::FunctionDecl* m_OriginalFn = nullptr;
clang::FunctionDecl* m_DerivedFn = nullptr;
clang::FunctionDecl* m_OverloadedDerivedFn = nullptr;
DiffMode m_Mode = DiffMode::unknown;
unsigned m_DerivativeOrder = 0;
DiffInputVarsInfo m_DiffVarsInfo;
bool m_UsesEnzyme = false;
bool m_DeclarationOnly = false;

DerivedFnInfo() = default;
DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,
clang::FunctionDecl* overloadedDerivedFn);

/// Returns true if the derived function represented by the object,
/// satisfies the requirements of the given differentiation request.
bool SatisfiesRequest(const DiffRequest& request) const;

/// Returns true if the object represents any derived function; otherwise
/// returns false.
bool IsValid() const;

const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; }
clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; }
clang::FunctionDecl* OverloadedDerivedFn() const {
return m_OverloadedDerivedFn;
}
bool DeclarationOnly() const { return m_DeclarationOnly; }

/// Returns true if `lhs` and `rhs` represents same derivative.
/// Here derivative is any function derived by clad.
static bool RepresentsSameDerivative(const DerivedFnInfo& lhs,
const DerivedFnInfo& rhs);
};
} // namespace clad

#endif // CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H
25 changes: 15 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,21 +1150,26 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName = FD->getNameAsString();
pushforwardFnRequest.BaseFunctionName = utils::ComputeEffectiveFnName(FD);
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder;
// Silence diag outputs in nested derivation process.
pushforwardFnRequest.VerboseDiags = false;

// Derive declaration of the pushforward function.
pushforwardFnRequest.DeclarationOnly = true;
// Check if request already derived in DerivedFunctions.
FunctionDecl* pushforwardFD =
plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest);

// Add the request to derive the definition of the pushforward function
// into the queue.
pushforwardFnRequest.DeclarationOnly = false;
pushforwardFnRequest.DerivedFDPrototype = pushforwardFD;
plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest);
m_Builder.FindDerivedFunction(pushforwardFnRequest);
if (!pushforwardFD) {
// Derive declaration of the pushforward function.
pushforwardFnRequest.DeclarationOnly = true;
pushforwardFD =
plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest);

// Add the request to derive the definition of the pushforward function
// into the queue.
pushforwardFnRequest.DeclarationOnly = false;
pushforwardFnRequest.DerivedFDPrototype = pushforwardFD;
plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest);
}

if (pushforwardFD) {
if (baseDiff.getExpr()) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ llvm_add_library(cladDifferentiator
CladUtils.cpp
ConstantFolder.cpp
DerivativeBuilder.cpp
DerivedFnCollector.cpp
DerivedFnInfo.cpp
DiffPlanner.cpp
ErrorEstimator.cpp
EstimationModel.cpp
Expand Down
88 changes: 49 additions & 39 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,52 +36,54 @@ using namespace clang;

namespace clad {

DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P)
: m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()),
DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC)
: m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC),
m_NodeCloner(new utils::StmtClone(m_Sema, m_Context)),
m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr) {}

DerivativeBuilder::~DerivativeBuilder() {}

static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
LookupResult R(semaRef, derivedFD->getNameInfo(), Sema::LookupOrdinaryName);
// FIXME: Attach out-of-line virtual function definitions to the TUScope.
Scope* S = semaRef.getScopeForContext(derivedFD->getDeclContext());
semaRef.CheckFunctionDeclaration(S, derivedFD, R,
/*IsMemberSpecialization=*/false
/*DeclIsDefn*/CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam(derivedFD));

// FIXME: Avoid the DeclContext lookup and the manual setPreviousDecl.
// Consider out-of-line virtual functions.
{
DeclContext* LookupCtx = derivedFD->getDeclContext();
auto R = LookupCtx->noload_lookup(derivedFD->getDeclName());

for (NamedDecl* I : R) {
if (auto* FD = dyn_cast<FunctionDecl>(I)) {
// FIXME: We still do extra work in creating a derivative and throwing
// it away.
if (FD->getDefinition())
return;

if (derivedFD->getASTContext()
.hasSameFunctionTypeIgnoringExceptionSpec(derivedFD
->getType(),
FD->getType())) {
// Register the function on the redecl chain.
derivedFD->setPreviousDecl(FD);
break;
}
DerivativeBuilder::~DerivativeBuilder() {}

static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
LookupResult R(semaRef, derivedFD->getNameInfo(), Sema::LookupOrdinaryName);
// FIXME: Attach out-of-line virtual function definitions to the TUScope.
Scope* S = semaRef.getScopeForContext(derivedFD->getDeclContext());
semaRef.CheckFunctionDeclaration(
S, derivedFD, R,
/*IsMemberSpecialization=*/
false
/*DeclIsDefn*/ CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam(
derivedFD));

// FIXME: Avoid the DeclContext lookup and the manual setPreviousDecl.
// Consider out-of-line virtual functions.
{
DeclContext* LookupCtx = derivedFD->getDeclContext();
auto R = LookupCtx->noload_lookup(derivedFD->getDeclName());

for (NamedDecl* I : R) {
if (auto* FD = dyn_cast<FunctionDecl>(I)) {
// FIXME: We still do extra work in creating a derivative and throwing
// it away.
if (FD->getDefinition())
return;

if (derivedFD->getASTContext().hasSameFunctionTypeIgnoringExceptionSpec(
derivedFD->getType(), FD->getType())) {
// Register the function on the redecl chain.
derivedFD->setPreviousDecl(FD);
break;
}
}
// Inform the decl's decl context for its existance after the lookup,
// otherwise it would end up in the LookupResult.
derivedFD->getDeclContext()->addDecl(derivedFD);

// FIXME: Rebuild VTable to remove requirements for "forward" declared
// virtual methods
}
// Inform the decl's decl context for its existance after the lookup,
// otherwise it would end up in the LookupResult.
derivedFD->getDeclContext()->addDecl(derivedFD);

// FIXME: Rebuild VTable to remove requirements for "forward" declared
// virtual methods
}
}

static bool hasAttribute(const Decl *D, attr::Kind Kind) {
for (const auto *Attribute : D->attrs())
Expand Down Expand Up @@ -378,4 +380,12 @@ namespace clad {

return result;
}

FunctionDecl*
DerivativeBuilder::FindDerivedFunction(const DiffRequest& request) {
auto DFI = m_DFC.Find(request);
if (DFI.IsValid())
return DFI.DerivedFn();
return nullptr;
}
}// end namespace clad
39 changes: 39 additions & 0 deletions lib/Differentiator/DerivedFnCollector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "clad/Differentiator/DerivedFnCollector.h"
#include "clad/Differentiator/DiffPlanner.h"

namespace clad {
void DerivedFnCollector::Add(const DerivedFnInfo& DFI) {
assert(!AlreadyExists(DFI) &&
"We are generating same derivative more than once, or calling "
"`DerivedFnCollector::Add` more than once for the same derivative "
". Ideally, we shouldn't do either.");
m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI);
}

bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {
auto subCollectionIt = m_DerivedFnInfoCollection.find(DFI.OriginalFn());
if (subCollectionIt == m_DerivedFnInfoCollection.end())
return false;
const auto& subCollection = subCollectionIt->second;
const auto* it =
std::find_if(subCollection.begin(), subCollection.end(),
[&DFI](const DerivedFnInfo& info) {
return DerivedFnInfo::RepresentsSameDerivative(DFI, info);
});
return it != subCollection.end();
}

DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const {
auto subCollectionIt = m_DerivedFnInfoCollection.find(request.Function);
if (subCollectionIt == m_DerivedFnInfoCollection.end())
return DerivedFnInfo();
const auto& subCollection = subCollectionIt->second;
const auto* it = std::find_if(subCollection.begin(), subCollection.end(),
[&request](const DerivedFnInfo& DFI) {
return DFI.SatisfiesRequest(request);
});
if (it == subCollection.end())
return DerivedFnInfo();
return *it;
}
} // namespace clad
24 changes: 11 additions & 13 deletions tools/DerivedFnInfo.cpp → lib/Differentiator/DerivedFnInfo.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "DerivedFnInfo.h"

#include "clad/Differentiator/DerivedFnInfo.h"
#include "clad/Differentiator/DiffPlanner.h"

using namespace clang;
Expand All @@ -19,17 +18,16 @@ bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const {
request.CurrentDerivativeOrder == m_DerivativeOrder &&
request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme &&
request.DeclarationOnly == m_DeclarationOnly);
}
}

bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; }
bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; }

bool DerivedFnInfo::RepresentsSameDerivative(const DerivedFnInfo& lhs,
const DerivedFnInfo& rhs) {
return lhs.m_OriginalFn == rhs.m_OriginalFn &&
lhs.m_DerivativeOrder == rhs.m_DerivativeOrder &&
lhs.m_Mode == rhs.m_Mode &&
lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo &&
lhs.m_UsesEnzyme == rhs.m_UsesEnzyme &&
lhs.m_DeclarationOnly == rhs.m_DeclarationOnly;
}
bool DerivedFnInfo::RepresentsSameDerivative(const DerivedFnInfo& lhs,
const DerivedFnInfo& rhs) {
return lhs.m_OriginalFn == rhs.m_OriginalFn &&
lhs.m_DerivativeOrder == rhs.m_DerivativeOrder &&
lhs.m_Mode == rhs.m_Mode && lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo &&
lhs.m_UsesEnzyme == rhs.m_UsesEnzyme &&
lhs.m_DeclarationOnly == rhs.m_DeclarationOnly;
}
} // namespace clad
Loading

0 comments on commit d879f1b

Please sign in to comment.