Skip to content

Commit

Permalink
Write numerical diff results to c-style arrays in the generated code.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed May 8, 2024
1 parent 3f0e5a1 commit dc8dbb6
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 47 deletions.
77 changes: 73 additions & 4 deletions include/clad/Differentiator/NumericalDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,76 @@ namespace numerical_diff {
}
}

/// A helper function to calculate the numerical derivative of a target
/// function.
///
/// \param[in] \c f The target function to numerically differentiate.
/// \param[out] \c _grad The gradient array reference to which the gradients
/// will be written.
/// \param[in] \c printErrors A flag to decide if we want to print numerical
/// diff errors estimates.
/// \param[in] \c idxSeq The index sequence associated with
/// the input parameter pack.
/// \param[in] \c args The arguments to the function to differentiate.
template <typename F, std::size_t... Ints,
typename RetType = typename clad::function_traits<F>::return_type,
typename... Args>
void central_difference_helper(F f, RetType* _grad, bool printErrors,
clad::IndexSequence<Ints...> idxSeq,
Args&&... args) {

std::size_t argLen = sizeof...(Args);
// loop over all the args, selecting each arg to get the derivative with
// respect to.
for (std::size_t i = 0; i < argLen; i++) {
precision h = 0;
// calculate f[x+h, x-h]
// f(..., x+h,...)
precision xaf = f(updateIndexParamValue(std::forward<Args>(args), Ints, i,
/*multiplier=*/1, h)...);
precision xbf = f(updateIndexParamValue(std::forward<Args>(args), Ints, i,
/*multiplier=*/-1, h)...);
precision xf1 = (xaf - xbf) / (h + h);

// calculate f[x+2h, x-2h]
precision xaf2 =
f(updateIndexParamValue(std::forward<Args>(args), Ints, i,
/*multiplier=*/2, h)...);
precision xbf2 =
f(updateIndexParamValue(std::forward<Args>(args), Ints, i,
/*multiplier=*/-2, h)...);
precision xf2 = (xaf2 - xbf2) / (2 * h + 2 * h);

if (printErrors) {
// calculate f(x+3h) and f(x-3h)
precision xaf3 =
f(updateIndexParamValue(std::forward<Args>(args), Ints, i,
/*multiplier=*/3, h)...);
precision xbf3 =
f(updateIndexParamValue(std::forward<Args>(args), Ints, i,
/*multiplier=*/-3, h)...);
// Error in derivative due to the five-point stencil formula
// E(f'(x)) = f`````(x) * h^4 / 30 + O(h^5) (Taylor Approx) and
// f`````(x) = (f[x+3h, x-3h] - 4f[x+2h, x-2h] + 5f[x+h, x-h])/(2 *
// h^5) Formula courtesy of 'Abramowitz, Milton; Stegun, Irene A.
// (1970), Handbook of Mathematical Functions with Formulas, Graphs,
// and Mathematical Tables, Dover. Ninth printing. Table 25.2.`.
precision error =
((xaf3 - xbf3) - 4 * (xaf2 - xbf2) + 5 * (xaf - xbf)) / (60 * h);
// This is the error in evaluation of all the function values.
precision evalError = std::numeric_limits<precision>::epsilon() *
(std::fabs(xaf2) + std::fabs(xbf2) +
8 * (std::fabs(xaf) + std::fabs(xbf))) /
(12 * h);
// Finally print the error to standard ouput.
printError(std::fabs(error), evalError, i);
}

// five-point stencil formula = (4f[x+h, x-h] - f[x+2h, x-2h])/3
_grad[i] = 4.0 * xf1 / 3.0 - xf2 / 3.0;
}
}

