Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove excessive parameters from Derive functions #1110

Merged
merged 2 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ class BaseForwardModeVisitor

///\brief Produces the first derivative of a given function.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns The differentiated and potentially created enclosing
/// context.
///
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();

DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePushforward();

/// Returns the return type for the pushforward function of the function
/// `m_DiffReq->Function`.
Expand Down
5 changes: 1 addition & 4 deletions include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,14 @@ namespace clad {
///\brief Produces the hessian second derivative columns of a given
/// function.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns A function containing second derivatives (columns) of a hessian
/// matrix and potentially created enclosing context.
///
/// We name the hessian of f as 'f_hessian'. Uses ForwardModeVisitor and
/// ReverseModeVisitor to generate second derivatives that correspond to
/// columns of the Hessian. uses Merge to return a FunctionDecl
/// containing CallExprs to the generated second derivatives.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();
};
} // end namespace clad

Expand Down
3 changes: 1 addition & 2 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();

StmtDiff ProcessSingleStmt(const clang::Stmt* S);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
Expand Down
8 changes: 2 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,6 @@ namespace clad {

///\brief Produces the gradient of a given function.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns The gradient of the function, potentially created enclosing
/// context and if generated, its overload.
///
Expand All @@ -373,10 +371,8 @@ namespace clad {
/// Improved naming scheme is required. Hence, we append the indices to of
/// the requested parameters to 'f_grad', i.e. in the previous example "x,
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();
DerivativeAndOverload DerivePullback();
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down
2 changes: 0 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ namespace clad {
std::vector<Stmts> m_Blocks;
/// Stores output variables for vector-valued functions
VectorOutputs m_VectorOutput;
/// The functor type that is currently being differentiated, if any.
const clang::CXXRecordDecl* m_Functor = nullptr;
/// Stores derivative expression of the implicit `this` pointer.
///
/// In the forward mode, `this` pointer derivative expression is of pointer
Expand Down
64 changes: 28 additions & 36 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,14 @@ bool IsRealNonReferenceType(QualType T) {
return T.getNonReferenceType()->isRealType();
}

DerivativeAndOverload
BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request && "Can't pass two different requests!");
m_Functor = request.Functor;
DerivativeAndOverload BaseForwardModeVisitor::Derive() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::forward);
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

DiffInputVarsInfo DVI = request.DVI;

DVI = request.DVI;
DiffInputVarsInfo DVI = m_DiffReq.DVI;

// FIXME: Shouldn't we give error here that no arg is specified?
if (DVI.empty())
Expand All @@ -84,7 +79,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) &&
(diffVarInfo.paramIndexInterval.size() != 1))) {
diag(DiagnosticsEngine::Error,
request.Args ? request.Args->getEndLoc() : noLoc,
m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc,
"Forward mode differentiation w.r.t. several parameters at once is "
"not "
"supported, call 'clad::differentiate' for each parameter "
Expand Down Expand Up @@ -129,7 +124,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
isField = true;
}
if (!IsRealNonReferenceType(T)) {
diag(DiagnosticsEngine::Error, request.Args->getEndLoc(),
diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(),
"Attempted differentiation w.r.t. %0 '%1' which is not "
"of real type.",
{(isField ? "member" : "parameter"), diffVarInfo.source});
Expand All @@ -142,11 +137,11 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
unsigned argIndex = ~0;
if (m_DiffReq->param_empty() && m_Functor)
argIndex =
std::distance(m_Functor->field_begin(),
std::find(m_Functor->field_begin(),
m_Functor->field_end(), m_IndependentVar));
const CXXRecordDecl* functor = m_DiffReq.Functor;
if (m_DiffReq->param_empty() && functor)
argIndex = std::distance(functor->field_begin(),
std::find(functor->field_begin(),
functor->field_end(), m_IndependentVar));
else
argIndex = std::distance(
FD->param_begin(),
Expand All @@ -157,12 +152,12 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
argInfo += "_" + field;

std::string s;
if (request.CurrentDerivativeOrder > 1)
s = std::to_string(request.CurrentDerivativeOrder);
if (m_DiffReq.CurrentDerivativeOrder > 1)
s = std::to_string(m_DiffReq.CurrentDerivativeOrder);

// Check if the function is already declared as a custom derivative.
std::string gradientName =
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix;
std::string gradientName = m_DiffReq.BaseFunctionName + "_d" + s + "arg" +
argInfo + derivativeSuffix;
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down Expand Up @@ -221,7 +216,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
derivedFD->setParams(paramsRef);
derivedFD->setBody(nullptr);

if (!request.DeclarationOnly) {
if (!m_DiffReq.DeclarationOnly) {
// Function body scope
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
Expand Down Expand Up @@ -300,8 +295,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

// Create derived variable for each member variable if we are
// differentiating a call operator.
if (m_Functor) {
for (FieldDecl* fieldDecl : m_Functor->fields()) {
if (m_DiffReq.Functor) {
for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) {
Expr* dInitializer = nullptr;
QualType fieldType = fieldDecl->getType();

Expand Down Expand Up @@ -365,9 +360,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down Expand Up @@ -401,13 +397,8 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
addToCurrentBlock(S);
}

DerivativeAndOverload
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
// FIXME: We must not reset the diff request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
m_Functor = request.Functor;
DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == GetPushForwardMode());
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
Expand Down Expand Up @@ -517,7 +508,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

if (!request.DeclarationOnly) {
if (!m_DiffReq.DeclarationOnly) {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();
Expand All @@ -532,9 +523,10 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
}

m_Sema.PopFunctionScopeInfo();
Expand Down Expand Up @@ -890,7 +882,7 @@ StmtDiff BaseForwardModeVisitor::VisitMemberExpr(const MemberExpr* ME) {
auto clonedME = dyn_cast<MemberExpr>(Clone(ME));
// Currently, we only differentiate member variables if we are
// differentiating a call operator.
if (m_Functor) {
if (m_DiffReq.Functor) {
if (isa<CXXThisExpr>(ME->getBase()->IgnoreParenImpCasts())) {
// Try to find the derivative of the member variable wrt independent
// variable
Expand Down Expand Up @@ -962,7 +954,7 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
ValueDecl* VD = nullptr;
// Derived variables for member variables are also created when we are
// differentiating a call operator.
if (m_Functor) {
if (m_DiffReq.Functor) {
if (auto ME = dyn_cast<MemberExpr>(clonedBase->IgnoreParenImpCasts())) {
ValueDecl* decl = ME->getMemberDecl();
auto it = m_Variables.find(decl);
Expand Down
18 changes: 9 additions & 9 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,44 +399,44 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
DerivativeAndOverload result{};
if (request.Mode == DiffMode::forward) {
BaseForwardModeVisitor V(*this, request);
result = V.Derive(FD, request);
result = V.Derive();
} else if (request.Mode == DiffMode::experimental_pushforward) {
PushForwardModeVisitor V(*this, request);
result = V.DerivePushforward(FD, request);
result = V.DerivePushforward();
} else if (request.Mode == DiffMode::vector_forward_mode) {
VectorForwardModeVisitor V(*this, request);
result = V.DeriveVectorMode(FD, request);
} else if (request.Mode == DiffMode::experimental_vector_pushforward) {
VectorPushForwardModeVisitor V(*this, request);
result = V.DerivePushforward(FD, request);
result = V.DerivePushforward();
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this, request);
result = V.Derive(FD, request);
result = V.Derive();
} else if (request.Mode == DiffMode::experimental_pullback) {
ReverseModeVisitor V(*this, request);
if (!m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.DerivePullback(FD, request);
result = V.DerivePullback();
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
result = V.Derive();
} else if (request.Mode == DiffMode::hessian ||
request.Mode == DiffMode::hessian_diagonal) {
HessianModeVisitor H(*this, request);
result = H.Derive(FD, request);
result = H.Derive();
} else if (request.Mode == DiffMode::jacobian) {
ReverseModeVisitor R(*this, request);
result = R.Derive(FD, request);
result = R.Derive();
} else if (request.Mode == DiffMode::error_estimation) {
ReverseModeVisitor R(*this, request);
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
R.AddExternalSource(*m_ErrorEstHandler.back());
// Finally begin estimation.
result = R.Derive(FD, request);
result = R.Derive();
// Once we are done, we want to clear the model for any further
// calls to estimate_error.
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
Expand Down
Loading
Loading