Skip to content

Commit

Permalink
Separate sets for custom and clad-generated derivatives
Browse files Browse the repository at this point in the history
This is required for allowing users to link custom derivatives
from a separate translation unit.
  • Loading branch information
vaithak committed Jun 12, 2024
1 parent 5752849 commit cc33a2c
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 132 deletions.
16 changes: 14 additions & 2 deletions include/clad/Differentiator/DerivedFnCollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,33 @@ class DerivedFnCollector {
/// a function.
llvm::DenseMap<const clang::FunctionDecl*, DerivedFns>
m_DerivedFnInfoCollection;
/// Set to keep track of all the functions that are derivatives.
/// Set to keep track of all the functions that are derivatives
/// functions produced by Clad.
DerivativeSet m_DerivativeSet;

/// Set to keep track of all the functions that are custom derivatives
/// functions provided by the user.
DerivativeSet m_CustomDerivativeSet;

public:
/// Adds a derived function to the collection.
void Add(const DerivedFnInfo& DFI);

/// Adds a function to derivative set.
void AddToDerivativeSet(const clang::FunctionDecl* FD);

/// Adds a function to custom derivative set.
void AddToCustomDerivativeSet(const clang::FunctionDecl* FD);

/// Finds a `DerivedFnInfo` object in the collection that satisfies the
/// given differentiation request.
DerivedFnInfo Find(const DiffRequest& request) const;

bool IsDerivative(const clang::FunctionDecl* FD) const;
/// Returns true if the function is a Clad-generated derivative.
bool IsCladDerivative(const clang::FunctionDecl* FD) const;

/// Returns true if the function is a custom derivative.
bool IsCustomDerivative(const clang::FunctionDecl* FD) const;

private:
/// Returns true if the collection already contains a `DerivedFnInfo`
Expand Down
4 changes: 3 additions & 1 deletion include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ namespace clad {
DerivativeAndOverload
Merge(std::vector<clang::FunctionDecl*> secDerivFuncs,
llvm::SmallVector<size_t, 16> IndependentArgsSize,
size_t TotalIndependentArgsSize, std::string hessianFuncName);
size_t TotalIndependentArgsSize, const std::string& hessianFuncName,
clang::DeclContext* FD, clang::QualType hessianFuncType,
llvm::SmallVector<clang::QualType, 16> paramTypes);

public:
HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
Expand Down
25 changes: 17 additions & 8 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,14 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
if (auto* FD = dyn_cast<FunctionDecl>(ND))
// Check if FD and functionType have the same signature.
if (utils::SameCanonicalType(FD->getType(), functionType))
if (FD->isDefined() || !m_DFC.IsDerivative(FD))
// Make sure that it is not the case that FD is the forward
// declaration generated by Clad. It should be user defined custom
// derivative (either within the same translation unit or linked in
// from another translation unit).
if (FD->isDefined() || !m_DFC.IsCladDerivative(FD)) {
m_DFC.AddToCustomDerivativeSet(FD);
return FD;
}

return nullptr;
}
Expand Down Expand Up @@ -284,7 +290,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
// differentiation due to unavailable definition.
if (auto* CE = dyn_cast<CallExpr>(OverloadedFn))
if (FunctionDecl* FD = CE->getDirectCallee())
m_DFC.AddToDerivativeSet(FD);
m_DFC.AddToCustomDerivativeSet(FD);
}
return OverloadedFn;
}
Expand Down Expand Up @@ -328,8 +334,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
// If FD is only a declaration, try to find its definition.
if (!FD->getDefinition()) {
// If only declaration is requested, allow this for clad-generated
// functions.
if (!request.DeclarationOnly || !m_DFC.IsDerivative(FD)) {
// functions or custom derivatives.
if (!request.DeclarationOnly ||
!(m_DFC.IsCladDerivative(FD) || m_DFC.IsCustomDerivative(FD))) {
if (request.VerboseDiags)
diag(DiagnosticsEngine::Error,
request.CallContext ? request.CallContext->getBeginLoc() : noLoc,
Expand Down Expand Up @@ -415,10 +422,12 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {

// FIXME: if the derivatives aren't registered in this order and the
// derivative is a member function it goes into an infinite loop
if (auto FD = result.derivative)
registerDerivative(FD, m_Sema);
if (auto OFD = result.overload)
registerDerivative(OFD, m_Sema);
if (!m_DFC.IsCustomDerivative(result.derivative)) {
if (auto FD = result.derivative)
registerDerivative(FD, m_Sema);
if (auto OFD = result.overload)
registerDerivative(OFD, m_Sema);
}

return result;
}
Expand Down
12 changes: 11 additions & 1 deletion lib/Differentiator/DerivedFnCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) {
m_DerivativeSet.insert(FD);
}

void DerivedFnCollector::AddToCustomDerivativeSet(
const clang::FunctionDecl* FD) {
m_CustomDerivativeSet.insert(FD);
}

bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {
auto subCollectionIt = m_DerivedFnInfoCollection.find(DFI.OriginalFn());
if (subCollectionIt == m_DerivedFnInfoCollection.end())
Expand Down Expand Up @@ -42,7 +47,12 @@ DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const {
return *it;
}

bool DerivedFnCollector::IsDerivative(const clang::FunctionDecl* FD) const {
bool DerivedFnCollector::IsCladDerivative(const clang::FunctionDecl* FD) const {
return m_DerivativeSet.count(FD);
}

bool DerivedFnCollector::IsCustomDerivative(
const clang::FunctionDecl* FD) const {
return m_CustomDerivativeSet.count(FD);
}
} // namespace clad
179 changes: 91 additions & 88 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,69 +48,71 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,

/// Derives the function w.r.t both forward and reverse mode and returns the
/// FunctionDecl obtained from reverse mode differentiation
static FunctionDecl* DeriveUsingForwardAndReverseMode(
Sema& SemaRef, clad::plugin::CladPlugin& CP,
clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest,
const Expr* ForwardModeArgs, const Expr* ReverseModeArgs) {
// Derives function once in forward mode w.r.t to ForwardModeArgs
IndependentArgRequest.Args = ForwardModeArgs;
IndependentArgRequest.Mode = DiffMode::forward;
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// FIXME: Find a way to do this without accessing plugin namespace functions
bool alreadyDerived = true;
FunctionDecl* firstDerivative =
Builder.FindDerivedFunction(IndependentArgRequest);
if (!firstDerivative) {
alreadyDerived = false;
// Derive declaration of the the forward mode derivative.
IndependentArgRequest.DeclarationOnly = true;
firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (firstDerivative->isDefined())
alreadyDerived = true;

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
IndependentArgRequest.DeclarationOnly = false;
IndependentArgRequest.DerivedFDPrototype = firstDerivative;
}
Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived);

// Further derives function w.r.t to ReverseModeArgs
DiffRequest ReverseModeRequest{};
ReverseModeRequest.Mode = DiffMode::reverse;
ReverseModeRequest.Function = firstDerivative;
ReverseModeRequest.Args = ReverseModeArgs;
ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString();
ReverseModeRequest.UpdateDiffParamsInfo(SemaRef);

alreadyDerived = true;
FunctionDecl* secondDerivative =
Builder.FindDerivedFunction(ReverseModeRequest);
if (!secondDerivative) {
alreadyDerived = false;
// Derive declaration of the the reverse mode derivative.
ReverseModeRequest.DeclarationOnly = true;
secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (secondDerivative->isDefined())
alreadyDerived = true;

// Add the request to derive the definition of the reverse mode
// derivative to the schedule.
ReverseModeRequest.DeclarationOnly = false;
ReverseModeRequest.DerivedFDPrototype = secondDerivative;
}
Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived);
return secondDerivative;
static FunctionDecl* DeriveUsingForwardAndReverseMode(
Sema& SemaRef, clad::plugin::CladPlugin& CP,
clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest,
const Expr* ForwardModeArgs, const Expr* ReverseModeArgs,
DerivedFnCollector& DFC) {
// Derives function once in forward mode w.r.t to ForwardModeArgs
IndependentArgRequest.Args = ForwardModeArgs;
IndependentArgRequest.Mode = DiffMode::forward;
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// FIXME: Find a way to do this without accessing plugin namespace functions
bool alreadyDerived = true;
FunctionDecl* firstDerivative =
Builder.FindDerivedFunction(IndependentArgRequest);
if (!firstDerivative) {
alreadyDerived = false;
// Derive declaration of the the forward mode derivative.
IndependentArgRequest.DeclarationOnly = true;
firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (firstDerivative->isDefined() || DFC.IsCustomDerivative(firstDerivative))
alreadyDerived = true;

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
IndependentArgRequest.DeclarationOnly = false;
IndependentArgRequest.DerivedFDPrototype = firstDerivative;
}
Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived);

// Further derives function w.r.t to ReverseModeArgs
DiffRequest ReverseModeRequest{};
ReverseModeRequest.Mode = DiffMode::reverse;
ReverseModeRequest.Function = firstDerivative;
ReverseModeRequest.Args = ReverseModeArgs;
ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString();
ReverseModeRequest.UpdateDiffParamsInfo(SemaRef);

alreadyDerived = true;
FunctionDecl* secondDerivative =
Builder.FindDerivedFunction(ReverseModeRequest);
if (!secondDerivative) {
alreadyDerived = false;
// Derive declaration of the the reverse mode derivative.
ReverseModeRequest.DeclarationOnly = true;
secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (secondDerivative->isDefined() ||
DFC.IsCustomDerivative(secondDerivative))
alreadyDerived = true;

// Add the request to derive the definition of the reverse mode
// derivative to the schedule.
ReverseModeRequest.DeclarationOnly = false;
ReverseModeRequest.DerivedFDPrototype = secondDerivative;
}
Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived);
return secondDerivative;
}

DerivativeAndOverload
HessianModeVisitor::Derive(const clang::FunctionDecl* FD,
Expand Down Expand Up @@ -149,6 +151,26 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,
}
}

