Skip to content

Commit

Permalink
Add support for call expressions in vector forward mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Nov 4, 2023
1 parent 1bdd261 commit 8ef4f4e
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 230 deletions.
15 changes: 15 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ class BaseForwardModeVisitor
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD,
const DiffRequest& request);

/// Returns the return type for the pushforward function of the function
/// `m_Function`.
/// \note `m_Function` field should be set before using this function.
clang::QualType ComputePushforwardFnReturnType();

virtual void ExecuteInsidePushforwardFunctionBlock();

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
Expand Down Expand Up @@ -93,6 +103,11 @@ class BaseForwardModeVisitor
VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE);
StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
virtual std::string GetPushForwardFunctionSuffix();
virtual DiffMode GetPushForwardMode();

protected:
/// Helper function for differentiating the switch statement body.
///
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ namespace clad {
private:
friend class VisitorBase;
friend class BaseForwardModeVisitor;
friend class ForwardModeVisitor;
friend class PushForwardModeVisitor;
friend class VectorForwardModeVisitor;
friend class VectorPushForwardModeVisitor;
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ enum class DiffMode {
vector_forward_mode,
experimental_pushforward,
experimental_pullback,
experimental_vector_pushforward,
reverse,
hessian,
jacobian,
Expand Down
33 changes: 0 additions & 33 deletions include/clad/Differentiator/ForwardModeVisitor.h

This file was deleted.

25 changes: 25 additions & 0 deletions include/clad/Differentiator/PushForwardModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//--------------------------------------------------------------------*- C++ -*-
// clad - the C++ Clang-based Automatic Differentiator
// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z [email protected] $
// author: Vassil Vassilev <vvasilev-at-cern.ch>
//------------------------------------------------------------------------------

#ifndef CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H

#include "BaseForwardModeVisitor.h"

namespace clad {
/// A visitor for processing the function code in forward mode.
/// Used to compute derivatives by clad::differentiate.
class PushForwardModeVisitor : public BaseForwardModeVisitor {

public:
PushForwardModeVisitor(DerivativeBuilder& builder);
~PushForwardModeVisitor() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H
8 changes: 8 additions & 0 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override;

clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType) override;
std::string GetPushForwardFunctionSuffix() override;
DiffMode GetPushForwardMode() override;

// Function for setting the independent variables for vector mode.
void SetIndependentVarsExpr(clang::Expr* IndVarCountExpr);
};
} // end namespace clad

Expand Down
20 changes: 20 additions & 0 deletions include/clad/Differentiator/VectorPushForwardModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H

#include "PushForwardModeVisitor.h"
#include "VectorForwardModeVisitor.h"

namespace clad {
class VectorPushForwardModeVisitor : public VectorForwardModeVisitor {

public:
VectorPushForwardModeVisitor(DerivativeBuilder& builder);
~VectorPushForwardModeVisitor() override;

void ExecuteInsidePushforwardFunctionBlock() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H
175 changes: 167 additions & 8 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,157 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
/*OverloadFunctionDecl=*/nullptr};
}

clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
assert(m_Mode == GetPushForwardMode());
QualType originalFnRT = m_Function->getReturnType();
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward =
LookupTemplateDeclInCladNamespace("ValueAndPushforward");
assert(valueAndPushforward &&
"clad::ValueAndPushforward template not found!!");
QualType RT = InstantiateTemplate(
valueAndPushforward,
{originalFnRT, GetPushForwardDerivativeType(originalFnRT)});
return RT;
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
for (Stmt* S : CS->body())
addToCurrentBlock(S);
}

DerivativeAndOverload
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
m_Function = FD;
m_Functor = request.Functor;
m_DerivativeOrder = request.CurrentDerivativeOrder;
m_Mode = GetPushForwardMode();
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function);

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, noLoc);
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

// If we are differentiating an instance member function then
// create a parameter type for the parameter that will represent the
// derivative of `this` pointer with respect to the independent parameter.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD);
derivedParamTypes.push_back(thisType);
}
}