/// A function to calculate the derivative of a function using the central
/// difference formula. Note: we do not propogate errors resulting in the
/// following function, it is likely the errors are large enough to be of
Expand All @@ -338,11 +408,10 @@ namespace numerical_diff {
/// \param[in] \c printErrors A flag to decide if we want to print numerical
/// diff errors estimates.
/// \param[in] \c args The arguments to the function to differentiate.
template <typename F, std::size_t... Ints,
typename RetType = typename clad::function_traits<F>::return_type,
template <typename F, std::size_t... Ints, typename GradType,
typename... Args>
void central_difference(F f, clad::tape_impl<clad::array_ref<RetType>>& _grad,
bool printErrors, Args&&... args) {
void central_difference(F f, GradType& _grad, bool printErrors,
Args&&... args) {
return central_difference_helper(f, _grad, printErrors,
clad::MakeIndexSequence<sizeof...(Args)>{},
std::forward<Args>(args)...);
Expand Down
45 changes: 18 additions & 27 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
Expand Down Expand Up @@ -2007,40 +2008,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
llvm::SmallVector<Expr*, 16U> NumDiffArgs = {};
NumDiffArgs.push_back(targetFuncCall);
// build the clad::tape<clad::array_ref>> = {};
QualType RefType = GetCladArrayRefOfType(retType);
QualType TapeType = GetCladTapeOfType(RefType);
auto* VD = BuildVarDecl(
TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false,
/*TSI=*/nullptr, VarDecl::InitializationStyle::CInit);
// build the output array declaration.
Expr* size =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, numArgs);
QualType GradType = m_Context.getConstantArrayType(
retType, llvm::APInt(m_Context.getTargetInfo().getIntWidth(), numArgs),
size, ArrayType::ArraySizeModifier::Normal, 0);
Expr* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expr* init = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
auto* VD = BuildVarDecl(GradType, "_grad", init);

PreCallStmts.push_back(BuildDeclStmt(VD));
Expr* TapeRef = BuildDeclRef(VD);
NumDiffArgs.push_back(TapeRef);
NumDiffArgs.push_back(BuildDeclRef(VD));
NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, printErrorInf));

// Build the tape push expressions.
VD->setLocation(m_Function->getLocation());
m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false);
CXXScopeSpec CSS;
CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc);
LookupResult& Push = GetCladTapePush();
Expr* PushDRE =
m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get();
for (unsigned i = 0, e = numArgs; i < e; i++) {
QualType argTy = args[i]->getType();
VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy));
PreCallStmts.push_back(BuildDeclStmt(gradVar));
Expr* PushExpr = BuildDeclRef(gradVar);
if (!isCladArrayType(argTy))
PushExpr = BuildOp(UO_AddrOf, PushExpr);
std::array<Expr*, 2> callArgs = {TapeRef, PushExpr};
Stmt* PushStmt =
m_Sema
.ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc)
.get();
PreCallStmts.push_back(PushStmt);
Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar));
Expr* gradRef = BuildDeclRef(VD);
Expr* idx =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i);
Expr* gradElem = BuildArraySubscript(gradRef, {idx});
Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem);
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
NumDiffArgs.push_back(args[i]);
}
Expand Down
12 changes: 4 additions & 8 deletions test/NumericalDiff/GradientMultiArg.C
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ double test_1(double x, double y){
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: clad::tape<clad::array_ref<double> > _t0 = {};
// CHECK-NEXT: double _grad0 = 0;
// CHECK-NEXT: clad::push(_t0, &_grad0);
// CHECK-NEXT: double _grad1 = 0;
// CHECK-NEXT: clad::push(_t0, &_grad1);
// CHECK-NEXT: numerical_diff::central_difference(std::hypot, _t0, 0, x, y);
// CHECK-NEXT: _r0 += 1 * _grad0;
// CHECK-NEXT: _r1 += 1 * _grad1;
// CHECK-NEXT: double _grad0[2] = {0};
// CHECK-NEXT: numerical_diff::central_difference(std::hypot, _grad0, 0, x, y);
// CHECK-NEXT: _r0 += 1 * _grad0[0];
// CHECK-NEXT: _r1 += 1 * _grad0[1];
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: *_d_y += _r1;
// CHECK-NEXT: }
Expand Down
12 changes: 4 additions & 8 deletions test/NumericalDiff/NumDiff.C
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,10 @@ double test_3(double x) {
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: clad::tape<clad::array_ref<double> > _t0 = {};
//CHECK-NEXT: double _grad0 = 0;
//CHECK-NEXT: clad::push(_t0, &_grad0);
//CHECK-NEXT: double _grad1 = 0;
//CHECK-NEXT: clad::push(_t0, &_grad1);
//CHECK-NEXT: numerical_diff::central_difference(std::hypot, _t0, 0, x, constant);
//CHECK-NEXT: _r0 += 1 * _grad0;
//CHECK-NEXT: _r1 += 1 * _grad1;
//CHECK-NEXT: double _grad0[2] = {0};
//CHECK-NEXT: numerical_diff::central_difference(std::hypot, _grad0, 0, x, constant);
//CHECK-NEXT: _r0 += 1 * _grad0[0];
//CHECK-NEXT: _r1 += 1 * _grad0[1];
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: _d_constant += _r1;
//CHECK-NEXT: }
Expand Down

0 comments on commit dc8dbb6

Please sign in to comment.