Skip to content

Commit

Permalink
Remove excessive parameters from the Derive functions. Partially addr…
Browse files Browse the repository at this point in the history
…esses #721.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Oct 6, 2024
1 parent 4ac4f77 commit 118b2c9
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 232 deletions.
8 changes: 2 additions & 6 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ class BaseForwardModeVisitor

///\brief Produces the first derivative of a given function.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns The differentiated and potentially created enclosing
/// context.
///
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();

DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePushforward();

/// Returns the return type for the pushforward function of the function
/// `m_DiffReq->Function`.
Expand Down
5 changes: 1 addition & 4 deletions include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,14 @@ namespace clad {
///\brief Produces the hessian second derivative columns of a given
/// function.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns A function containing second derivatives (columns) of a hessian
/// matrix and potentially created enclosing context.
///
/// We name the hessian of f as 'f_hessian'. Uses ForwardModeVisitor and
/// ReverseModeVisitor to generate second derivatives that correspond to
/// columns of the Hessian. uses Merge to return a FunctionDecl
/// containing CallExprs to the generated second derivatives.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();
};
} // end namespace clad

Expand Down
3 changes: 1 addition & 2 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder,
const DiffRequest& request);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();

StmtDiff ProcessSingleStmt(const clang::Stmt* S);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
Expand Down
8 changes: 2 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,6 @@ namespace clad {

///\brief Produces the gradient of a given function.
///
///\param[in] FD - the function that will be differentiated.
///
///\returns The gradient of the function, potentially created enclosing
/// context and if generated, its overload.
///
Expand All @@ -373,10 +371,8 @@ namespace clad {
/// Improved naming scheme is required. Hence, we append the indices to of
/// the requested parameters to 'f_grad', i.e. in the previous example "x,
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload Derive();
DerivativeAndOverload DerivePullback();
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down
48 changes: 21 additions & 27 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,15 @@ bool IsRealNonReferenceType(QualType T) {
return T.getNonReferenceType()->isRealType();
}

DerivativeAndOverload
BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request && "Can't pass two different requests!");
m_Functor = request.Functor;
DerivativeAndOverload BaseForwardModeVisitor::Derive() {
m_Functor = m_DiffReq.Functor;
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::forward);
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

DiffInputVarsInfo DVI = request.DVI;

DVI = request.DVI;
DiffInputVarsInfo DVI = m_DiffReq.DVI;

// FIXME: Shouldn't we give error here that no arg is specified?
if (DVI.empty())
Expand All @@ -84,7 +80,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) &&
(diffVarInfo.paramIndexInterval.size() != 1))) {
diag(DiagnosticsEngine::Error,
request.Args ? request.Args->getEndLoc() : noLoc,
m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc,
"Forward mode differentiation w.r.t. several parameters at once is "
"not "
"supported, call 'clad::differentiate' for each parameter "
Expand Down Expand Up @@ -129,7 +125,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
isField = true;
}
if (!IsRealNonReferenceType(T)) {
diag(DiagnosticsEngine::Error, request.Args->getEndLoc(),
diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(),
"Attempted differentiation w.r.t. %0 '%1' which is not "
"of real type.",
{(isField ? "member" : "parameter"), diffVarInfo.source});
Expand Down Expand Up @@ -157,12 +153,12 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
argInfo += "_" + field;

std::string s;
if (request.CurrentDerivativeOrder > 1)
s = std::to_string(request.CurrentDerivativeOrder);
if (m_DiffReq.CurrentDerivativeOrder > 1)
s = std::to_string(m_DiffReq.CurrentDerivativeOrder);

// Check if the function is already declared as a custom derivative.
std::string gradientName =
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix;
std::string gradientName = m_DiffReq.BaseFunctionName + "_d" + s + "arg" +
argInfo + derivativeSuffix;
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down Expand Up @@ -221,7 +217,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
derivedFD->setParams(paramsRef);
derivedFD->setBody(nullptr);

if (!request.DeclarationOnly) {
if (!m_DiffReq.DeclarationOnly) {
// Function body scope
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
Expand Down Expand Up @@ -365,9 +361,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down Expand Up @@ -401,13 +398,9 @@ void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
addToCurrentBlock(S);
}

DerivativeAndOverload
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
// FIXME: We must not reset the diff request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
m_Functor = request.Functor;
DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
const FunctionDecl* FD = m_DiffReq.Function;
m_Functor = m_DiffReq.Functor;
assert(m_DiffReq.Mode == GetPushForwardMode());
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
Expand Down Expand Up @@ -517,7 +510,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

if (!request.DeclarationOnly) {
if (!m_DiffReq.DeclarationOnly) {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();
Expand All @@ -532,9 +525,10 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
}

m_Sema.PopFunctionScopeInfo();
Expand Down
18 changes: 9 additions & 9 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,44 +399,44 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
DerivativeAndOverload result{};
if (request.Mode == DiffMode::forward) {
BaseForwardModeVisitor V(*this, request);
result = V.Derive(FD, request);
result = V.Derive();
} else if (request.Mode == DiffMode::experimental_pushforward) {
PushForwardModeVisitor V(*this, request);
result = V.DerivePushforward(FD, request);
result = V.DerivePushforward();
} else if (request.Mode == DiffMode::vector_forward_mode) {
VectorForwardModeVisitor V(*this, request);
result = V.DeriveVectorMode(FD, request);
} else if (request.Mode == DiffMode::experimental_vector_pushforward) {
VectorPushForwardModeVisitor V(*this, request);
result = V.DerivePushforward(FD, request);
result = V.DerivePushforward();
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this, request);
result = V.Derive(FD, request);
result = V.Derive();
} else if (request.Mode == DiffMode::experimental_pullback) {
ReverseModeVisitor V(*this, request);
if (!m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.DerivePullback(FD, request);
result = V.DerivePullback();
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
result = V.Derive();
} else if (request.Mode == DiffMode::hessian ||
request.Mode == DiffMode::hessian_diagonal) {
HessianModeVisitor H(*this, request);
result = H.Derive(FD, request);
result = H.Derive();
} else if (request.Mode == DiffMode::jacobian) {
ReverseModeVisitor R(*this, request);
result = R.Derive(FD, request);
result = R.Derive();
} else if (request.Mode == DiffMode::error_estimation) {
ReverseModeVisitor R(*this, request);
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
R.AddExternalSource(*m_ErrorEstHandler.back());
// Finally begin estimation.
result = R.Derive(FD, request);
result = R.Derive();
// Once we are done, we want to clear the model for any further
// calls to estimate_error.
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
Expand Down
Loading

0 comments on commit 118b2c9

Please sign in to comment.