From cc33a2cada5395dc71a7c05e6f3afdddeb2d2e43 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 12 Jun 2024 16:06:35 +0200 Subject: [PATCH] Separate sets for custom and clad-generated derivatives This is required for allowing users to link custom derivatives from a separate translation unit. --- .../clad/Differentiator/DerivedFnCollector.h | 16 +- .../clad/Differentiator/HessianModeVisitor.h | 4 +- lib/Differentiator/DerivativeBuilder.cpp | 25 ++- lib/Differentiator/DerivedFnCollector.cpp | 12 +- lib/Differentiator/HessianModeVisitor.cpp | 179 +++++++++--------- test/Hessian/BuiltinDerivatives.C | 70 ++++--- tools/ClangPlugin.h | 3 +- unittests/Misc/CallDeclOnly.cpp | 29 +++ unittests/Misc/Defs.cpp | 7 + 9 files changed, 213 insertions(+), 132 deletions(-) diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index 909285e99..24e013fbd 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -20,9 +20,14 @@ class DerivedFnCollector { /// a function. llvm::DenseMap m_DerivedFnInfoCollection; - /// Set to keep track of all the functions that are derivatives. + /// Set to keep track of all the functions that are derivatives + /// functions produced by Clad. DerivativeSet m_DerivativeSet; + /// Set to keep track of all the functions that are custom derivatives + /// functions provided by the user. + DerivativeSet m_CustomDerivativeSet; + public: /// Adds a derived function to the collection. void Add(const DerivedFnInfo& DFI); @@ -30,11 +35,18 @@ class DerivedFnCollector { /// Adds a function to derivative set. void AddToDerivativeSet(const clang::FunctionDecl* FD); + /// Adds a function to custom derivative set. + void AddToCustomDerivativeSet(const clang::FunctionDecl* FD); + /// Finds a `DerivedFnInfo` object in the collection that satisfies the /// given differentiation request. DerivedFnInfo Find(const DiffRequest& request) const; - bool IsDerivative(const clang::FunctionDecl* FD) const; + /// Returns true if the function is a Clad-generated derivative. + bool IsCladDerivative(const clang::FunctionDecl* FD) const; + + /// Returns true if the function is a custom derivative. + bool IsCustomDerivative(const clang::FunctionDecl* FD) const; private: /// Returns true if the collection already contains a `DerivedFnInfo` diff --git a/include/clad/Differentiator/HessianModeVisitor.h b/include/clad/Differentiator/HessianModeVisitor.h index d1ca65c1c..b9b768d30 100644 --- a/include/clad/Differentiator/HessianModeVisitor.h +++ b/include/clad/Differentiator/HessianModeVisitor.h @@ -30,7 +30,9 @@ namespace clad { DerivativeAndOverload Merge(std::vector secDerivFuncs, llvm::SmallVector IndependentArgsSize, - size_t TotalIndependentArgsSize, std::string hessianFuncName); + size_t TotalIndependentArgsSize, const std::string& hessianFuncName, + clang::DeclContext* FD, clang::QualType hessianFuncType, + llvm::SmallVector paramTypes); public: HessianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 9f02f4602..43fd1ccaa 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -243,8 +243,14 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { if (auto* FD = dyn_cast(ND)) // Check if FD and functionType have the same signature. if (utils::SameCanonicalType(FD->getType(), functionType)) - if (FD->isDefined() || !m_DFC.IsDerivative(FD)) + // Make sure that it is not the case that FD is the forward + // declaration generated by Clad. It should be user defined custom + // derivative (either within the same translation unit or linked in + // from another translation unit). + if (FD->isDefined() || !m_DFC.IsCladDerivative(FD)) { + m_DFC.AddToCustomDerivativeSet(FD); return FD; + } return nullptr; } @@ -284,7 +290,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { // differentiation due to unavailable definition. if (auto* CE = dyn_cast(OverloadedFn)) if (FunctionDecl* FD = CE->getDirectCallee()) - m_DFC.AddToDerivativeSet(FD); + m_DFC.AddToCustomDerivativeSet(FD); } return OverloadedFn; } @@ -328,8 +334,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { // If FD is only a declaration, try to find its definition. if (!FD->getDefinition()) { // If only declaration is requested, allow this for clad-generated - // functions. - if (!request.DeclarationOnly || !m_DFC.IsDerivative(FD)) { + // functions or custom derivatives. + if (!request.DeclarationOnly || + !(m_DFC.IsCladDerivative(FD) || m_DFC.IsCustomDerivative(FD))) { if (request.VerboseDiags) diag(DiagnosticsEngine::Error, request.CallContext ? request.CallContext->getBeginLoc() : noLoc, @@ -415,10 +422,12 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { // FIXME: if the derivatives aren't registered in this order and the // derivative is a member function it goes into an infinite loop - if (auto FD = result.derivative) - registerDerivative(FD, m_Sema); - if (auto OFD = result.overload) - registerDerivative(OFD, m_Sema); + if (!m_DFC.IsCustomDerivative(result.derivative)) { + if (auto FD = result.derivative) + registerDerivative(FD, m_Sema); + if (auto OFD = result.overload) + registerDerivative(OFD, m_Sema); + } return result; } diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index 71bc22258..4b236a867 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -15,6 +15,11 @@ void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) { m_DerivativeSet.insert(FD); } +void DerivedFnCollector::AddToCustomDerivativeSet( + const clang::FunctionDecl* FD) { + m_CustomDerivativeSet.insert(FD); +} + bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { auto subCollectionIt = m_DerivedFnInfoCollection.find(DFI.OriginalFn()); if (subCollectionIt == m_DerivedFnInfoCollection.end()) @@ -42,7 +47,12 @@ DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const { return *it; } -bool DerivedFnCollector::IsDerivative(const clang::FunctionDecl* FD) const { +bool DerivedFnCollector::IsCladDerivative(const clang::FunctionDecl* FD) const { return m_DerivativeSet.count(FD); } + +bool DerivedFnCollector::IsCustomDerivative( + const clang::FunctionDecl* FD) const { + return m_CustomDerivativeSet.count(FD); +} } // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index ddd472345..a442f6e88 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -48,69 +48,71 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, /// Derives the function w.r.t both forward and reverse mode and returns the /// FunctionDecl obtained from reverse mode differentiation - static FunctionDecl* DeriveUsingForwardAndReverseMode( - Sema& SemaRef, clad::plugin::CladPlugin& CP, - clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest, - const Expr* ForwardModeArgs, const Expr* ReverseModeArgs) { - // Derives function once in forward mode w.r.t to ForwardModeArgs - IndependentArgRequest.Args = ForwardModeArgs; - IndependentArgRequest.Mode = DiffMode::forward; - IndependentArgRequest.CallUpdateRequired = false; - IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); - // FIXME: Find a way to do this without accessing plugin namespace functions - bool alreadyDerived = true; - FunctionDecl* firstDerivative = - Builder.FindDerivedFunction(IndependentArgRequest); - if (!firstDerivative) { - alreadyDerived = false; - // Derive declaration of the the forward mode derivative. - IndependentArgRequest.DeclarationOnly = true; - firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest); - - // It is possible that user has provided a custom derivative for the - // derivative function. In that case, we should not derive the definition - // again. - if (firstDerivative->isDefined()) - alreadyDerived = true; - - // Add the request to derive the definition of the forward mode derivative - // to the schedule. - IndependentArgRequest.DeclarationOnly = false; - IndependentArgRequest.DerivedFDPrototype = firstDerivative; - } - Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived); - - // Further derives function w.r.t to ReverseModeArgs - DiffRequest ReverseModeRequest{}; - ReverseModeRequest.Mode = DiffMode::reverse; - ReverseModeRequest.Function = firstDerivative; - ReverseModeRequest.Args = ReverseModeArgs; - ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString(); - ReverseModeRequest.UpdateDiffParamsInfo(SemaRef); - - alreadyDerived = true; - FunctionDecl* secondDerivative = - Builder.FindDerivedFunction(ReverseModeRequest); - if (!secondDerivative) { - alreadyDerived = false; - // Derive declaration of the the reverse mode derivative. - ReverseModeRequest.DeclarationOnly = true; - secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest); - - // It is possible that user has provided a custom derivative for the - // derivative function. In that case, we should not derive the definition - // again. - if (secondDerivative->isDefined()) - alreadyDerived = true; - - // Add the request to derive the definition of the reverse mode - // derivative to the schedule. - ReverseModeRequest.DeclarationOnly = false; - ReverseModeRequest.DerivedFDPrototype = secondDerivative; - } - Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived); - return secondDerivative; +static FunctionDecl* DeriveUsingForwardAndReverseMode( + Sema& SemaRef, clad::plugin::CladPlugin& CP, + clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest, + const Expr* ForwardModeArgs, const Expr* ReverseModeArgs, + DerivedFnCollector& DFC) { + // Derives function once in forward mode w.r.t to ForwardModeArgs + IndependentArgRequest.Args = ForwardModeArgs; + IndependentArgRequest.Mode = DiffMode::forward; + IndependentArgRequest.CallUpdateRequired = false; + IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); + // FIXME: Find a way to do this without accessing plugin namespace functions + bool alreadyDerived = true; + FunctionDecl* firstDerivative = + Builder.FindDerivedFunction(IndependentArgRequest); + if (!firstDerivative) { + alreadyDerived = false; + // Derive declaration of the the forward mode derivative. + IndependentArgRequest.DeclarationOnly = true; + firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest); + + // It is possible that user has provided a custom derivative for the + // derivative function. In that case, we should not derive the definition + // again. + if (firstDerivative->isDefined() || DFC.IsCustomDerivative(firstDerivative)) + alreadyDerived = true; + + // Add the request to derive the definition of the forward mode derivative + // to the schedule. + IndependentArgRequest.DeclarationOnly = false; + IndependentArgRequest.DerivedFDPrototype = firstDerivative; + } + Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived); + + // Further derives function w.r.t to ReverseModeArgs + DiffRequest ReverseModeRequest{}; + ReverseModeRequest.Mode = DiffMode::reverse; + ReverseModeRequest.Function = firstDerivative; + ReverseModeRequest.Args = ReverseModeArgs; + ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString(); + ReverseModeRequest.UpdateDiffParamsInfo(SemaRef); + + alreadyDerived = true; + FunctionDecl* secondDerivative = + Builder.FindDerivedFunction(ReverseModeRequest); + if (!secondDerivative) { + alreadyDerived = false; + // Derive declaration of the the reverse mode derivative. + ReverseModeRequest.DeclarationOnly = true; + secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest); + + // It is possible that user has provided a custom derivative for the + // derivative function. In that case, we should not derive the definition + // again. + if (secondDerivative->isDefined() || + DFC.IsCustomDerivative(secondDerivative)) + alreadyDerived = true; + + // Add the request to derive the definition of the reverse mode + // derivative to the schedule. + ReverseModeRequest.DeclarationOnly = false; + ReverseModeRequest.DerivedFDPrototype = secondDerivative; } + Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived); + return secondDerivative; +} DerivativeAndOverload HessianModeVisitor::Derive(const clang::FunctionDecl* FD, @@ -149,6 +151,26 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, } } + llvm::SmallVector paramTypes(m_DiffReq->getNumParams() + 1); + std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(), + std::begin(paramTypes), + [](const ParmVarDecl* PVD) { return PVD->getType(); }); + paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType()); + + const auto* originalFnProtoType = + cast(m_DiffReq->getType()); + QualType hessianFunctionType = m_Context.getFunctionType( + m_Context.VoidTy, + llvm::ArrayRef(paramTypes.data(), paramTypes.size()), + // Cast to function pointer. + originalFnProtoType->getExtProtoInfo()); + + // Check if the function is already declared as a custom derivative. + auto* DC = const_cast(m_DiffReq->getDeclContext()); + if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( + hessianFuncName, DC, hessianFunctionType)) + return DerivativeAndOverload{customDerivative, nullptr}; + // Ascertains the independent arguments and differentiates the function // in forward and reverse mode by calling ProcessDiffRequest twice each // iteration, storing each generated second derivative function @@ -209,7 +231,7 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, CreateStringLiteral(m_Context, independentArgString); auto* DFD = DeriveUsingForwardAndReverseMode( m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args); + request.Args, m_Builder.m_DFC); secondDerivativeColumns.push_back(DFD); } @@ -222,13 +244,14 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, CreateStringLiteral(m_Context, PVD->getNameAsString()); auto* DFD = DeriveUsingForwardAndReverseMode( m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args); + request.Args, m_Builder.m_DFC); secondDerivativeColumns.push_back(DFD); } } } return Merge(secondDerivativeColumns, IndependentArgsSize, - TotalIndependentArgsSize, hessianFuncName); + TotalIndependentArgsSize, hessianFuncName, DC, + hessianFunctionType, paramTypes); } // Combines all generated second derivative functions into a @@ -238,7 +261,9 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, HessianModeVisitor::Merge(std::vector secDerivFuncs, SmallVector IndependentArgsSize, size_t TotalIndependentArgsSize, - std::string hessianFuncName) { + const std::string& hessianFuncName, DeclContext* DC, + QualType hessianFunctionType, + llvm::SmallVector paramTypes) { DiffParams args; std::copy(m_DiffReq->param_begin(), m_DiffReq->param_end(), std::back_inserter(args)); @@ -246,28 +271,6 @@ static const StringLiteral* CreateStringLiteral(ASTContext& C, IdentifierInfo* II = &m_Context.Idents.get(hessianFuncName); DeclarationNameInfo name(II, noLoc); - llvm::SmallVector paramTypes(m_DiffReq->getNumParams() + 1); - - std::transform(m_DiffReq->param_begin(), m_DiffReq->param_end(), - std::begin(paramTypes), - [](const ParmVarDecl* PVD) { return PVD->getType(); }); - - paramTypes.back() = m_Context.getPointerType(m_DiffReq->getReturnType()); - - const auto* originalFnProtoType = - cast(m_DiffReq->getType()); - QualType hessianFunctionType = m_Context.getFunctionType( - m_Context.VoidTy, - llvm::ArrayRef(paramTypes.data(), paramTypes.size()), - // Cast to function pointer. - originalFnProtoType->getExtProtoInfo()); - - // Check if the function is already declared as a custom derivative. - auto* DC = const_cast(m_DiffReq->getDeclContext()); - if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - hessianFuncName, DC, hessianFunctionType)) - return DerivativeAndOverload{customDerivative, nullptr}; - // Create the gradient function declaration. llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), diff --git a/test/Hessian/BuiltinDerivatives.C b/test/Hessian/BuiltinDerivatives.C index 770b73cc0..40f9a19f1 100644 --- a/test/Hessian/BuiltinDerivatives.C +++ b/test/Hessian/BuiltinDerivatives.C @@ -9,25 +9,6 @@ #include "clad/Differentiator/Differentiator.h" #include -namespace clad { - namespace custom_derivatives { - float f7_darg0(float x, float y) { - return cos(x); - } - - float f7_darg1(float x, float y) { - return exp(y); - } - - void f8_hessian(double x, double y, double *hessianMatrix) { - hessianMatrix[0] = 1.0; - hessianMatrix[1] = 1.0; - hessianMatrix[2] = 1.0; - hessianMatrix[3] = 1.0; - } - } -} - float f1(float x) { return sin(x) + cos(x); } @@ -109,6 +90,29 @@ float f6(float x, float y) { // CHECK-NEXT: f6_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +namespace clad { + namespace custom_derivatives { + float f7_darg0(float x, float y) { + return cos(x); + } + + float f7_darg1(float x, float y) { + return exp(y); + } + + void f7_darg1_grad(float x, float y, float *d_x, float *d_y) { + *d_y += exp(y); + } + + void f8_hessian(float x, float y, float *hessianMatrix) { + hessianMatrix[0] = 1.0; + hessianMatrix[1] = 1.0; + hessianMatrix[2] = 1.0; + hessianMatrix[3] = 1.0; + } + } +} + float f7(float x, float y) { return sin(x) + exp(y); } @@ -123,13 +127,26 @@ float f7(float x, float y) { // CHECK-NEXT: return exp(y); // CHECK-NEXT: } -// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y); +// CHECK: void f7_darg1_grad(float x, float y, float *d_x, float *d_y) { +// CHECK-NEXT: *d_y += exp(y); +// CHECK-NEXT: } // CHECK: void f7_hessian(float x, float y, float *hessianMatrix) { // CHECK-NEXT: f7_darg0_grad(x, y, hessianMatrix + {{0U|0UL}}, hessianMatrix + {{1U|1UL}}); // CHECK-NEXT: f7_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +float f8(float x, float y) { + return (x*x + y*y)/2 + x*y; +} + +// CHECK: void f8_hessian(float x, float y, float *hessianMatrix) { +// CHECK-NEXT: hessianMatrix[0] = 1.; +// CHECK-NEXT: hessianMatrix[1] = 1.; +// CHECK-NEXT: hessianMatrix[2] = 1.; +// CHECK-NEXT: hessianMatrix[3] = 1.; +// CHECK-NEXT: } + #define TEST1(F, x) { \ result[0] = 0; \ auto h = clad::hessian(F); \ @@ -155,6 +172,7 @@ int main() { TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84} TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76} TEST2(f7, 3, 4); // CHECK-EXEC: Result is = {-0.14, 0.00, 0.00, 54.60} + TEST2(f8, 3, 4); // CHECK-EXEC: Result is = {1.00, 1.00, 1.00, 1.00} // CHECK: float f1_darg0(float x) { // CHECK-NEXT: float _d_x = 1; @@ -363,21 +381,11 @@ int main() { // CHECK-NEXT: _label0: // CHECK-NEXT: { // CHECK-NEXT: float _r0 = 0; -// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::cos_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward; // CHECK-NEXT: *_d_x += _r0; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y) { -// CHECK-NEXT: goto _label0; -// CHECK-NEXT: _label0: -// CHECK-NEXT: { -// CHECK-NEXT: float _r0 = 0; -// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::exp_pushforward(y, 1.F).pushforward; -// CHECK-NEXT: *_d_y += _r0; -// CHECK-NEXT: } -// CHECK-NEXT: } - // CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x) { // CHECK-NEXT: float _t0; // CHECK-NEXT: _t0 = ::std::cos(x); diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3ba226f31..2f4e23694 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -174,7 +174,8 @@ class CladTimerGroup { // setup, we exit early to give control to the non-standard setup for // code generation. // FIXME: This should go away if Cling starts using the clang driver. - if (!m_Multiplexer && m_DFC.IsDerivative(FD)) + if (!m_Multiplexer && + (m_DFC.IsCladDerivative(FD) || m_DFC.IsCustomDerivative(FD))) return true; HandleTopLevelDeclForClad(D); diff --git a/unittests/Misc/CallDeclOnly.cpp b/unittests/Misc/CallDeclOnly.cpp index 4ad44d29c..7925cb730 100644 --- a/unittests/Misc/CallDeclOnly.cpp +++ b/unittests/Misc/CallDeclOnly.cpp @@ -66,4 +66,33 @@ TEST(CallDeclOnly, CheckCustomDiff) { double dx = 0.0; grad.execute(&x, &dx); EXPECT_DOUBLE_EQ(dx, 2.0); +} + +namespace clad { +namespace custom_derivatives { +float custom_fn_darg0(float x, float y); + +void custom_fn_darg0_grad(float x, float y, float* d_x, float* d_y); + +float custom_fn_darg1(float x, float y) { return exp(y); } +} // namespace custom_derivatives +} // namespace clad + +float custom_fn(float x, float y) { + // This is to test that Clad actual doesn't generate a derivative for sin(x) + // as it is commented out, but use the user provided derivatives, which + // assumes function is sin(x) + exp(y). + return /*sin(x)*/ +exp(y); +} + +TEST(CallDeclOnly, CheckCustomDiff2) { + auto hessian = clad::hessian(custom_fn); + float result[4] = {0.0, 0.0, 0.0, 0.0}; + float x = 1.0; + float y = 2.0; + hessian.execute(x, y, result); + EXPECT_FLOAT_EQ(result[0], -sin(x)); + EXPECT_FLOAT_EQ(result[1], 0.0); + EXPECT_FLOAT_EQ(result[2], 0.0); + EXPECT_FLOAT_EQ(result[3], exp(y)); } \ No newline at end of file diff --git a/unittests/Misc/Defs.cpp b/unittests/Misc/Defs.cpp index bf0a2aa62..8d2efe910 100644 --- a/unittests/Misc/Defs.cpp +++ b/unittests/Misc/Defs.cpp @@ -21,5 +21,12 @@ _label0: { *_d_x += 2 * _d_y.pushforward; } } + +float custom_fn_darg0(float x, float y) { return cos(x); } + +void custom_fn_darg0_grad(float x, float y, float* d_x, float* d_y) { + *d_x -= sin(x); +} + } // namespace custom_derivatives } // namespace clad \ No newline at end of file