From 10e85f43a64df7eb148c47d3415f543f17bc7450 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Tue, 11 Jun 2024 15:40:51 +0200 Subject: [PATCH] Add support for custom derivatives for top level derivatives fixes #352 --- .../clad/Differentiator/DerivativeBuilder.h | 31 +++++++- include/clad/Differentiator/DynamicGraph.h | 5 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 13 +++- lib/Differentiator/CladUtils.cpp | 5 +- lib/Differentiator/DerivativeBuilder.cpp | 56 +++++++++++---- lib/Differentiator/HessianModeVisitor.cpp | 32 +++++++-- .../ReverseModeForwPassVisitor.cpp | 8 ++- lib/Differentiator/ReverseModeVisitor.cpp | 17 ++++- test/FirstDerivative/BuiltinDerivatives.C | 12 +++- test/Gradient/FunctionCalls.C | 70 +++++++++++++++++-- test/Hessian/BuiltinDerivatives.C | 60 ++++++++++++++++ 11 files changed, 270 insertions(+), 39 deletions(-) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index f2981438f..d343c103f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -133,6 +133,32 @@ namespace clad { stream << arg; } + /// Lookup the result of finding a custom derivative or numerical + /// differentiation function. + /// + /// \param[in] Name The name of the function to look up. + /// \param[in] originalFnDC The original function's DeclContext. + /// \param[in] SS The CXXScopeSpec to extend with the namespace of the + /// function. + /// \param[in] forCustomDerv A flag to keep track of which + /// namespace we should look in for the overloads. + /// \param[in] namespaceShouldExist A flag to enforce assertion failure + /// if the overload function namespace was not found. If false and + /// the function containing namespace was not found, + clang::LookupResult LookupCustomDerivativeOrNumericalDiff( + const std::string& Name, clang::DeclContext* originalFnDC, + clang::CXXScopeSpec& SS, bool forCustomDerv = true, + bool namespaceShouldExist = true); + + /// Looks up if the user has defined a custom derivative for the given + /// derivative function. + /// \param[in] D + /// \returns The custom derivative function if found, nullptr otherwise. + clang::FunctionDecl* + LookupCustomDerivativeDecl(const std::string& Name, + clang::DeclContext* originalFnDC, + clang::QualType functionType); + public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, DerivedFnCollector& DFC, @@ -175,7 +201,10 @@ namespace clad { /// graph. /// /// \param[in] request The request to add the edge to. - void AddEdgeToGraph(const DiffRequest& request); + /// \param[in] alreadyDerived A flag to keep track of whether the request + /// is already derived or not. + void AddEdgeToGraph(const DiffRequest& request, + bool alreadyDerived = false); }; } // end namespace clad diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index 55e5ad350..db1113116 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -74,9 +74,12 @@ template class DynamicGraph { /// Add an edge from the current node being processed to the /// destination node. /// \param dest - void addEdgeToCurrentNode(const T& dest) { + /// \param alreadyProcessed If the destination node is already processed. + void addEdgeToCurrentNode(const T& dest, bool alreadyProcessed = false) { if (m_currentId != -1) addEdge(m_nodes[m_currentId], dest); + if (alreadyProcessed) + m_nodeMap[dest].first = true; } /// Set the current node being processed. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5a3ed4394..0035bd7ba 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -159,13 +159,20 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, for (auto field : diffVarInfo.fields) argInfo += "_" + field; - IdentifierInfo* II = &m_Context.Idents.get( - request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix); + // Check if the function is already declared as a custom derivative. + std::string gradientName = + request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix; + auto* DC = const_cast(m_Function->getDeclContext()); + if (FunctionDecl* customDerivative = + m_Builder.LookupCustomDerivativeDecl(gradientName, DC, FD->getType())) + return DerivativeAndOverload{customDerivative, nullptr}; + + IdentifierInfo* II = &m_Context.Idents.get(gradientName); SourceLocation validLoc{m_Function->getLocation()}; DeclarationNameInfo name(II, validLoc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope()); - DeclContext* DC = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType()); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 6cb5aff43..73fd1c5e8 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -238,8 +238,9 @@ namespace clad { DeclContext* DC = DC1; for (int i = contexts.size() - 1; i >= 0; --i) { NamespaceDecl* ND = cast(contexts[i]); - DC = LookupNSD(semaRef, ND->getIdentifier()->getName(), - /*shouldExist=*/false, DC1); + if (ND->getIdentifier()) + DC = LookupNSD(semaRef, ND->getIdentifier()->getName(), + /*shouldExist=*/false, DC1); if (!DC) return nullptr; DC1 = DC; diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 632e1b65b..ceed8eca1 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -171,10 +171,16 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { return false; } - Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( - const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, clang::DeclContext* originalFnDC, - bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { + LookupResult DerivativeBuilder::LookupCustomDerivativeOrNumericalDiff( + const std::string& Name, clang::DeclContext* originalFnDC, + CXXScopeSpec& SS, bool forCustomDerv /*=true*/, + bool namespaceShouldExist /*=true*/) { + + IdentifierInfo* II = &m_Context.Idents.get(Name); + DeclarationName name(II); + DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema)); + LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName); + NamespaceDecl* NSD = nullptr; std::string namespaceID; if (forCustomDerv) { @@ -201,10 +207,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { "flag, this means that every try to numerically differentiate a " "function will fail! Remove the flag to revert to default " "behaviour."); - return nullptr; + return R; } } - CXXScopeSpec SS; DeclContext* DC = NSD; // FIXME: Here `if` branch should be removed once we update @@ -223,13 +228,37 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } else { SS.Extend(m_Context, NSD, noLoc, noLoc); } - IdentifierInfo* II = &m_Context.Idents.get(Name); - DeclarationName name(II); - DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema)); - - LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName); if (DC) m_Sema.LookupQualifiedName(R, DC); + return R; + } + + FunctionDecl* DerivativeBuilder::LookupCustomDerivativeDecl( + const std::string& Name, clang::DeclContext* originalFnDC, + QualType functionType) { + CXXScopeSpec SS; + LookupResult R = + LookupCustomDerivativeOrNumericalDiff(Name, originalFnDC, SS); + + for (NamedDecl* ND : R) + if (auto* FD = dyn_cast(ND)) + // Check if FD and functionType have the same signature. + if (utils::SameCanonicalType(FD->getType(), functionType)) + if (FD->isDefined() || !m_DFC.IsDerivative(FD)) + return FD; + + return nullptr; + } + + Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( + const std::string& Name, llvm::SmallVectorImpl& CallArgs, + clang::Scope* S, clang::DeclContext* originalFnDC, + bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { + + CXXScopeSpec SS; + LookupResult R = LookupCustomDerivativeOrNumericalDiff( + Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist); + Expr* OverloadedFn = nullptr; if (!R.empty()) { // FIXME: We should find a way to specify nested name specifier @@ -402,7 +431,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { return nullptr; } - void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request) { - m_DiffRequestGraph.addEdgeToCurrentNode(request); + void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request, + bool alreadyDerived /*=false*/) { + m_DiffRequestGraph.addEdgeToCurrentNode(request, alreadyDerived); } }// end namespace clad diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 6fe5ecbfd..937c43a41 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -59,19 +59,27 @@ namespace clad { IndependentArgRequest.CallUpdateRequired = false; IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); // FIXME: Find a way to do this without accessing plugin namespace functions + bool alreadyDerived = true; FunctionDecl* firstDerivative = Builder.FindDerivedFunction(IndependentArgRequest); if (!firstDerivative) { + alreadyDerived = false; // Derive declaration of the the forward mode derivative. IndependentArgRequest.DeclarationOnly = true; firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest); + // It is possible that user has provided a custom derivative for the + // derivative function. In that case, we should not derive the definition + // again. + if (firstDerivative->isDefined()) + alreadyDerived = true; + // Add the request to derive the definition of the forward mode derivative // to the schedule. IndependentArgRequest.DeclarationOnly = false; IndependentArgRequest.DerivedFDPrototype = firstDerivative; } - Builder.AddEdgeToGraph(IndependentArgRequest); + Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived); // Further derives function w.r.t to ReverseModeArgs DiffRequest ReverseModeRequest{}; @@ -81,20 +89,27 @@ namespace clad { ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString(); ReverseModeRequest.UpdateDiffParamsInfo(SemaRef); + alreadyDerived = true; FunctionDecl* secondDerivative = Builder.FindDerivedFunction(ReverseModeRequest); if (!secondDerivative) { + alreadyDerived = false; // Derive declaration of the the reverse mode derivative. ReverseModeRequest.DeclarationOnly = true; secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest); - // Add the request to derive the definition of the reverse mode derivative - // to the schedule. + // It is possible that user has provided a custom derivative for the + // derivative function. In that case, we should not derive the definition + // again. + if (secondDerivative->isDefined()) + alreadyDerived = true; + + // Add the request to derive the definition of the reverse mode + // derivative to the schedule. ReverseModeRequest.DeclarationOnly = false; ReverseModeRequest.DerivedFDPrototype = secondDerivative; } - Builder.AddEdgeToGraph(ReverseModeRequest); - + Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived); return secondDerivative; } @@ -249,8 +264,13 @@ namespace clad { // Cast to function pointer. originalFnProtoType->getExtProtoInfo()); - // Create the gradient function declaration. + // Check if the function is already declared as a custom derivative. DeclContext* DC = const_cast(m_Function->getDeclContext()); + if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( + hessianFuncName, DC, hessianFunctionType)) + return DerivativeAndOverload{customDerivative, nullptr}; + + // Create the gradient function declaration. llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 7c444415f..73b197379 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -42,8 +42,14 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + auto* DC = const_cast(m_Function->getDeclContext()); + // Check if the function is already declared as a custom derivative. + if (FunctionDecl* customDerivative = + m_Builder.LookupCustomDerivativeDecl(fnName, DC, fnType)) + return DerivativeAndOverload{customDerivative, nullptr}; + + m_Sema.CurContext = DC; SourceLocation validLoc{m_Function->getLocation()}; DeclWithContext fnBuildRes = m_Builder.cloneFunction( m_Function, *this, m_Sema.CurContext, validLoc, fnDNI, fnType); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c1bbb5a2d..2b4887f8a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -357,11 +357,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Cast to function pointer. originalFnType->getExtProtoInfo()); + // Check if the function is already declared as a custom derivative. + auto* DC = const_cast(m_Function->getDeclContext()); + if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( + gradientName, DC, gradientFunctionType)) + return DerivativeAndOverload{customDerivative, nullptr}; + // Create the gradient function declaration. llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); - auto* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( m_Function, *this, DC, noLoc, name, gradientFunctionType); @@ -1775,20 +1780,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, clad::utils::ComputeEffectiveFnName(FD); calleeFnForwPassReq.VerboseDiags = true; + bool alreadyDerived = true; FunctionDecl* calleeFnForwPassFD = m_Builder.FindDerivedFunction(calleeFnForwPassReq); if (!calleeFnForwPassFD) { + alreadyDerived = false; // Derive declaration of the the forward pass function. calleeFnForwPassReq.DeclarationOnly = true; calleeFnForwPassFD = plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + // It is possible that user has provided a custom derivative for the + // derivative function. In that case, we should not derive the + // definition again. + if (calleeFnForwPassFD->getDefinition()) + alreadyDerived = true; + // Add the request to derive the definition of the forward pass // function. calleeFnForwPassReq.DeclarationOnly = false; calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; } - m_Builder.AddEdgeToGraph(calleeFnForwPassReq); + m_Builder.AddEdgeToGraph(calleeFnForwPassReq, alreadyDerived); assert(calleeFnForwPassFD && "Clad failed to generate callee function forward pass function"); diff --git a/test/FirstDerivative/BuiltinDerivatives.C b/test/FirstDerivative/BuiltinDerivatives.C index a0f94ad8f..b9ddddf32 100644 --- a/test/FirstDerivative/BuiltinDerivatives.C +++ b/test/FirstDerivative/BuiltinDerivatives.C @@ -9,14 +9,20 @@ #include "../TestUtils.h" extern "C" int printf(const char* fmt, ...); +namespace clad{ + namespace custom_derivatives{ + float f1_darg0(float x) { + return cos(x); + } + } +} + float f1(float x) { return sin(x); } // CHECK: float f1_darg0(float x) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x); -// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: return cos(x); // CHECK-NEXT: } float f2(float x) { diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 0a91b4927..af6d0ab24 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -240,12 +240,25 @@ double& identity(double& i) { return i; } +namespace clad{ +namespace custom_derivatives{ + clad::ValueAndAdjoint custom_identity_forw(double &i, double *d_i) { + return {i, *d_i}; + } +} // namespace custom_derivatives +} // namespace clad + +double& custom_identity(double& i) { + return i; +} + double fn7(double i, double j) { double& k = identity(i); double& l = identity(j); + double& temp = custom_identity(i); k += 7*j; l += 9*i; - return i + j; + return i + j + temp; } // CHECK: void fn6_grad(double i, double j, double *_d_i, double *_d_j) { @@ -261,13 +274,21 @@ double fn7(double i, double j) { // CHECK: clad::ValueAndAdjoint identity_forw(double &i, double *_d_i); +// CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i); + +// CHECK: clad::ValueAndAdjoint custom_identity_forw(double &i, double *d_i) { +// CHECK-NEXT: return {i, *d_i}; +// CHECK-NEXT: } + // CHECK: void fn7_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: double _t0; // CHECK-NEXT: double *_d_k = 0; // CHECK-NEXT: double _t2; // CHECK-NEXT: double *_d_l = 0; // CHECK-NEXT: double _t4; -// CHECK-NEXT: double _t5; +// CHECK-NEXT: double *_d_temp = 0; +// CHECK-NEXT: double _t6 +// CHECK-NEXT: double _t7; // CHECK-NEXT: _t0 = i; // CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &*_d_i); // CHECK-NEXT: _d_k = &_t1.adjoint; @@ -276,27 +297,36 @@ double fn7(double i, double j) { // CHECK-NEXT: clad::ValueAndAdjoint _t3 = identity_forw(j, &*_d_j); // CHECK-NEXT: _d_l = &_t3.adjoint; // CHECK-NEXT: double &l = _t3.value; -// CHECK-NEXT: _t4 = k; +// CHECK-NEXT: _t4 = i; +// CHECK-NEXT: clad::ValueAndAdjoint _t5 = custom_identity_forw(i, &*_d_i); +// CHECK-NEXT: _d_temp = &_t5.adjoint; +// CHECK-NEXT: double &temp = _t5.value; +// CHECK-NEXT: _t6 = k; // CHECK-NEXT: k += 7 * j; -// CHECK-NEXT: _t5 = l; +// CHECK-NEXT: _t7 = l; // CHECK-NEXT: l += 9 * i; // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { // CHECK-NEXT: *_d_i += 1; // CHECK-NEXT: *_d_j += 1; +// CHECK-NEXT: _d_temp += 1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: l = _t5; +// CHECK-NEXT: l = _t7; // CHECK-NEXT: double _r_d1 = *_d_l; // CHECK-NEXT: *_d_i += 9 * _r_d1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: k = _t4; +// CHECK-NEXT: k = _t6; // CHECK-NEXT: double _r_d0 = *_d_k; // CHECK-NEXT: *_d_j += 7 * _r_d0; // CHECK-NEXT: } // CHECK-NEXT: { +// CHECK-NEXT: i = _t4; +// CHECK-NEXT: custom_identity_pullback(_t4, 0, &*_d_i); +// CHECK-NEXT: } +// CHECK-NEXT: { // CHECK-NEXT: j = _t2; // CHECK-NEXT: identity_pullback(_t2, 0, &*_d_j); // CHECK-NEXT: } @@ -711,6 +741,21 @@ double fn21(double x) { // CHECK-NEXT: ptr = _t0; // CHECK-NEXT: ptrRef_pullback(_t0, 1, &_d_ptr); // CHECK-NEXT: } + +namespace clad{ +namespace custom_derivatives{ + void fn22_grad_1(double x, double y, double *d_y) { + *d_y += x; + } +} +} + +double fn22(double x, double y) { + return x*y; // fn22 has a custom derivative defined. +} + +// CHECK: void fn22_grad_1(double x, double y, double *d_y) { +// CHECK-NEXT: *d_y += x; // CHECK-NEXT: } template @@ -778,7 +823,7 @@ int main() { TEST_ARR5(fn4, arr, 5); // CHECK-EXEC: {23.00, 3.00, 3.00, 3.00, 3.00} TEST_ARR5(fn5, arr, 5); // CHECK-EXEC: {5.00, 1.00, 0.00, 0.00, 0.00} TEST2(fn6, 3, 5); // CHECK-EXEC: {5.00, 3.00} - TEST2(fn7, 3, 5); // CHECK-EXEC: {10.00, 71.00} + TEST2(fn7, 3, 5); // CHECK-EXEC: {11.00, 78.00} TEST2(fn8, 3, 5); // CHECK-EXEC: {7.62, 4.57} TEST2(fn9, 3, 5); // CHECK-EXEC: {5.00, 3.00} TEST2(fn10, 8, 5); // CHECK-EXEC: {0.00, 7.00} @@ -819,6 +864,11 @@ int main() { INIT(fn21); TEST1(fn21, 8); // CHECK-EXEC: {1.00} + + auto fn22_grad_1 = clad::gradient(fn22, "y"); + double dy = 0; + fn22_grad_1.execute(3, 5, &dy); + printf("{%.2f}\n", dy); // CHECK-EXEC: {3.00} } double sq_defined_later(double x) { @@ -965,6 +1015,12 @@ double sq_defined_later(double x) { // CHECK-NEXT: return {i, *_d_i}; // CHECK-NEXT: } +// CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: *_d_i += _d_y; +// CHECK-NEXT: } + // CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) { // CHECK-NEXT: bool _cond0; // CHECK-NEXT: { diff --git a/test/Hessian/BuiltinDerivatives.C b/test/Hessian/BuiltinDerivatives.C index 0bc005901..770b73cc0 100644 --- a/test/Hessian/BuiltinDerivatives.C +++ b/test/Hessian/BuiltinDerivatives.C @@ -9,6 +9,25 @@ #include "clad/Differentiator/Differentiator.h" #include +namespace clad { + namespace custom_derivatives { + float f7_darg0(float x, float y) { + return cos(x); + } + + float f7_darg1(float x, float y) { + return exp(y); + } + + void f8_hessian(double x, double y, double *hessianMatrix) { + hessianMatrix[0] = 1.0; + hessianMatrix[1] = 1.0; + hessianMatrix[2] = 1.0; + hessianMatrix[3] = 1.0; + } + } +} + float f1(float x) { return sin(x) + cos(x); } @@ -90,6 +109,26 @@ float f6(float x, float y) { // CHECK-NEXT: f6_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +float f7(float x, float y) { + return sin(x) + exp(y); +} + +// CHECK: float f7_darg0(float x, float y) { +// CHECK-NEXT: return cos(x); +// CHECK-NEXT: } + +// CHECK: void f7_darg0_grad(float x, float y, float *_d_x, float *_d_y); + +// CHECK: float f7_darg1(float x, float y) { +// CHECK-NEXT: return exp(y); +// CHECK-NEXT: } + +// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y); + +// CHECK: void f7_hessian(float x, float y, float *hessianMatrix) { +// CHECK-NEXT: f7_darg0_grad(x, y, hessianMatrix + {{0U|0UL}}, hessianMatrix + {{1U|1UL}}); +// CHECK-NEXT: f7_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); +// CHECK-NEXT: } #define TEST1(F, x) { \ result[0] = 0; \ @@ -115,6 +154,7 @@ int main() { TEST1(f4, 3); // CHECK-EXEC: Result is = {108.00} TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84} TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76} + TEST2(f7, 3, 4); // CHECK-EXEC: Result is = {-0.14, 0.00, 0.00, 54.60} // CHECK: float f1_darg0(float x) { // CHECK-NEXT: float _d_x = 1; @@ -318,6 +358,26 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: void f7_darg0_grad(float x, float y, float *_d_x, float *_d_y) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::cos_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::exp_pushforward(y, 1.F).pushforward; +// CHECK-NEXT: *_d_y += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + // CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x) { // CHECK-NEXT: float _t0; // CHECK-NEXT: _t0 = ::std::cos(x);