llvm::SmallVector<QualType, 16> paramTypes(m_DiffReq->getNumParams() + 1);
std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(),
std::begin(paramTypes),
[](const ParmVarDecl* PVD) { return PVD->getType(); });
paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType());

const auto* originalFnProtoType =
cast<FunctionProtoType>(m_DiffReq->getType());
QualType hessianFunctionType = m_Context.getFunctionType(
m_Context.VoidTy,
llvm::ArrayRef<QualType>(paramTypes.data(), paramTypes.size()),
// Cast to function pointer.
originalFnProtoType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
hessianFuncName, DC, hessianFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};

// Ascertains the independent arguments and differentiates the function
// in forward and reverse mode by calling ProcessDiffRequest twice each
// iteration, storing each generated second derivative function
Expand Down Expand Up @@ -209,7 +231,7 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,
CreateStringLiteral(m_Context, independentArgString);
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args);
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
}

Expand All @@ -222,13 +244,14 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,
CreateStringLiteral(m_Context, PVD->getNameAsString());
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args);
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
}
}
}
return Merge(secondDerivativeColumns, IndependentArgsSize,
TotalIndependentArgsSize, hessianFuncName);
TotalIndependentArgsSize, hessianFuncName, DC,
hessianFunctionType, paramTypes);
}

