Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the reverse mode by having just one DiffMode for it #964

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ namespace custom_derivatives {
double my_pow_darg1(dobule x, double y) { return my_pow(x, y) * std::log(x); }
}
```
You can also specify a custom gradient:
You can also specify a custom pullback:
```cpp
namespace custom_derivatives {
void my_pow_grad(double x, double y, array_ref<double> _d_x, array_ref<double> _d_y) {
void my_pow_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
double t = my_pow(x, y - 1);
*_d_x = y * t;
*_d_x = y * t * _d_y0;
*_d_y = x * t * std::log(x);
}
}
Expand Down
12 changes: 6 additions & 6 deletions demos/ErrorEstimation/CustomModel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,26 @@ clang++ -Xclang -add-plugin -Xclang clad -Xclang -load -Xclang CLAD_INST/lib/cla
To verify your results, you can build the dummy `test.cpp` file with the commands shown above. Once you compile and run the test file correctly, you will notice the generated code is as follows:

```cpp
The code is: void func_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
The code is: void func_pullback(float x, float y, float _d_y0, float *_d_x, float *_d_y, double *_final_error) {
float _d_z = 0;
float _t0;
float z;
_t0 = z;
z = x + y;
_d_z += 1;
_d_z += _d_y0;
{
_final_error += _d_z * z;
*_final_error += _d_z * z;
z = _t0;
float _r_d0 = _d_z;
_d_z = 0;
*_d_x += _r_d0;
*_d_y += _r_d0;
}
_final_error += *_d_x * x;
_final_error += *_d_y * y;
*_final_error += *_d_x * x;
*_final_error += *_d_y * y;
}
```

Here, notice that the result in the `_final_error` variable now reflects the error expression defined in the custom model we just compiled!
Here, notice that the result in the `*_final_error` variable now reflects the error expression defined in the custom model we just compiled!

