From c335823c383581e13c489e56080b4a58ab25f285 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Tue, 11 Jun 2024 17:29:05 +0200 Subject: [PATCH] improve testing, comments and hessians --- .../clad/Differentiator/HessianModeVisitor.h | 4 +- lib/Differentiator/DerivativeBuilder.cpp | 4 ++ lib/Differentiator/HessianModeVisitor.cpp | 49 ++++++------- test/Hessian/BuiltinDerivatives.C | 70 +++++++++++-------- 4 files changed, 71 insertions(+), 56 deletions(-) diff --git a/include/clad/Differentiator/HessianModeVisitor.h b/include/clad/Differentiator/HessianModeVisitor.h index 7a8f2f88c..f56266fb8 100644 --- a/include/clad/Differentiator/HessianModeVisitor.h +++ b/include/clad/Differentiator/HessianModeVisitor.h @@ -30,7 +30,9 @@ namespace clad { DerivativeAndOverload Merge(std::vector secDerivFuncs, llvm::SmallVector IndependentArgsSize, - size_t TotalIndependentArgsSize, std::string hessianFuncName); + size_t TotalIndependentArgsSize, const std::string& hessianFuncName, + clang::DeclContext* FD, clang::QualType hessianFuncType, + llvm::SmallVector paramTypes); public: HessianModeVisitor(DerivativeBuilder& builder); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index ceed8eca1..c0b70ac73 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -244,6 +244,10 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { if (auto* FD = dyn_cast(ND)) // Check if FD and functionType have the same signature. if (utils::SameCanonicalType(FD->getType(), functionType)) + // Make sure that it is not the case that FD is the forward + // declaration generated by Clad. It should be user defined custom + // derivative (either within the same translation unit or linked in + // from another translation unit). if (FD->isDefined() || !m_DFC.IsDerivative(FD)) return FD; diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 937c43a41..ae98563ea 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -150,6 +150,26 @@ namespace clad { } } + llvm::SmallVector paramTypes(m_Function->getNumParams() + 1); + std::transform(m_Function->param_begin(), m_Function->param_end(), + std::begin(paramTypes), + [](const ParmVarDecl* PVD) { return PVD->getType(); }); + paramTypes.back() = m_Context.getPointerType(m_Function->getReturnType()); + + const auto* originalFnProtoType = + cast(m_Function->getType()); + QualType hessianFunctionType = m_Context.getFunctionType( + m_Context.VoidTy, + llvm::ArrayRef(paramTypes.data(), paramTypes.size()), + // Cast to function pointer. + originalFnProtoType->getExtProtoInfo()); + + // Check if the function is already declared as a custom derivative. + auto* DC = const_cast(m_Function->getDeclContext()); + if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( + hessianFuncName, DC, hessianFunctionType)) + return DerivativeAndOverload{customDerivative, nullptr}; + // Ascertains the independent arguments and differentiates the function // in forward and reverse mode by calling ProcessDiffRequest twice each // iteration, storing each generated second derivative function @@ -229,7 +249,8 @@ namespace clad { } } return Merge(secondDerivativeColumns, IndependentArgsSize, - TotalIndependentArgsSize, hessianFuncName); + TotalIndependentArgsSize, hessianFuncName, DC, + hessianFunctionType, paramTypes); } // Combines all generated second derivative functions into a @@ -239,7 +260,9 @@ namespace clad { HessianModeVisitor::Merge(std::vector secDerivFuncs, SmallVector IndependentArgsSize, size_t TotalIndependentArgsSize, - std::string hessianFuncName) { + const std::string& hessianFuncName, DeclContext* DC, + QualType hessianFunctionType, + llvm::SmallVector paramTypes) { DiffParams args; std::copy(m_Function->param_begin(), m_Function->param_end(), @@ -248,28 +271,6 @@ namespace clad { IdentifierInfo* II = &m_Context.Idents.get(hessianFuncName); DeclarationNameInfo name(II, noLoc); - llvm::SmallVector paramTypes(m_Function->getNumParams() + 1); - - std::transform(m_Function->param_begin(), - m_Function->param_end(), - std::begin(paramTypes), - [](const ParmVarDecl* PVD) { return PVD->getType(); }); - - paramTypes.back() = m_Context.getPointerType(m_Function->getReturnType()); - - auto originalFnProtoType = cast(m_Function->getType()); - QualType hessianFunctionType = m_Context.getFunctionType( - m_Context.VoidTy, - llvm::ArrayRef(paramTypes.data(), paramTypes.size()), - // Cast to function pointer. - originalFnProtoType->getExtProtoInfo()); - - // Check if the function is already declared as a custom derivative. - DeclContext* DC = const_cast(m_Function->getDeclContext()); - if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - hessianFuncName, DC, hessianFunctionType)) - return DerivativeAndOverload{customDerivative, nullptr}; - // Create the gradient function declaration. llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope(), diff --git a/test/Hessian/BuiltinDerivatives.C b/test/Hessian/BuiltinDerivatives.C index 770b73cc0..40f9a19f1 100644 --- a/test/Hessian/BuiltinDerivatives.C +++ b/test/Hessian/BuiltinDerivatives.C @@ -9,25 +9,6 @@ #include "clad/Differentiator/Differentiator.h" #include -namespace clad { - namespace custom_derivatives { - float f7_darg0(float x, float y) { - return cos(x); - } - - float f7_darg1(float x, float y) { - return exp(y); - } - - void f8_hessian(double x, double y, double *hessianMatrix) { - hessianMatrix[0] = 1.0; - hessianMatrix[1] = 1.0; - hessianMatrix[2] = 1.0; - hessianMatrix[3] = 1.0; - } - } -} - float f1(float x) { return sin(x) + cos(x); } @@ -109,6 +90,29 @@ float f6(float x, float y) { // CHECK-NEXT: f6_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +namespace clad { + namespace custom_derivatives { + float f7_darg0(float x, float y) { + return cos(x); + } + + float f7_darg1(float x, float y) { + return exp(y); + } + + void f7_darg1_grad(float x, float y, float *d_x, float *d_y) { + *d_y += exp(y); + } + + void f8_hessian(float x, float y, float *hessianMatrix) { + hessianMatrix[0] = 1.0; + hessianMatrix[1] = 1.0; + hessianMatrix[2] = 1.0; + hessianMatrix[3] = 1.0; + } + } +} + float f7(float x, float y) { return sin(x) + exp(y); } @@ -123,13 +127,26 @@ float f7(float x, float y) { // CHECK-NEXT: return exp(y); // CHECK-NEXT: } -// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y); +// CHECK: void f7_darg1_grad(float x, float y, float *d_x, float *d_y) { +// CHECK-NEXT: *d_y += exp(y); +// CHECK-NEXT: } // CHECK: void f7_hessian(float x, float y, float *hessianMatrix) { // CHECK-NEXT: f7_darg0_grad(x, y, hessianMatrix + {{0U|0UL}}, hessianMatrix + {{1U|1UL}}); // CHECK-NEXT: f7_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +float f8(float x, float y) { + return (x*x + y*y)/2 + x*y; +} + +// CHECK: void f8_hessian(float x, float y, float *hessianMatrix) { +// CHECK-NEXT: hessianMatrix[0] = 1.; +// CHECK-NEXT: hessianMatrix[1] = 1.; +// CHECK-NEXT: hessianMatrix[2] = 1.; +// CHECK-NEXT: hessianMatrix[3] = 1.; +// CHECK-NEXT: } + #define TEST1(F, x) { \ result[0] = 0; \ auto h = clad::hessian(F); \ @@ -155,6 +172,7 @@ int main() { TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84} TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76} TEST2(f7, 3, 4); // CHECK-EXEC: Result is = {-0.14, 0.00, 0.00, 54.60} + TEST2(f8, 3, 4); // CHECK-EXEC: Result is = {1.00, 1.00, 1.00, 1.00} // CHECK: float f1_darg0(float x) { // CHECK-NEXT: float _d_x = 1; @@ -363,21 +381,11 @@ int main() { // CHECK-NEXT: _label0: // CHECK-NEXT: { // CHECK-NEXT: float _r0 = 0; -// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::cos_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward; // CHECK-NEXT: *_d_x += _r0; // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y) { -// CHECK-NEXT: goto _label0; -// CHECK-NEXT: _label0: -// CHECK-NEXT: { -// CHECK-NEXT: float _r0 = 0; -// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::exp_pushforward(y, 1.F).pushforward; -// CHECK-NEXT: *_d_y += _r0; -// CHECK-NEXT: } -// CHECK-NEXT: } - // CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x) { // CHECK-NEXT: float _t0; // CHECK-NEXT: _t0 = ::std::cos(x);