// Combines all generated second derivative functions into a
Expand All @@ -238,36 +261,16 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C,
HessianModeVisitor::Merge(std::vector<FunctionDecl*> secDerivFuncs,
SmallVector<size_t, 16> IndependentArgsSize,
size_t TotalIndependentArgsSize,
std::string hessianFuncName) {
const std::string& hessianFuncName, DeclContext* DC,
QualType hessianFunctionType,
llvm::SmallVector<QualType, 16> paramTypes) {
DiffParams args;
std::copy(m_DiffReq->param_begin(), m_DiffReq->param_end(),
std::back_inserter(args));

IdentifierInfo* II = &m_Context.Idents.get(hessianFuncName);
DeclarationNameInfo name(II, noLoc);

llvm::SmallVector<QualType, 16> paramTypes(m_DiffReq->getNumParams() + 1);

std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(),
std::begin(paramTypes),
[](const ParmVarDecl* PVD) { return PVD->getType(); });

paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType());

const auto* originalFnProtoType =
cast<FunctionProtoType>(m_DiffReq->getType());
QualType hessianFunctionType = m_Context.getFunctionType(
m_Context.VoidTy,
llvm::ArrayRef<QualType>(paramTypes.data(), paramTypes.size()),
// Cast to function pointer.
originalFnProtoType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
hessianFuncName, DC, hessianFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};

// Create the gradient function declaration.
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
Expand Down
Loading

0 comments on commit cc33a2c

Please sign in to comment.