This demo is also a runnable test under `CLAD_BASE/test/Misc/RunDemos.C` and will run as a part of the lit test suite. Thus, the same can be verified by running `make check-clad`.
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/FloatSum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ int main() {
finalError = 0;
unsigned int dn = 0;
// First execute the derived function.
df.execute(x, n, &ret[0], &dn, finalError);
df.execute(x, n, &ret[0], &dn, &finalError);

double kahanResult = kahanSum(x, n);
double vanillaResult = vanillaSum(x, n);
Expand Down
12 changes: 6 additions & 6 deletions demos/ErrorEstimation/PrintModel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,26 @@ clang++ -Xclang -add-plugin -Xclang clad -Xclang -load -Xclang CLAD_INST/lib/cla
To verify your results, you can build the dummy `test.cpp` file with the commands shown above. Once you compile and run the test file correctly, you will notice the generated code is as follows:

```cpp
The code is: void func_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
The code is: void func_pullback(float x, float y, float _d_y0, float *_d_x, float *_d_y, double *_final_error) {
float _d_z = 0;
float _t0;
float z;
_t0 = z;
z = x + y;
_d_z += 1;
_d_z += _d_y0;
{
_final_error += _d_z * z;
*_final_error += _d_z * z;
z = _t0;
float _r_d0 = _d_z;
_d_z = 0;
*_d_x += _r_d0;
*_d_y += _r_d0;
}
_final_error += *_d_x * x;
_final_error += *_d_y * y;
*_final_error += *_d_x * x;
*_final_error += *_d_y * y;
}
```

Here, notice that the result in the `_final_error` variable now reflects the error expression defined in the custom model we just compiled!
Here, notice that the result in the `*_final_error` variable now reflects the error expression defined in the custom model we just compiled!

This demo is also a runnable test under `CLAD_BASE/test/Misc/RunDemos.C` and will run as a part of the lit test suite. Thus, the same can be verified by running `make check-clad`.
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/PrintModel/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ int main() {
// Calculate the error
float dx, dy;
double error;
df.execute(2, 3, &dx, &dy, error);
df.execute(2, 3, &dx, &dy, &error);
}
4 changes: 2 additions & 2 deletions demos/Jupyter/Intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
"output_type": "stream",
"text": [
"The code is: \n",
"void fn_grad(double x, double y, double *_d_x, double *_d_y, double &_final_error) {\n",
"void fn_grad(double x, double y, double *_d_x, double *_d_y, double *_final_error) {\n",
" double _t2;\n",
" double _t3;\n",
" double _t4;\n",
Expand Down Expand Up @@ -390,7 +390,7 @@
" _delta_x += std::abs(*_d_x * x * 1.1920928955078125E-7);\n",
" double _delta_y = 0;\n",
" _delta_y += std::abs(*_d_y * y * 1.1920928955078125E-7);\n",
" _final_error += _delta_y + _delta_x + std::abs(1. * _ret_value0 * 1.1920928955078125E-7);\n",
" *_final_error += _delta_y + _delta_x + std::abs(1. * _ret_value0 * 1.1920928955078125E-7);\n",
"}\n",
"\n"
]
Expand Down
3 changes: 0 additions & 3 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ enum class DiffMode {
forward,
vector_forward_mode,
experimental_pushforward,
experimental_pullback,
experimental_vector_pushforward,
reverse,
hessian,
Expand All @@ -26,8 +25,6 @@ inline const char* DiffModeToString(DiffMode mode) {
return "vector_forward_mode";
case DiffMode::experimental_pushforward:
return "pushforward";
case DiffMode::experimental_pullback:
return "pullback";
case DiffMode::experimental_vector_pushforward:
return "vector_pushforward";
case DiffMode::reverse:
Expand Down
15 changes: 7 additions & 8 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,7 @@ namespace clad {
// GradientDerivedEstFnTraits specializations for pure function pointer types
template <class ReturnType, class... Args>
struct GradientDerivedEstFnTraits<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputParamType_t<Args, Args>...,
double&);
using type = void (*)(Args..., OutputParamType_t<Args, void>..., void*);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this became a void* from a double&? What is the benefit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary to enable overloads in error estimation. double& cannot be converted to void* so I replaced it with void* in the overload and double* in other places. This is the reason _final_error became pointer-type.

};

/// These macro expansions are used to cover all possible cases of
Expand All @@ -495,12 +494,12 @@ namespace clad {
/// qualifier and reference respectively. The AddNOEX adds cases for noexcept
/// qualifier only if it is supported and finally AddSPECS declares the
/// function with all the cases
#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \
template <typename R, typename C, typename... Args> \
struct GradientDerivedEstFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)(Args..., OutputParamType_t<Args, Args>..., \
double&) cv vol ref noex; \
};
#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function-like macro 'GradientDerivedEstFnTraits_AddSPECS' used; consider a 'constexpr' template function [cppcoreguidelines-macro-usage]

#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex)         \
        ^

template <typename R, typename C, typename... Args> \
struct GradientDerivedEstFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)(Args..., OutputParamType_t<Args, void>..., \
void*) cv vol ref noex; \
};

