diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 5e9d54ac2..8a70fc6f3 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -13,6 +13,7 @@ #include "clang/Sema/Sema.h" #include "clad/Differentiator/DerivedFnCollector.h" #include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/DynamicGraph.h" #include #include @@ -72,6 +73,7 @@ namespace clad { class DerivativeBuilder { private: friend class VisitorBase; + friend class DiffRequest; friend class BaseForwardModeVisitor; friend class PushForwardModeVisitor; friend class VectorForwardModeVisitor; diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index d2b74592b..b025b5c8f 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -3,12 +3,13 @@ #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" +#include +#include +#include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" -#include -#include namespace clang { class CallExpr; class CompilerInstance; @@ -21,6 +22,7 @@ class Type; } // namespace clang namespace clad { +class DerivativeBuilder; /// A struct containing information about request to differentiate a function. struct DiffRequest { @@ -34,11 +36,13 @@ struct DiffRequest { } m_TbrRunInfo; mutable struct ActivityRunInfo { - std::set ToBeRecorded; + std::set VariedDecls; bool HasAnalysisRun = false; } m_ActivityRunInfo; public: + const DerivativeBuilder* Builder = nullptr; + // static std::set AllVariedDecls; /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -127,7 +131,8 @@ struct DiffRequest { Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && EnableVariedAnalysis == other.EnableVariedAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && - DeclarationOnly == other.DeclarationOnly; + DeclarationOnly == other.DeclarationOnly && + getVariedDecls() == other.getVariedDecls(); } const clang::FunctionDecl* operator->() const { return Function; } @@ -144,6 +149,16 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + + void setVariedDecls(std::set init) { + for (auto* vd : init) + this->m_ActivityRunInfo.VariedDecls.insert(vd); + } + std::set getVariedDecls() const { + return this->m_ActivityRunInfo.VariedDecls; + } + DiffRequest() {} + DiffRequest(DerivativeBuilder& builder) : Builder(&builder) {} }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index eb4f2d2a7..38d65f037 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -121,24 +121,56 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { FunctionDecl* FD = CE->getDirectCallee(); bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); if (noHiddenParam) { + bool restoreMarking = m_Marking; + bool restoreVaried = m_Varied; MutableArrayRef FDparam = FD->parameters(); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); + + QualType parType = FDparam[i]->getType(); + QualType innermostType = parType; + while (innermostType->isPointerType()) + innermostType = innermostType->getPointeeType(); + + m_Varied = false; + m_Marking = false; TraverseStmt(par); - m_VariedDecls.insert(FDparam[i]); + if (m_Varied) + m_VariedDecls.insert(FDparam[i]); + else if ((parType->isReferenceType() || + (utils::isArrayOrPointerType(parType) && + !innermostType.isConstQualified()))) { + m_Varied = true; + m_Marking = true; + TraverseStmt(par); + m_VariedDecls.insert(FDparam[i]); + } } + m_Varied = restoreVaried; + m_Marking = restoreMarking; } + return true; } bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { for (Decl* D : DS->decls()) { + + QualType VDTy = cast(D)->getType(); + QualType innermost = VDTy; + while (innermost->isPointerType()) + innermost = innermost->getPointeeType(); + if (VDTy->isArrayType() || + (VDTy->isPointerType() && !innermost.isConstQualified())) { + copyVarToCurBlock(cast(D)); + m_Varied = true; + } + if (Expr* init = cast(D)->getInit()) { m_Varied = false; TraverseStmt(init); m_Marking = true; - QualType VDTy = cast(D)->getType(); - if (m_Varied || utils::isArrayOrPointerType(VDTy)) + if (m_Varied) copyVarToCurBlock(cast(D)); m_Marking = false; } @@ -147,12 +179,7 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { } bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { - const auto opCode = UnOp->getOpcode(); Expr* E = UnOp->getSubExpr(); - if (opCode == UO_AddrOf || opCode == UO_Deref) { - m_Varied = true; - m_Marking = true; - } TraverseStmt(E); m_Marking = false; return true; @@ -163,7 +190,7 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { if (!VD) return true; - if (isVaried(VD)) + if (isVaried(VD) || VD->getType()->isArrayType()) m_Varied = true; if (m_Varied && m_Marking) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 1f1fe761f..ec5aaf6ed 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -22,158 +22,146 @@ using namespace clang; namespace clad { - static SourceLocation noLoc; - - /// Returns `DeclRefExpr` node corresponding to the function, method or - /// functor argument which is to be differentiated. - /// - /// \param[in] call A clad differentiation function call expression - /// \param SemaRef Reference to Sema - DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { - struct Finder : - RecursiveASTVisitor { - Sema& m_SemaRef; - SourceLocation m_BeginLoc; - DeclRefExpr* m_FnDRE = nullptr; - Finder(Sema& SemaRef, SourceLocation beginLoc) - : m_SemaRef(SemaRef), m_BeginLoc(beginLoc) {} - - // Required for visiting lambda declarations. - bool shouldVisitImplicitCode() const { return true; } - - bool VisitDeclRefExpr(DeclRefExpr* DRE) { - if (auto VD = dyn_cast(DRE->getDecl())) { - auto varType = VD->getType().getTypePtr(); - // If variable is of class type, set `m_FnDRE` to - // `DeclRefExpr` of overloaded call operator method of - // the class type. - if (varType->isStructureOrClassType()) { - auto RD = varType->getAsCXXRecordDecl(); - TraverseDecl(RD); - } else { - TraverseStmt(VD->getInit()); - } - } else if (isa(DRE->getDecl())) - m_FnDRE = DRE; - return false; +// std::set DiffRequest::AllVariedDecls; +static SourceLocation noloc; + +/// Returns `DeclRefExpr` node corresponding to the function, method or +/// functor argument which is to be differentiated. +/// +/// \param[in] call A clad differentiation function call expression +/// \param SemaRef Reference to Sema +DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { + struct Finder : RecursiveASTVisitor { + Sema& m_SemaRef; + SourceLocation m_BeginLoc; + DeclRefExpr* m_FnDRE = nullptr; + Finder(Sema& SemaRef, SourceLocation beginLoc) + : m_SemaRef(SemaRef), m_BeginLoc(beginLoc) {} + + // Required for visiting lambda declarations. + bool shouldVisitImplicitCode() const { return true; } + + bool VisitDeclRefExpr(DeclRefExpr* DRE) { + if (auto* VD = dyn_cast(DRE->getDecl())) { + const auto* varType = VD->getType().getTypePtr(); + // If variable is of class type, set `m_FnDRE` to + // `DeclRefExpr` of overloaded call operator method of + // the class type. + if (varType->isStructureOrClassType()) { + auto* RD = varType->getAsCXXRecordDecl(); + TraverseDecl(RD); + } else { + TraverseStmt(VD->getInit()); } + } else if (isa(DRE->getDecl())) + m_FnDRE = DRE; + return false; + } - bool VisitCXXRecordDecl(CXXRecordDecl* RD) { - auto callOperatorDeclName = - m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( - OverloadedOperatorKind::OO_Call); - LookupResult R(m_SemaRef, - callOperatorDeclName, - noLoc, - Sema::LookupNameKind::LookupMemberName); - // We do not want diagnostics that would fire because of this lookup. - R.suppressDiagnostics(); - m_SemaRef.LookupQualifiedName(R, RD); - - // Emit error diagnostics - if (R.empty()) { - const char diagFmt[] = "'%0' has no defined operator()"; - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); - return false; - } else if (!R.isSingleResult()) { - const char diagFmt[] = - "'%0' has multiple definitions of operator(). " - "Multiple definitions of call operators are not supported."; - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); - - // Emit diagnostics for candidate functions - for (auto oper = R.begin(), operEnd = R.end(); oper != operEnd; - ++oper) { - auto candidateFn = cast(oper.getDecl()); - m_SemaRef.NoteOverloadCandidate(candidateFn, - cast(candidateFn)); - } - return false; - } else if (R.isSingleResult() == 1 && - cast(R.getFoundDecl())->getAccess() != - AccessSpecifier::AS_public) { - const char diagFmt[] = - "'%0' contains %1 call operator. Differentiation of " - "private/protected call operator is not supported."; - - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - // Compute access specifier name so that it can be used in - // diagnostic message. - const char* callOperatorAS = - (cast(R.getFoundDecl())->getAccess() == - AccessSpecifier::AS_private - ? "private" - : "protected"); - m_SemaRef.Diag(m_BeginLoc, diagId) - << RD->getName() << callOperatorAS; - auto callOperator = cast(R.getFoundDecl()); - - bool isImplicit = true; - - // compute if the corresponding access specifier of the found - // call operator is implicit or explicit. - for (auto decl : RD->decls()) { - if (decl == callOperator) - break; - if (isa(decl)) { - isImplicit = false; - break; - } - } - - // Emit diagnostics for the found call operator - m_SemaRef.Diag(callOperator->getBeginLoc(), - diag::note_access_natural) - << (unsigned)(callOperator->getAccess() == - AccessSpecifier::AS_protected) - << isImplicit; - - return false; + bool VisitCXXRecordDecl(CXXRecordDecl* RD) { + auto callOperatorDeclName = + m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( + OverloadedOperatorKind::OO_Call); + LookupResult R(m_SemaRef, callOperatorDeclName, noloc, + Sema::LookupNameKind::LookupMemberName); + // We do not want diagnostics that would fire because of this lookup. + R.suppressDiagnostics(); + m_SemaRef.LookupQualifiedName(R, RD); + + // Emit error diagnostics + if (R.empty()) { + const char diagFmt[] = "'%0' has no defined operator()"; + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); + } else if (!R.isSingleResult()) { + const char diagFmt[] = + "'%0' has multiple definitions of operator(). " + "Multiple definitions of call operators are not supported."; + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); + + // Emit diagnostics for candidate functions + for (auto oper = R.begin(), operEnd = R.end(); oper != operEnd; + ++oper) { + auto* candidateFn = cast(oper.getDecl()); + m_SemaRef.NoteOverloadCandidate(candidateFn, + cast(candidateFn)); + } + } else if (R.isSingleResult() == 1 && + cast(R.getFoundDecl())->getAccess() != + AccessSpecifier::AS_public) { + const char diagFmt[] = + "'%0' contains %1 call operator. Differentiation of " + "private/protected call operator is not supported."; + + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + // Compute access specifier name so that it can be used in + // diagnostic message. + const char* callOperatorAS = + (cast(R.getFoundDecl())->getAccess() == + AccessSpecifier::AS_private + ? "private" + : "protected"); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName() << callOperatorAS; + auto* callOperator = cast(R.getFoundDecl()); + + bool isImplicit = true; + + // compute if the corresponding access specifier of the found + // call operator is implicit or explicit. + for (auto* decl : RD->decls()) { + if (decl == callOperator) + break; + if (isa(decl)) { + isImplicit = false; + break; } - - assert(R.isSingleResult() && - "Multiple definitions of call operators are not supported"); - assert(R.isSingleResult() == 1 && - cast(R.getFoundDecl())->getAccess() == - AccessSpecifier::AS_public && - "Differentiation of private/protected call operators are " - "not supported"); - auto callOperator = cast(R.getFoundDecl()); - // Creating `DeclRefExpr` of the found overloaded call operator - // method, to maintain consistency with member function - // differentiation. - CXXScopeSpec CSS; - utils::BuildNNS(m_SemaRef, callOperator->getDeclContext(), CSS, - /*addGlobalNS=*/true); - - // `ExprValueKind::VK_RValue` is used because functions are - // decomposed to function pointers and thus a temporary is - // created for the function pointer. - auto newFnDRE = clad_compat::GetResult( - m_SemaRef.BuildDeclRefExpr(callOperator, - callOperator->getType(), - CLAD_COMPAT_ExprValueKind_R_or_PR_Value, - noLoc, - &CSS)); - m_FnDRE = cast(newFnDRE); - return false; } - } finder(SemaRef, call->getArg(0)->getBeginLoc()); - finder.TraverseStmt(call->getArg(0)); - assert(cast(call->getDirectCallee()->getDeclContext()) - ->getName() == "clad" && - "Should be called for clad:: special functions!"); - return finder.m_FnDRE; - } + // Emit diagnostics for the found call operator + m_SemaRef.Diag(callOperator->getBeginLoc(), diag::note_access_natural) + << (unsigned)(callOperator->getAccess() == + AccessSpecifier::AS_protected) + << isImplicit; + + } else { + assert(R.isSingleResult() && + "Multiple definitions of call operators are not supported"); + assert(R.isSingleResult() == 1 && + cast(R.getFoundDecl())->getAccess() == + AccessSpecifier::AS_public && + "Differentiation of private/protected call operators are " + "not supported"); + auto* callOperator = cast(R.getFoundDecl()); + // Creating `DeclRefExpr` of the found overloaded call operator + // method, to maintain consistency with member function + // differentiation. + CXXScopeSpec CSS; + utils::BuildNNS(m_SemaRef, callOperator->getDeclContext(), CSS, + /*addGlobalNS=*/true); + + // `ExprValueKind::VK_RValue` is used because functions are + // decomposed to function pointers and thus a temporary is + // created for the function pointer. + auto* newFnDRE = + clad_compat::GetResult(m_SemaRef.BuildDeclRefExpr( + callOperator, callOperator->getType(), + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noloc, &CSS)); + m_FnDRE = cast(newFnDRE); + } + return false; + } + } finder(SemaRef, call->getArg(0)->getBeginLoc()); + finder.TraverseStmt(call->getArg(0)); + + assert(cast(call->getDirectCallee()->getDeclContext()) + ->getName() == "clad" && + "Should be called for clad:: special functions!"); + return finder.m_FnDRE; +} void DiffRequest::updateCall(FunctionDecl* FD, FunctionDecl* OverloadedFD, Sema& SemaRef) { @@ -210,7 +198,7 @@ namespace clad { auto kernelArgIdx = numArgs - 1; auto* cudaKernelFlag = SemaRef - .ActOnCXXBoolLiteral(noLoc, + .ActOnCXXBoolLiteral(noloc, replacementFD->hasAttr() ? tok::kw_true : tok::kw_false) @@ -221,7 +209,7 @@ namespace clad { // Create ref to generated FD. DeclRefExpr* DRE = - DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD, + DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noloc, replacementFD, /*RefersToEnclosingVariableOrCapture=*/false, replacementFD->getNameInfo(), replacementFD->getType(), oldDRE->getValueKind()); @@ -237,7 +225,7 @@ namespace clad { // Add the "&" operator auto* newUnOp = SemaRef - .BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE) + .BuildUnaryOp(nullptr, noloc, UnaryOperatorKind::UO_AddrOf, DRE) .get(); call->setArg(derivedFnArgIdx, newUnOp); } @@ -630,28 +618,25 @@ namespace clad { return true; if (!m_ActivityRunInfo.HasAnalysisRun) { - ArrayRef FDparam = Function->parameters(); - std::vector derivedParam; - - for (auto* parameter : FDparam) { - QualType parType = parameter->getType(); - while (parType->isPointerType()) - parType = parType->getPointeeType(); - if (!parType.isConstQualified()) - derivedParam.push_back(parameter); - } - - std::copy(derivedParam.begin(), derivedParam.end(), - std::inserter(m_ActivityRunInfo.ToBeRecorded, - m_ActivityRunInfo.ToBeRecorded.end())); - + if (Builder) + for (auto diffreq : this->Builder->m_DiffRequestGraph.getNodes()) + for (auto vd : diffreq.getVariedDecls()) + m_ActivityRunInfo.VariedDecls.insert(vd); + + if (Args) + for (const auto& dParam : DVI) + m_ActivityRunInfo.VariedDecls.insert(cast(dParam.param)); VariedAnalyzer analyzer(Function->getASTContext(), - m_ActivityRunInfo.ToBeRecorded); + m_ActivityRunInfo.VariedDecls); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; + if (Builder) + this->Builder->m_DiffRequestGraph.addNode(*this); } - auto found = m_ActivityRunInfo.ToBeRecorded.find(VD); - return found != m_ActivityRunInfo.ToBeRecorded.end(); + auto found = m_ActivityRunInfo.VariedDecls.find(VD); + return found != m_ActivityRunInfo.VariedDecls.end(); + + return false; } bool DiffCollector::VisitCallExpr(CallExpr* E) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a7d4dc6cb..7b8f9966c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -216,6 +216,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); auto derivativeBaseName = m_DiffReq.BaseFunctionName; + // llvm::errs() << "\nBaseFunctionName: " << derivativeBaseName << "\n"; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if // we differentiate w.r.t. all the parameters at once. @@ -1672,8 +1673,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerArg(arg) || - !m_DiffReq.shouldHaveAdjoint(PVD)) { + if (utils::IsReferenceOrPointerArg(arg)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -1947,7 +1947,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Overloaded derivative was not found, request the CladPlugin to // derive the called function. - DiffRequest pullbackRequest{}; + DiffRequest pullbackRequest(m_Builder); pullbackRequest.Function = FD; // Mark the indexes of the global args. Necessary if the argument of the @@ -1961,6 +1961,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + pullbackRequest.setVariedDecls(m_DiffReq.getVariedDecls()); bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) { diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 9979c85ce..490bafede 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -6,6 +6,26 @@ #include "clad/Differentiator/Differentiator.h" +inline double interpolate1d(double low, double high, double val, unsigned int numBins, double const* vals) +{ + double binWidth = (high - low) / numBins; + int idx = val >= high ? numBins - 1 : std::abs((val - low) / binWidth); + + // interpolation + double central = low + (idx + 0.5) * binWidth; + if (val > low + 0.5 * binWidth && val < high - 0.5 * binWidth) { + double slope; + if (val < central) { + slope = vals[idx] - vals[idx - 1]; + } else { + slope = vals[idx + 1] - vals[idx]; + } + return vals[idx] + slope * (val - central) / binWidth; + } + + return vals[idx]; +} + double f1(double x){ double a = x*x; double b = 1; @@ -264,7 +284,7 @@ double f8(double x){ // CHECK-NEXT: } // CHECK-NEXT: } -double fn9(double x, double const *obs) +double f9(double x, double const *obs) { double res = 0.0; for (int loopIdx0 = 0; loopIdx0 < 2; loopIdx0++) { @@ -273,7 +293,7 @@ double fn9(double x, double const *obs) return res; } -// CHECK: void fn9_grad(double x, const double *obs, double *_d_x, double *_d_obs) { +// CHECK: void f9_grad_0(double x, const double *obs, double *_d_x) { // CHECK-NEXT: int loopIdx0 = 0; // CHECK-NEXT: clad::tape _t1 = {}; // CHECK-NEXT: double _d_res = 0.; @@ -304,6 +324,81 @@ double fn9(double x, double const *obs) // CHECK-NEXT: } +void f10_1(double x, double* t){ + t[0] = x; +} + +double f10(double x){ + double t[3]; + f10_1(x, t); + return t[0]; +} +// CHECK: void f10_1_pullback(double x, double *t, double *_d_x, double *_d_t); +// CHECK-NEXT: void f10_grad(double x, double *_d_x) { +// CHECK-NEXT: double _d_t[3] = {0}; +// CHECK-NEXT: double t[3]; +// CHECK-NEXT: f10_1(x, t); +// CHECK-NEXT: _d_t[0] += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f10_1_pullback(x, t, &_r0, _d_t); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f11_1(double v, double& u){ + u = v; + return u; +} + +double f11(double x){ + double y; + double c = f11_1(x, y); + return y; +} + +// CHECK: void f11_1_pullback(double v, double &u, double _d_y, double *_d_v, double *_d_u); +// CHECK-NEXT: void f11_grad(double x, double *_d_x) { +// CHECK-NEXT: double _d_y = 0.; +// CHECK-NEXT: double y; +// CHECK-NEXT: double _t0 = y; +// CHECK-NEXT: double _d_c = 0.; +// CHECK-NEXT: double c = f11_1(x, y); +// CHECK-NEXT: _d_y += 1; +// CHECK-NEXT: { +// CHECK-NEXT: y = _t0; +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f11_1_pullback(x, _t0, _d_c, &_r0, &_d_y); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f12_1(double y, const double* obs){ + double nopull = interpolate1d(1.0, 3.0, 5.0, 5, obs); + return nopull; +} +double f12(double x, const double* obs){ + double pull = f12_1(x, obs); + return pull*x; +} + +// CHECK: void f12_1_pullback(double y, const double *obs, double _d_y0, double *_d_y); +// CHECK-NEXT: void f12_grad_0(double x, const double *obs, double *_d_x) { +// CHECK-NEXT: double _d_pull = 0.; +// CHECK-NEXT: double pull = f12_1(x, obs); +// CHECK-NEXT: { +// CHECK-NEXT: _d_pull += 1 * x; +// CHECK-NEXT: *_d_x += pull * 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f12_1_pullback(x, obs, _d_pull, &_r0); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -315,7 +410,6 @@ int main(){ double arr[] = {1,2,3,4,5}; double darr[] = {0,0,0,0,0}; double result[3] = {}; - double dx; TEST(f1, 3);// CHECK-EXEC: {6.00} TEST(f2, 3);// CHECK-EXEC: {6.00} TEST(f3, 3);// CHECK-EXEC: {0.00} @@ -324,24 +418,53 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} - auto grad = clad::gradient(fn9); - grad.execute(3, arr, &dx, darr); - printf("%.2f\n", dx);// CHECK-EXEC: 2.00 + double dx9 = 0; + auto grad9 = clad::gradient(f9, "x"); + grad9.execute(3, arr, &dx9, darr); + printf("%.2f\n", dx9);// CHECK-EXEC: 2.00 + TEST(f10, 3);// CHECK-EXEC: {1.00} + TEST(f11, 3);// CHECK-EXEC: {1.00} + double dx12 = 0; + auto grad = clad::gradient(f12, "x"); + grad.execute(3, arr, &dx12, darr); + printf("%.2f\n", dx12);// CHECK-EXEC: 5.00 } // CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { -// CHECK-NEXT: double _d_k = 0.; // CHECK-NEXT: double k = 2 * u; // CHECK-NEXT: double _d_n = 0.; // CHECK-NEXT: double n = 2 * v; -// CHECK-NEXT: { -// CHECK-NEXT: _d_n += _d_y * k; -// CHECK-NEXT: _d_k += n * _d_y; -// CHECK-NEXT: } +// CHECK-NEXT: _d_n += _d_y * k; // CHECK-NEXT: *_d_v += 2 * _d_n; -// CHECK-NEXT: *_d_u += 2 * _d_k; // CHECK-NEXT: } // CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { // CHECK-NEXT: *_d_v += _d_y; +// CHECK-NEXT: } + +// CHECK: void f10_1_pullback(double x, double *t, double *_d_x, double *_d_t) { +// CHECK-NEXT: double _t0 = t[0]; +// CHECK-NEXT: t[0] = x; +// CHECK-NEXT: { +// CHECK-NEXT: t[0] = _t0; +// CHECK-NEXT: double _r_d0 = _d_t[0]; +// CHECK-NEXT: _d_t[0] = 0.; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void f11_1_pullback(double v, double &u, double _d_y, double *_d_v, double *_d_u) { +// CHECK-NEXT: double _t0 = u; +// CHECK-NEXT: u = v; +// CHECK-NEXT: *_d_u += _d_y; +// CHECK-NEXT: { +// CHECK-NEXT: u = _t0; +// CHECK-NEXT: double _r_d0 = *_d_u; +// CHECK-NEXT: *_d_u = 0.; +// CHECK-NEXT: *_d_v += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void f12_1_pullback(double y, const double *obs, double _d_y0, double *_d_y) { +// CHECK-NEXT: double nopull = interpolate1d(1., 3., 5., 5, obs); // CHECK-NEXT: } \ No newline at end of file