for (auto* PVD : m_Function->parameters()) {
paramTypes.push_back(PVD->getType());

if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType()));
}

paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(),
derivedParamTypes.end());

const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_Function->getType());
QualType returnType = ComputePushforwardFnReturnType();
QualType derivedFnType = m_Context.getFunctionType(
returnType, paramTypes, originalFnType->getExtProtoInfo());
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope);
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;

DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
llvm::SmallVector<ParmVarDecl*, 16> derivedParams;
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

// If we are differentiating an instance member function then
// create a parameter for representing derivative of
// `this` pointer with respect to the independent parameter.
if (const auto* MFD = dyn_cast<CXXMethodDecl>(FD)) {
if (MFD->isInstance()) {
auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD);
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this");
auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext,
derivedPVDII, thisType);
m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(),
/*AddToContext=*/false);
derivedParams.push_back(derivedPVD);
m_ThisExprDerivative = BuildDeclRef(derivedPVD);
}
}

std::size_t numParamsOriginalFn = m_Function->getNumParams();
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) {
const auto* PVD = m_Function->getParamDecl(i);
// Some of the special member functions created implicitly by compilers
// have missing parameter identifier.
bool identifierMissing = false;
IdentifierInfo* PVDII = PVD->getIdentifier();
if (!PVDII || PVDII->getLength() == 0) {
PVDII = CreateUniqueIdentifier("param");
identifierMissing = true;
}
auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);
params.push_back(newPVD);

if (identifierMissing)
m_DeclReplacements[PVD] = newPVD;

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;
auto derivedPVDName = "_d_" + std::string(PVDII->getName());
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName);
auto* derivedPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, derivedPVDII,
GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass());
derivedParams.push_back(derivedPVD);
m_Variables[newPVD] = BuildDeclRef(derivedPVD);
}

params.insert(params.end(), derivedParams.begin(), derivedParams.end());
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

// execute the functor inside the function body.
ExecuteInsidePushforwardFunctionBlock();

Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;
return DerivativeAndOverload{cloneFunctionResult.first};
}

StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) {
diag(DiagnosticsEngine::Warning, S->getBeginLoc(),
"attempted to differentiate unsupported statement, no changes applied");
Expand Down Expand Up @@ -913,6 +1064,19 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
return OverloadedFn;
}

QualType
BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) {
return ParamType;
}

std::string BaseForwardModeVisitor::GetPushForwardFunctionSuffix() {
return "_pushforward";
}

DiffMode BaseForwardModeVisitor::GetPushForwardMode() {
return DiffMode::experimental_pushforward;
}

StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
const FunctionDecl* FD = CE->getDirectCallee();
if (!FD) {
Expand Down Expand Up @@ -1007,10 +1171,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
CallArgs.push_back(argDiff.getExpr());
if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) {
Expr* dArg = argDiff.getExpr_dx();
QualType CallArgTy = CallArgs.back()->getType();
assert((!dArg || m_Context.hasSameType(CallArgTy, dArg->getType())) &&
"Type mismatch, we might fail to instantiate a pullback");
(void)CallArgTy;
// FIXME: What happens when dArg is nullptr?
diffArgs.push_back(dArg);
}
Expand All @@ -1033,14 +1193,13 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

// Try to find a user-defined overloaded derivative.
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function) &&
m_Mode == DiffMode::experimental_pushforward) {
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
Expand Down Expand Up @@ -1089,7 +1248,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// derive the called function.
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = DiffMode::experimental_pushforward;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName = FD->getNameAsString();
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder;
// Silence diag outputs in nested derivation process.
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ add_llvm_library(cladDifferentiator
DiffPlanner.cpp
ErrorEstimator.cpp
EstimationModel.cpp
ForwardModeVisitor.cpp
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
PushForwardModeVisitor.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Version.cpp
VisitorBase.cpp
${version_inc}
Expand Down
Loading

0 comments on commit 8ef4f4e

Please sign in to comment.