#if __cpp_noexcept_function_type > 0
#define GradientDerivedEstFnTraits_AddNOEX(var, con, vol, ref) \
Expand Down
40 changes: 31 additions & 9 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,36 @@ namespace clad {
// 'MultiplexExternalRMVSource.h' file
MultiplexExternalRMVSource* m_ExternalSource = nullptr;
clang::Expr* m_Pullback = nullptr;
const char* funcPostfix() const {
if (m_DiffReq.Mode == DiffMode::jacobian)
return "_jac";
if (m_DiffReq.use_enzyme)
return "_grad_enzyme";
return "_grad";

static std::string diffParamsPostfix(const DiffRequest& request) {
std::string postfix;
const DiffInputVarsInfo& DVI = request.DVI;
std::size_t numParams = request->getNumParams();
// If Jacobian is asked, the last parameter is the result parameter
// and should be ignored
if (request.Mode == DiffMode::jacobian)
numParams -= 1;
// To be consistent with older tests, nothing is appended to 'f_grad' if
// we differentiate w.r.t. all the parameters at once.
if (DVI.size() != numParams)
for (const auto& dParam : DVI) {
const clang::ValueDecl* arg = dParam.param;
const auto* begin = request->param_begin();
const auto* end = std::next(begin, numParams);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: narrowing conversion from 'std::size_t' (aka 'unsigned long') to signed type 'typename iterator_traits<ParmVarDecl *const *>::difference_type' (aka 'long') is implementation-defined [cppcoreguidelines-narrowing-conversions]

          const auto* end = std::next(begin, numParams);
                                             ^

const auto* it = std::find(begin, end, arg);
auto idx = std::distance(begin, it);
postfix += ('_' + std::to_string(idx));
}
return postfix;
}

static std::string funcPostfix(const DiffRequest& request) {
std::string postfix = "_pullback";
if (request.Mode == DiffMode::jacobian)
postfix = "_jac";
if (request.use_enzyme)
postfix = "_grad_enzyme";
return postfix + diffParamsPostfix(request);
}

/// Removes the local const qualifiers from a QualType and returns a new
Expand Down Expand Up @@ -361,8 +385,6 @@ namespace clad {
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request);
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down Expand Up @@ -444,7 +466,7 @@ namespace clad {
/// Builds an overload for the gradient function that has derived params for
/// all the arguments of the requested function and it calls the original
/// gradient function internally
clang::FunctionDecl* CreateGradientOverload();
clang::FunctionDecl* CreateGradientOverload(unsigned numExtraParams);

/// Returns the type that should be used to represent the derivative of a
/// variable of type `yType` with respect to a parameter variable of type
Expand Down
9 changes: 3 additions & 6 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,15 +419,12 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
result = V.DerivePushforward(FD, request);
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this, request);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::experimental_pullback) {
ReverseModeVisitor V(*this, request);
if (!m_ErrorEstHandler.empty()) {
if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
result = V.Derive(FD, request);
if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,10 @@ namespace clad {
}

void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) {
// Diff info for pullbacks is generated automatically,
// Some Diff info is generated automatically,
// its parameters are not provided by the user.
if (Mode == DiffMode::experimental_pullback) {
// Update parameters only if no DVI info is present.
if (Mode == DiffMode::reverse && !DVI.empty()) {
// Might need to update DVI args, as they may be pointing to the
// declaration parameters, not the definition parameters.
if (!Function->getPreviousDecl())
Expand Down
7 changes: 4 additions & 3 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ void ErrorEstimationHandler::ActAfterCreatingDerivedFnParamTypes(
// If we are performing error estimation, our gradient function
// will have an extra argument which will hold the final error value
paramTypes.push_back(
m_RMV->m_Context.getLValueReferenceType(m_RMV->m_Context.DoubleTy));
m_RMV->m_Context.getPointerType(m_RMV->m_Context.DoubleTy));
}

void ErrorEstimationHandler::ActAfterCreatingDerivedFnParams(
Expand All @@ -307,7 +307,8 @@ void ErrorEstimationHandler::ActAfterCreatingDerivedFnParams(

void ErrorEstimationHandler::ActBeforeCreatingDerivedFnBodyScope() {
// Reference to the final error statement
SetFinalErrorExpr(m_RMV->BuildDeclRef(m_Params->back()));
Expr* finalErrorPointer = m_RMV->BuildDeclRef(m_Params->back());
SetFinalErrorExpr(m_RMV->BuildOp(UO_Deref, finalErrorPointer));
}

void ErrorEstimationHandler::ActOnEndOfDerivedFnBody() {
Expand Down Expand Up @@ -473,7 +474,7 @@ void ErrorEstimationHandler::ActBeforeDifferentiatingCallExpr(
m_RMV->getZeroInit(m_RMV->m_Context.DoubleTy));
ArgDecls.push_back(m_RMV->BuildDeclStmt(errorRef));
auto finErr = m_RMV->BuildDeclRef(errorRef);
pullbackArgs.push_back(finErr);
pullbackArgs.push_back(m_RMV->BuildOp(UO_AddrOf, finErr));
if (hasAssignee) {
if (m_NestedFuncError)
m_NestedFuncError = m_RMV->BuildOp(BO_Add, m_NestedFuncError, finErr);
Expand Down
10 changes: 10 additions & 0 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ static FunctionDecl* DeriveUsingForwardModeTwice(
});
DeclRefToParams.pop_back();

if (m_DiffReq.Mode == DiffMode::hessian) {
// Pass 1 as the middle parameter to the pullback to effectively get the
// gradient.
QualType intTy = m_Context.IntTy;
llvm::APInt APVal(m_Context.getIntWidth(intTy), 1);
Expr* one =
IntegerLiteral::Create(m_Context, APVal, m_Context.IntTy, noLoc);
DeclRefToParams.push_back(one);
}

/// If we are differentiating a member function then create a parameter
/// that can represent the derivative for the implicit `this` pointer. It
/// is required because reverse mode derived function expects an explicit
Expand Down
Loading
Loading