diff --git a/include/clad/Differentiator/NumericalDiff.h b/include/clad/Differentiator/NumericalDiff.h index 792812be4..a33b9d098 100644 --- a/include/clad/Differentiator/NumericalDiff.h +++ b/include/clad/Differentiator/NumericalDiff.h @@ -326,6 +326,76 @@ namespace numerical_diff { } } + /// A helper function to calculate the numerical derivative of a target + /// function. + /// + /// \param[in] \c f The target function to numerically differentiate. + /// \param[out] \c _grad The gradient array reference to which the gradients + /// will be written. + /// \param[in] \c printErrors A flag to decide if we want to print numerical + /// diff errors estimates. + /// \param[in] \c idxSeq The index sequence associated with + /// the input parameter pack. + /// \param[in] \c args The arguments to the function to differentiate. + template ::return_type, + typename... Args> + void central_difference_helper(F f, RetType* _grad, bool printErrors, + clad::IndexSequence idxSeq, + Args&&... args) { + + std::size_t argLen = sizeof...(Args); + // loop over all the args, selecting each arg to get the derivative with + // respect to. + for (std::size_t i = 0; i < argLen; i++) { + precision h = 0; + // calculate f[x+h, x-h] + // f(..., x+h,...) + precision xaf = f(updateIndexParamValue(std::forward(args), Ints, i, + /*multiplier=*/1, h)...); + precision xbf = f(updateIndexParamValue(std::forward(args), Ints, i, + /*multiplier=*/-1, h)...); + precision xf1 = (xaf - xbf) / (h + h); + + // calculate f[x+2h, x-2h] + precision xaf2 = + f(updateIndexParamValue(std::forward(args), Ints, i, + /*multiplier=*/2, h)...); + precision xbf2 = + f(updateIndexParamValue(std::forward(args), Ints, i, + /*multiplier=*/-2, h)...); + precision xf2 = (xaf2 - xbf2) / (2 * h + 2 * h); + + if (printErrors) { + // calculate f(x+3h) and f(x-3h) + precision xaf3 = + f(updateIndexParamValue(std::forward(args), Ints, i, + /*multiplier=*/3, h)...); + precision xbf3 = + f(updateIndexParamValue(std::forward(args), Ints, i, + /*multiplier=*/-3, h)...); + // Error in derivative due to the five-point stencil formula + // E(f'(x)) = f`````(x) * h^4 / 30 + O(h^5) (Taylor Approx) and + // f`````(x) = (f[x+3h, x-3h] - 4f[x+2h, x-2h] + 5f[x+h, x-h])/(2 * + // h^5) Formula courtesy of 'Abramowitz, Milton; Stegun, Irene A. + // (1970), Handbook of Mathematical Functions with Formulas, Graphs, + // and Mathematical Tables, Dover. Ninth printing. Table 25.2.`. + precision error = + ((xaf3 - xbf3) - 4 * (xaf2 - xbf2) + 5 * (xaf - xbf)) / (60 * h); + // This is the error in evaluation of all the function values. + precision evalError = std::numeric_limits::epsilon() * + (std::fabs(xaf2) + std::fabs(xbf2) + + 8 * (std::fabs(xaf) + std::fabs(xbf))) / + (12 * h); + // Finally print the error to standard ouput. + printError(std::fabs(error), evalError, i); + } + + // five-point stencil formula = (4f[x+h, x-h] - f[x+2h, x-2h])/3 + _grad[i] = 4.0 * xf1 / 3.0 - xf2 / 3.0; + } + } + /// A function to calculate the derivative of a function using the central /// difference formula. Note: we do not propogate errors resulting in the /// following function, it is likely the errors are large enough to be of @@ -338,11 +408,10 @@ namespace numerical_diff { /// \param[in] \c printErrors A flag to decide if we want to print numerical /// diff errors estimates. /// \param[in] \c args The arguments to the function to differentiate. - template ::return_type, + template - void central_difference(F f, clad::tape_impl>& _grad, - bool printErrors, Args&&... args) { + void central_difference(F f, GradType& _grad, bool printErrors, + Args&&... args) { return central_difference_helper(f, _grad, printErrors, clad::MakeIndexSequence{}, std::forward(args)...); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 32febed08..c6628ee6f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -20,6 +20,7 @@ #include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "clang/AST/TemplateBase.h" +#include "clang/Basic/TargetInfo.h" #include "clang/Basic/TokenKinds.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" @@ -2007,40 +2008,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); llvm::SmallVector NumDiffArgs = {}; NumDiffArgs.push_back(targetFuncCall); - // build the clad::tape> = {}; - QualType RefType = GetCladArrayRefOfType(retType); - QualType TapeType = GetCladTapeOfType(RefType); - auto* VD = BuildVarDecl( - TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false, - /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); + // build the output array declaration. + Expr* size = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, numArgs); + QualType GradType = m_Context.getConstantArrayType( + retType, llvm::APInt(m_Context.getTargetInfo().getIntWidth(), numArgs), + size, ArrayType::ArraySizeModifier::Normal, 0); + Expr* zero = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + Expr* init = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get(); + auto* VD = BuildVarDecl(GradType, "_grad", init); + PreCallStmts.push_back(BuildDeclStmt(VD)); - Expr* TapeRef = BuildDeclRef(VD); - NumDiffArgs.push_back(TapeRef); + NumDiffArgs.push_back(BuildDeclRef(VD)); NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral( m_Context.IntTy, m_Context, printErrorInf)); // Build the tape push expressions. VD->setLocation(m_Function->getLocation()); - m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false); - CXXScopeSpec CSS; - CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); - LookupResult& Push = GetCladTapePush(); - Expr* PushDRE = - m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); for (unsigned i = 0, e = numArgs; i < e; i++) { - QualType argTy = args[i]->getType(); - VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy)); - PreCallStmts.push_back(BuildDeclStmt(gradVar)); - Expr* PushExpr = BuildDeclRef(gradVar); - if (!isCladArrayType(argTy)) - PushExpr = BuildOp(UO_AddrOf, PushExpr); - std::array callArgs = {TapeRef, PushExpr}; - Stmt* PushStmt = - m_Sema - .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) - .get(); - PreCallStmts.push_back(PushStmt); - Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar)); + Expr* gradRef = BuildDeclRef(VD); + Expr* idx = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i); + Expr* gradElem = BuildArraySubscript(gradRef, {idx}); + Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem); PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); NumDiffArgs.push_back(args[i]); } diff --git a/test/NumericalDiff/GradientMultiArg.C b/test/NumericalDiff/GradientMultiArg.C index 74fd78a9d..14123f9a2 100644 --- a/test/NumericalDiff/GradientMultiArg.C +++ b/test/NumericalDiff/GradientMultiArg.C @@ -21,14 +21,10 @@ double test_1(double x, double y){ // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0; // CHECK-NEXT: double _r1 = 0; -// CHECK-NEXT: clad::tape > _t0 = {}; -// CHECK-NEXT: double _grad0 = 0; -// CHECK-NEXT: clad::push(_t0, &_grad0); -// CHECK-NEXT: double _grad1 = 0; -// CHECK-NEXT: clad::push(_t0, &_grad1); -// CHECK-NEXT: numerical_diff::central_difference(std::hypot, _t0, 0, x, y); -// CHECK-NEXT: _r0 += 1 * _grad0; -// CHECK-NEXT: _r1 += 1 * _grad1; +// CHECK-NEXT: double _grad0[2] = {0}; +// CHECK-NEXT: numerical_diff::central_difference(std::hypot, _grad0, 0, x, y); +// CHECK-NEXT: _r0 += 1 * _grad0[0]; +// CHECK-NEXT: _r1 += 1 * _grad0[1]; // CHECK-NEXT: *_d_x += _r0; // CHECK-NEXT: *_d_y += _r1; // CHECK-NEXT: } diff --git a/test/NumericalDiff/NumDiff.C b/test/NumericalDiff/NumDiff.C index c65ff8a39..9649bce91 100644 --- a/test/NumericalDiff/NumDiff.C +++ b/test/NumericalDiff/NumDiff.C @@ -56,14 +56,10 @@ double test_3(double x) { //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0; //CHECK-NEXT: double _r1 = 0; -//CHECK-NEXT: clad::tape > _t0 = {}; -//CHECK-NEXT: double _grad0 = 0; -//CHECK-NEXT: clad::push(_t0, &_grad0); -//CHECK-NEXT: double _grad1 = 0; -//CHECK-NEXT: clad::push(_t0, &_grad1); -//CHECK-NEXT: numerical_diff::central_difference(std::hypot, _t0, 0, x, constant); -//CHECK-NEXT: _r0 += 1 * _grad0; -//CHECK-NEXT: _r1 += 1 * _grad1; +//CHECK-NEXT: double _grad0[2] = {0}; +//CHECK-NEXT: numerical_diff::central_difference(std::hypot, _grad0, 0, x, constant); +//CHECK-NEXT: _r0 += 1 * _grad0[0]; +//CHECK-NEXT: _r1 += 1 * _grad0[1]; //CHECK-NEXT: *_d_x += _r0; //CHECK-NEXT: _d_constant += _r1; //CHECK-NEXT: }