Skip to content

Commit

Permalink
Add support for call expressions in vector forward mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Oct 9, 2023
1 parent 41699e8 commit e04b101
Show file tree
Hide file tree
Showing 14 changed files with 442 additions and 203 deletions.
5 changes: 5 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class BaseForwardModeVisitor
VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE);
StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
virtual std::string GetPushForwardFunctionSuffix();
virtual DiffMode GetPushForwardMode();

protected:
/// Helper function for differentiating the switch statement body.
///
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ namespace clad {
private:
friend class VisitorBase;
friend class BaseForwardModeVisitor;
friend class ForwardModeVisitor;
friend class PushForwardModeVisitor;
friend class VectorForwardModeVisitor;
friend class VectorPushForwardModeVisitor;
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ enum class DiffMode {
vector_forward_mode,
experimental_pushforward,
experimental_pullback,
experimental_vector_pushforward,
reverse,
hessian,
jacobian,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@
#include "BaseForwardModeVisitor.h"

namespace clad {
/// A visitor for processing the function code in forward mode.
/// Used to compute derivatives by clad::differentiate.
class ForwardModeVisitor : public BaseForwardModeVisitor {
/// A visitor for processing the function code in forward mode.
/// Used to compute derivatives by clad::differentiate.
class PushForwardModeVisitor : virtual public BaseForwardModeVisitor {

public:
ForwardModeVisitor(DerivativeBuilder& builder);
~ForwardModeVisitor();
PushForwardModeVisitor(DerivativeBuilder& builder);
~PushForwardModeVisitor() override;

DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD,
const DiffRequest& request);

virtual void ExecuteInsideFunctionBlock();

/// Returns the return type for the pushforward function of the function
/// `m_Function`.
/// \note `m_Function` field should be set before using this function.
Expand Down
10 changes: 9 additions & 1 deletion include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace clad {
/// A visitor for processing the function code in vector forward mode.
/// Used to compute derivatives by clad::vector_forward_differentiate.
class VectorForwardModeVisitor : public BaseForwardModeVisitor {
class VectorForwardModeVisitor : virtual public BaseForwardModeVisitor {
private:
llvm::SmallVector<const clang::ValueDecl*, 16> m_IndependentVars;
/// Map used to keep track of parameter variables w.r.t which the
Expand Down Expand Up @@ -78,6 +78,14 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override;

clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType) override;
std::string GetPushForwardFunctionSuffix() override;
DiffMode GetPushForwardMode() override;

// Function for setting the independent variables for vector mode.
void SetIndependentVarsExpr(clang::Expr* IndVarsExpr);
};
} // end namespace clad

Expand Down
21 changes: 21 additions & 0 deletions include/clad/Differentiator/VectorPushForwardModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H

#include "PushForwardModeVisitor.h"
#include "VectorForwardModeVisitor.h"

namespace clad {
class VectorPushForwardModeVisitor : public PushForwardModeVisitor,
public VectorForwardModeVisitor {

public:
VectorPushForwardModeVisitor(DerivativeBuilder& builder);
~VectorPushForwardModeVisitor() override;

void ExecuteInsideFunctionBlock() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H
24 changes: 16 additions & 8 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,19 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
return OverloadedFn;
}

QualType
BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) {
return ParamType;
}

std::string BaseForwardModeVisitor::GetPushForwardFunctionSuffix() {
return "_pushforward";
}

DiffMode BaseForwardModeVisitor::GetPushForwardMode() {
return DiffMode::experimental_pushforward;
}

StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
const FunctionDecl* FD = CE->getDirectCallee();
if (!FD) {
Expand Down Expand Up @@ -1008,10 +1021,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
CallArgs.push_back(argDiff.getExpr());
if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) {
Expr* dArg = argDiff.getExpr_dx();
QualType CallArgTy = CallArgs.back()->getType();
assert((!dArg || m_Context.hasSameType(CallArgTy, dArg->getType())) &&
"Type mismatch, we might fail to instantiate a pullback");
(void)CallArgTy;
// FIXME: What happens when dArg is nullptr?
diffArgs.push_back(dArg);
}
Expand All @@ -1034,14 +1043,13 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

// Try to find a user-defined overloaded derivative.
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function) &&
m_Mode == DiffMode::experimental_pushforward) {
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
Expand All @@ -1060,7 +1068,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// derive the called function.
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = DiffMode::experimental_pushforward;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName = FD->getNameAsString();
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder;
// Silence diag outputs in nested derivation process.
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ add_llvm_library(cladDifferentiator
DiffPlanner.cpp
ErrorEstimator.cpp
EstimationModel.cpp
ForwardModeVisitor.cpp
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
PushForwardModeVisitor.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Version.cpp
VisitorBase.cpp
${version_inc}
Expand Down
8 changes: 6 additions & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/ForwardModeVisitor.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/PushForwardModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"
#include "clad/Differentiator/VectorPushForwardModeVisitor.h"

#include <algorithm>

Expand Down Expand Up @@ -209,11 +210,14 @@ namespace clad {
BaseForwardModeVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::experimental_pushforward) {
ForwardModeVisitor V(*this);
PushForwardModeVisitor V(*this);
result = V.DerivePushforward(FD, request);
} else if (request.Mode == DiffMode::vector_forward_mode) {
VectorForwardModeVisitor V(*this);
result = V.DeriveVectorMode(FD, request);
} else if (request.Mode == DiffMode::experimental_vector_pushforward) {
VectorPushForwardModeVisitor V(*this);
result = V.DerivePushforward(FD, request);
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this);
result = V.Derive(FD, request);
Expand Down
185 changes: 0 additions & 185 deletions lib/Differentiator/ForwardModeVisitor.cpp

This file was deleted.

Loading

0 comments on commit e04b101

Please sign in to comment.