-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify the reverse mode by having just one DiffMode for it #964
base: master
Are you sure you want to change the base?
Changes from all commits
f0f031f
6385045
cab276d
9cd45b5
a0eba51
5fb5536
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -482,8 +482,7 @@ namespace clad { | |
// GradientDerivedEstFnTraits specializations for pure function pointer types | ||
template <class ReturnType, class... Args> | ||
struct GradientDerivedEstFnTraits<ReturnType (*)(Args...)> { | ||
using type = void (*)(Args..., OutputParamType_t<Args, Args>..., | ||
double&); | ||
using type = void (*)(Args..., OutputParamType_t<Args, void>..., void*); | ||
}; | ||
|
||
/// These macro expansions are used to cover all possible cases of | ||
|
@@ -495,12 +494,12 @@ namespace clad { | |
/// qualifier and reference respectively. The AddNOEX adds cases for noexcept | ||
/// qualifier only if it is supported and finally AddSPECS declares the | ||
/// function with all the cases | ||
#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \ | ||
template <typename R, typename C, typename... Args> \ | ||
struct GradientDerivedEstFnTraits<R (C::*)(Args...) cv vol ref noex> { \ | ||
using type = void (C::*)(Args..., OutputParamType_t<Args, Args>..., \ | ||
double&) cv vol ref noex; \ | ||
}; | ||
#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: function-like macro 'GradientDerivedEstFnTraits_AddSPECS' used; consider a 'constexpr' template function [cppcoreguidelines-macro-usage] #define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \
^ |
||
template <typename R, typename C, typename... Args> \ | ||
struct GradientDerivedEstFnTraits<R (C::*)(Args...) cv vol ref noex> { \ | ||
using type = void (C::*)(Args..., OutputParamType_t<Args, void>..., \ | ||
void*) cv vol ref noex; \ | ||
}; | ||
|
||
#if __cpp_noexcept_function_type > 0 | ||
#define GradientDerivedEstFnTraits_AddNOEX(var, con, vol, ref) \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,12 +73,36 @@ namespace clad { | |
// 'MultiplexExternalRMVSource.h' file | ||
MultiplexExternalRMVSource* m_ExternalSource = nullptr; | ||
clang::Expr* m_Pullback = nullptr; | ||
const char* funcPostfix() const { | ||
if (m_DiffReq.Mode == DiffMode::jacobian) | ||
return "_jac"; | ||
if (m_DiffReq.use_enzyme) | ||
return "_grad_enzyme"; | ||
return "_grad"; | ||
|
||
static std::string diffParamsPostfix(const DiffRequest& request) { | ||
std::string postfix; | ||
const DiffInputVarsInfo& DVI = request.DVI; | ||
std::size_t numParams = request->getNumParams(); | ||
// If Jacobian is asked, the last parameter is the result parameter | ||
// and should be ignored | ||
if (request.Mode == DiffMode::jacobian) | ||
numParams -= 1; | ||
// To be consistent with older tests, nothing is appended to 'f_grad' if | ||
// we differentiate w.r.t. all the parameters at once. | ||
if (DVI.size() != numParams) | ||
for (const auto& dParam : DVI) { | ||
const clang::ValueDecl* arg = dParam.param; | ||
const auto* begin = request->param_begin(); | ||
const auto* end = std::next(begin, numParams); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: narrowing conversion from 'std::size_t' (aka 'unsigned long') to signed type 'typename iterator_traits<ParmVarDecl *const *>::difference_type' (aka 'long') is implementation-defined [cppcoreguidelines-narrowing-conversions] const auto* end = std::next(begin, numParams);
^ |
||
const auto* it = std::find(begin, end, arg); | ||
auto idx = std::distance(begin, it); | ||
postfix += ('_' + std::to_string(idx)); | ||
} | ||
return postfix; | ||
} | ||
|
||
static std::string funcPostfix(const DiffRequest& request) { | ||
std::string postfix = "_pullback"; | ||
if (request.Mode == DiffMode::jacobian) | ||
postfix = "_jac"; | ||
if (request.use_enzyme) | ||
postfix = "_grad_enzyme"; | ||
return postfix + diffParamsPostfix(request); | ||
} | ||
|
||
/// Removes the local const qualifiers from a QualType and returns a new | ||
|
@@ -361,8 +385,6 @@ namespace clad { | |
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'. | ||
DerivativeAndOverload Derive(const clang::FunctionDecl* FD, | ||
const DiffRequest& request); | ||
DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD, | ||
const DiffRequest& request); | ||
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); | ||
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); | ||
StmtDiff VisitCallExpr(const clang::CallExpr* CE); | ||
|
@@ -444,7 +466,7 @@ namespace clad { | |
/// Builds an overload for the gradient function that has derived params for | ||
/// all the arguments of the requested function and it calls the original | ||
/// gradient function internally | ||
clang::FunctionDecl* CreateGradientOverload(); | ||
clang::FunctionDecl* CreateGradientOverload(unsigned numExtraParams); | ||
|
||
/// Returns the type that should be used to represent the derivative of a | ||
/// variable of type `yType` with respect to a parameter variable of type | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this became a
void*
from adouble&
? What is the benefit?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary to enable overloads in error estimation.
double&
cannot be converted tovoid*
so I replaced it withvoid*
in the overload anddouble*
in other places. This is the reason_final_error
became pointer-type.