From 72379866345298b70b9850cf743c4a01b253a4cf Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 20 Jun 2024 18:11:41 +0200 Subject: [PATCH] Add support for computing only the diagonal hessian entries fixes #509 --- include/clad/Differentiator/CladConfig.h | 3 + include/clad/Differentiator/DiffMode.h | 3 + include/clad/Differentiator/DiffPlanner.h | 8 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 10 +- lib/Differentiator/DerivativeBuilder.cpp | 13 ++- lib/Differentiator/DiffPlanner.cpp | 13 ++- lib/Differentiator/HessianModeVisitor.cpp | 109 +++++++++++++----- .../ReverseModeForwPassVisitor.cpp | 5 +- lib/Differentiator/ReverseModeVisitor.cpp | 12 +- test/FirstDerivative/FunctionCalls.C | 1 + test/Hessian/Arrays.C | 28 +++++ tools/ClangPlugin.cpp | 3 + 12 files changed, 160 insertions(+), 48 deletions(-) diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index f2c06bec2..8c0eb3b5f 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -29,6 +29,9 @@ enum opts : unsigned { // 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid enable_tbr = 1 << (ORDER_BITS + 2), disable_tbr = 1 << (ORDER_BITS + 3), + + // Specifying whether we only want the diagonal of the hessian. + diagonal_only = 1 << (ORDER_BITS + 4), }; // enum opts constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) { diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index b079f554d..6c7e6f6c3 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -11,6 +11,7 @@ enum class DiffMode { experimental_vector_pushforward, reverse, hessian, + hessian_diagonal, jacobian, reverse_mode_forward_pass, error_estimation @@ -33,6 +34,8 @@ inline const char* DiffModeToString(DiffMode mode) { return "reverse"; case DiffMode::hessian: return "hessian"; + case DiffMode::hessian_diagonal: + return "hessian_diagonal"; case DiffMode::jacobian: return "jacobian"; case DiffMode::reverse_mode_forward_pass: diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 58359a14b..5da49f13f 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -65,9 +65,11 @@ struct DiffRequest { // A flag to enable the use of enzyme for backend instead of clad bool use_enzyme = false; - /// A pointer to keep track of the prototype of the derived function. - /// This will be particularly useful for pushforward and pullback functions. - clang::FunctionDecl* DerivedFDPrototype = nullptr; + /// A pointer to keep track of the prototype of the derived functions. + /// For higher order derivatives, we store the entire sequence of + /// prototypes declared for all orders of derivatives. + /// This will be useful for forward declaration of the derived functions. + llvm::SmallVector DerivedFDPrototypes; /// A boolean to indicate if only the declaration of the derived function /// is required (and not the definition or body). diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 75122c3e8..ae49c0bcd 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -364,8 +364,9 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, endScope(); // Function body scope - if (request.DerivedFDPrototype) - m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + m_Derivative->setPreviousDeclaration( + request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); @@ -529,8 +530,9 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, endScope(); // Function body scope - if (request.DerivedFDPrototype) - m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + m_Derivative->setPreviousDeclaration( + request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 33ba9b2e0..9e28775cc 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -295,6 +295,11 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { FunctionDecl* derivative = this->FindDerivedFunction(request); if (!derivative) { alreadyDerived = false; + + // Store the function and its order before processing the nested request. + const FunctionDecl* origFn = request.Function; + unsigned origFnOrder = request.CurrentDerivativeOrder; + // Derive declaration of the the forward mode derivative. request.DeclarationOnly = true; derivative = plugin::ProcessDiffRequest(m_CladPlugin, request); @@ -306,10 +311,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { (derivative->isDefined() || m_DFC.IsCustomDerivative(derivative))) alreadyDerived = true; + // Restore the original function and its order. + request.CurrentDerivativeOrder = origFnOrder; + request.Function = origFn; + // Add the request to derive the definition of the forward mode derivative // to the schedule. request.DeclarationOnly = false; - request.DerivedFDPrototype = derivative; } this->AddEdgeToGraph(request, alreadyDerived); return derivative; @@ -423,7 +431,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { ReverseModeForwPassVisitor V(*this, request); result = V.Derive(FD, request); - } else if (request.Mode == DiffMode::hessian) { + } else if (request.Mode == DiffMode::hessian || + request.Mode == DiffMode::hessian_diagonal) { HessianModeVisitor H(*this, request); result = H.Derive(FD, request); } else if (request.Mode == DiffMode::jacobian) { diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 07a7993c1..b5e7e1ffd 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -616,6 +616,14 @@ namespace clad { } else { request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; } + if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { + if (!A->getAnnotation().equals("H")) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "Diagonal only option is only valid for Hessian " + "mode."); + return true; + } + } } if (A->getAnnotation().equals("D")) { @@ -651,7 +659,10 @@ namespace clad { } } } else if (A->getAnnotation().equals("H")) { - request.Mode = DiffMode::hessian; + if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) + request.Mode = DiffMode::hessian_diagonal; + else + request.Mode = DiffMode::hessian; } else if (A->getAnnotation().equals("J")) { request.Mode = DiffMode::jacobian; } else if (A->getAnnotation().equals("G")) { diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 3bcdfa51b..55161e6e8 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -75,6 +75,24 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( return secondDerivative; } +/// Derives the function two times with forward mode AD and returns the +/// FunctionDecl obtained. +static FunctionDecl* DeriveUsingForwardModeTwice( + Sema& SemaRef, clad::plugin::CladPlugin& CP, + clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest, + const Expr* ForwardModeArgs, DerivedFnCollector& DFC) { + // Set derivative order in the request to 2. + IndependentArgRequest.RequestedDerivativeOrder = 2; + IndependentArgRequest.Args = ForwardModeArgs; + IndependentArgRequest.Mode = DiffMode::forward; + IndependentArgRequest.CallUpdateRequired = false; + IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); + // Derive the function twice in forward mode. + FunctionDecl* secondDerivative = + Builder.HandleNestedDiffRequest(IndependentArgRequest); + return secondDerivative; +} + DerivativeAndOverload HessianModeVisitor::Derive(const clang::FunctionDecl* FD, const DiffRequest& request) { @@ -91,7 +109,7 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - std::vector secondDerivativeColumns; + std::vector secondDerivativeFuncs; llvm::SmallVector IndependentArgsSize{}; size_t TotalIndependentArgsSize = 0; @@ -99,6 +117,8 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( assert(m_DiffReq == request); std::string hessianFuncName = request.BaseFunctionName + "_hessian"; + if (request.Mode == DiffMode::hessian_diagonal) + hessianFuncName += "_diagonal"; // To be consistent with older tests, nothing is appended to 'f_hessian' if // we differentiate w.r.t. all the parameters at once. if (args.size() != FD->getNumParams() || @@ -192,12 +212,17 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( PVD->getNameAsString() + "[" + std::to_string(i) + "]"; auto ForwardModeIASL = CreateStringLiteral(m_Context, independentArgString); - auto* DFD = DeriveUsingForwardAndReverseMode( - m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args, m_Builder.m_DFC); - secondDerivativeColumns.push_back(DFD); + FunctionDecl* DFD = nullptr; + if (request.Mode == DiffMode::hessian_diagonal) + DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, + request, ForwardModeIASL, + m_Builder.m_DFC); + else + DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args, m_Builder.m_DFC); + secondDerivativeFuncs.push_back(DFD); } - } else { IndependentArgsSize.push_back(1); TotalIndependentArgsSize++; @@ -205,14 +230,20 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( // then in reverse mode w.r.t to all requested args auto ForwardModeIASL = CreateStringLiteral(m_Context, PVD->getNameAsString()); - auto* DFD = DeriveUsingForwardAndReverseMode( - m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args, m_Builder.m_DFC); - secondDerivativeColumns.push_back(DFD); + FunctionDecl* DFD = nullptr; + if (request.Mode == DiffMode::hessian_diagonal) + DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, + request, ForwardModeIASL, + m_Builder.m_DFC); + else + DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args, m_Builder.m_DFC); + secondDerivativeFuncs.push_back(DFD); } } } - return Merge(secondDerivativeColumns, IndependentArgsSize, + return Merge(secondDerivativeFuncs, IndependentArgsSize, TotalIndependentArgsSize, hessianFuncName, DC, hessianFunctionType, paramTypes); } @@ -272,14 +303,13 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( return VD; }); - // The output parameter "hessianMatrix". + // The output parameter "hessianMatrix" or "diagonalHessianVector" + std::string outputParamName = "hessianMatrix"; + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) + outputParamName = "diagonalHessianVector"; params.back() = ParmVarDecl::Create( - m_Context, - hessianFD, - noLoc, - noLoc, - &m_Context.Idents.get("hessianMatrix"), - paramTypes.back(), + m_Context, hessianFD, noLoc, noLoc, + &m_Context.Idents.get(outputParamName), paramTypes.back(), m_Context.getTrivialTypeSourceInfo(paramTypes.back(), noLoc), params.front()->getStorageClass(), /* No default value */ nullptr); @@ -301,7 +331,6 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( // Creates callExprs to the second derivative functions genereated // and creates maps array elements to input array. for (size_t i = 0, e = secDerivFuncs.size(); i < e; ++i) { - const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize; auto size_type = m_Context.getSizeType(); auto size_type_bits = m_Context.getIntWidth(size_type); @@ -345,22 +374,40 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( } } - size_t columnIndex = 0; - // Create Expr parameters for each independent arg in the CallExpr - for (size_t indArgSize : IndependentArgsSize) { - llvm::APInt offsetValue(size_type_bits, - HessianMatrixStartIndex + columnIndex); + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) { + const size_t HessianMatrixStartIndex = i; + // Call the derived function for second derivative. + Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams); + // Create the offset argument. + llvm::APInt offsetValue(size_type_bits, HessianMatrixStartIndex); Expr* OffsetArg = IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc); - // Create the hessianMatrix + OffsetArg expression. - Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg); - - DeclRefToParams.push_back(SliceExpr); - columnIndex += indArgSize; + // Create a assignment expression to store the value of call expression + // into the diagonalHessianVector with index HessianMatrixStartIndex. + Expr* SliceExprLHS = BuildOp(BO_Add, m_Result, OffsetArg); + Expr* DerefExpr = BuildOp(UO_Deref, BuildParens(SliceExprLHS)); + Expr* AssignExpr = BuildOp(BO_Assign, DerefExpr, call); + CompStmtSave.push_back(AssignExpr); + } else { + const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize; + size_t columnIndex = 0; + // Create Expr parameters for each independent arg in the CallExpr + for (size_t indArgSize : IndependentArgsSize) { + llvm::APInt offsetValue(size_type_bits, + HessianMatrixStartIndex + columnIndex); + // Create the offset argument. + Expr* OffsetArg = + IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc); + // Create the hessianMatrix + OffsetArg expression. + Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg); + + DeclRefToParams.push_back(SliceExpr); + columnIndex += indArgSize; + } + Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams); + CompStmtSave.push_back(call); } - Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams); - CompStmtSave.push_back(call); } auto StmtsRef = diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 3c20f60d6..ea0c0b1f0 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -86,8 +86,9 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, m_Derivative->setBody(fnBody); endScope(); - if (request.DerivedFDPrototype) - m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + m_Derivative->setPreviousDeclaration( + request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6b946d94a..1a505bcf4 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -336,7 +336,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // added by the plugins yet. if (request.Mode != DiffMode::jacobian && numExtraParam == 0) shouldCreateOverload = true; - if (request.DerivedFDPrototype) + if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty()) // If the overload is already created, we don't need to create it again. shouldCreateOverload = false; @@ -452,8 +452,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Derivative->setBody(gradientBody); endScope(); // Function body scope - if (request.DerivedFDPrototype) - m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + m_Derivative->setPreviousDeclaration( + request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); @@ -585,8 +586,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Derivative->setBody(fnBody); endScope(); // Function body scope - if (request.DerivedFDPrototype) - m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) + m_Derivative->setPreviousDeclaration( + request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index 8222d7333..33e15c2c6 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -194,6 +194,7 @@ int main () { clad::differentiate(test_8, "x"); clad::differentiate(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}} clad::differentiate(test_8); // expected-error {{Both enable and disable TBR options are specified.}} + clad::differentiate(test_8); // expected-error {{Diagonal only option is only valid for Hessian mode.}} clad::differentiate(test_9); return 0; diff --git a/test/Hessian/Arrays.C b/test/Hessian/Arrays.C index 67e6d6c4b..78dbb7ea5 100644 --- a/test/Hessian/Arrays.C +++ b/test/Hessian/Arrays.C @@ -21,6 +21,24 @@ double g(double i, double j[2]) { return i * (j[0] + j[1]); } // CHECK-NEXT: g_darg1_1_grad(i, j, hessianMatrix + {{6U|6UL}}, hessianMatrix + {{7U|7UL}}); // CHECK-NEXT: } +double h(double arr[3], double weights[3], double multiplier) { + // return square of weighted sum. + double weightedSum = arr[0] * weights[0]; + weightedSum += arr[1] * weights[1]; + weightedSum += arr[2] * weights[2]; + weightedSum *= multiplier; + return weightedSum * weightedSum; +} +// CHECK: void h_hessian_diagonal(double arr[3], double weights[3], double multiplier, double *diagonalHessianVector) { +// CHECK-NEXT: *(diagonalHessianVector + 0{{U|UL}}) = h_d2arg0_0(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 1{{U|UL}}) = h_d2arg0_1(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 2{{U|UL}}) = h_d2arg0_2(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 3{{U|UL}}) = h_d2arg1_0(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 4{{U|UL}}) = h_d2arg1_1(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 5{{U|UL}}) = h_d2arg1_2(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 6{{U|UL}}) = h_d2arg2(arr, weights, multiplier); +// CHECK-NEXT: } + #define TEST(var, i, j) \ result[0] = result[1] = result[2] = result[3] = result[4] = result[5] = \ result[6] = result[7] = result[8] = 0; \ @@ -44,4 +62,14 @@ int main() { TEST(h1, 2, x); // CHECK-EXEC: Result = {0.00 4.00 3.00 4.00 0.00 2.00 3.00 2.00 0.00} auto h2 = clad::hessian(g, "i, j[0:1]"); TEST(h2, 2, x); // CHECK-EXEC: Result = {0.00 1.00 1.00 1.00 0.00 0.00 1.00 0.00 0.00} + + double arr[] = {1, 2, 3}; + double weights[] = {4, 5, 6}; + double diag[7]; // result will be the diagonal of the Hessian matrix. + double multiplier = 2.0; + auto h3 = clad::hessian(h, "arr[0:2], weights[0:2], multiplier"); + h3.execute(arr, weights, multiplier, diag); + printf("Diagonal (arr) = {%.2f %.2f %.2f},\n", diag[0], diag[1], diag[2]); // CHECK-EXEC: Diagonal (arr) = {128.00 200.00 288.00}, + printf("Diagonal (weights) = {%.2f %.2f %.2f}\n", diag[3], diag[4], diag[5]); // CHECK-EXEC: Diagonal (weights) = {8.00 32.00 72.00} + printf("Diagonal (multiplier) = %.2f\n", diag[6]); // CHECK-EXEC: Diagonal (multiplier) = 2048.00 } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 5cf698a36..e62966429 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -286,6 +286,9 @@ namespace clad { request.updateCall(DerivativeDecl, OverloadedDerivativeDecl, m_CI.getSema()); + if (request.DeclarationOnly) + request.DerivedFDPrototypes.push_back(DerivativeDecl); + // Last requested order was computed, return the result. if (lastDerivativeOrder) return DerivativeDecl;