From 4831ba843bbd9aa0432b339f8b0494181fd0e571 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Tue, 18 Jun 2024 15:27:27 +0200 Subject: [PATCH] Add support for computing only the diagonal hessian entries --- include/clad/Differentiator/CladConfig.h | 3 + include/clad/Differentiator/DiffMode.h | 3 + lib/Differentiator/DerivativeBuilder.cpp | 3 +- lib/Differentiator/DiffPlanner.cpp | 13 ++- lib/Differentiator/HessianModeVisitor.cpp | 109 ++++++++++++++++------ test/Hessian/Arrays.C | 29 ++++++ tools/ClangPlugin.cpp | 2 +- tools/ClangPlugin.h | 2 +- 8 files changed, 129 insertions(+), 35 deletions(-) diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index f2c06bec2..8c0eb3b5f 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -29,6 +29,9 @@ enum opts : unsigned { // 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid enable_tbr = 1 << (ORDER_BITS + 2), disable_tbr = 1 << (ORDER_BITS + 3), + + // Specifying whether we only want the diagonal of the hessian. + diagonal_only = 1 << (ORDER_BITS + 4), }; // enum opts constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) { diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index b079f554d..6c7e6f6c3 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -11,6 +11,7 @@ enum class DiffMode { experimental_vector_pushforward, reverse, hessian, + hessian_diagonal, jacobian, reverse_mode_forward_pass, error_estimation @@ -33,6 +34,8 @@ inline const char* DiffModeToString(DiffMode mode) { return "reverse"; case DiffMode::hessian: return "hessian"; + case DiffMode::hessian_diagonal: + return "hessian_diagonal"; case DiffMode::jacobian: return "jacobian"; case DiffMode::reverse_mode_forward_pass: diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 24cd53b0d..92898bd98 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -430,7 +430,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { ReverseModeForwPassVisitor V(*this, request); result = V.Derive(FD, request); - } else if (request.Mode == DiffMode::hessian) { + } else if (request.Mode == DiffMode::hessian || + request.Mode == DiffMode::hessian_diagonal) { HessianModeVisitor H(*this, request); result = H.Derive(FD, request); } else if (request.Mode == DiffMode::jacobian) { diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 07a7993c1..b5e7e1ffd 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -616,6 +616,14 @@ namespace clad { } else { request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; } + if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { + if (!A->getAnnotation().equals("H")) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "Diagonal only option is only valid for Hessian " + "mode."); + return true; + } + } } if (A->getAnnotation().equals("D")) { @@ -651,7 +659,10 @@ namespace clad { } } } else if (A->getAnnotation().equals("H")) { - request.Mode = DiffMode::hessian; + if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) + request.Mode = DiffMode::hessian_diagonal; + else + request.Mode = DiffMode::hessian; } else if (A->getAnnotation().equals("J")) { request.Mode = DiffMode::jacobian; } else if (A->getAnnotation().equals("G")) { diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 3bcdfa51b..55161e6e8 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -75,6 +75,24 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( return secondDerivative; } +/// Derives the function two times with forward mode AD and returns the +/// FunctionDecl obtained. +static FunctionDecl* DeriveUsingForwardModeTwice( + Sema& SemaRef, clad::plugin::CladPlugin& CP, + clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest, + const Expr* ForwardModeArgs, DerivedFnCollector& DFC) { + // Set derivative order in the request to 2. + IndependentArgRequest.RequestedDerivativeOrder = 2; + IndependentArgRequest.Args = ForwardModeArgs; + IndependentArgRequest.Mode = DiffMode::forward; + IndependentArgRequest.CallUpdateRequired = false; + IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); + // Derive the function twice in forward mode. + FunctionDecl* secondDerivative = + Builder.HandleNestedDiffRequest(IndependentArgRequest); + return secondDerivative; +} + DerivativeAndOverload HessianModeVisitor::Derive(const clang::FunctionDecl* FD, const DiffRequest& request) { @@ -91,7 +109,7 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - std::vector secondDerivativeColumns; + std::vector secondDerivativeFuncs; llvm::SmallVector IndependentArgsSize{}; size_t TotalIndependentArgsSize = 0; @@ -99,6 +117,8 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( assert(m_DiffReq == request); std::string hessianFuncName = request.BaseFunctionName + "_hessian"; + if (request.Mode == DiffMode::hessian_diagonal) + hessianFuncName += "_diagonal"; // To be consistent with older tests, nothing is appended to 'f_hessian' if // we differentiate w.r.t. all the parameters at once. if (args.size() != FD->getNumParams() || @@ -192,12 +212,17 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( PVD->getNameAsString() + "[" + std::to_string(i) + "]"; auto ForwardModeIASL = CreateStringLiteral(m_Context, independentArgString); - auto* DFD = DeriveUsingForwardAndReverseMode( - m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args, m_Builder.m_DFC); - secondDerivativeColumns.push_back(DFD); + FunctionDecl* DFD = nullptr; + if (request.Mode == DiffMode::hessian_diagonal) + DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, + request, ForwardModeIASL, + m_Builder.m_DFC); + else + DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args, m_Builder.m_DFC); + secondDerivativeFuncs.push_back(DFD); } - } else { IndependentArgsSize.push_back(1); TotalIndependentArgsSize++; @@ -205,14 +230,20 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( // then in reverse mode w.r.t to all requested args auto ForwardModeIASL = CreateStringLiteral(m_Context, PVD->getNameAsString()); - auto* DFD = DeriveUsingForwardAndReverseMode( - m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, - request.Args, m_Builder.m_DFC); - secondDerivativeColumns.push_back(DFD); + FunctionDecl* DFD = nullptr; + if (request.Mode == DiffMode::hessian_diagonal) + DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder, + request, ForwardModeIASL, + m_Builder.m_DFC); + else + DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args, m_Builder.m_DFC); + secondDerivativeFuncs.push_back(DFD); } } } - return Merge(secondDerivativeColumns, IndependentArgsSize, + return Merge(secondDerivativeFuncs, IndependentArgsSize, TotalIndependentArgsSize, hessianFuncName, DC, hessianFunctionType, paramTypes); } @@ -272,14 +303,13 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( return VD; }); - // The output parameter "hessianMatrix". + // The output parameter "hessianMatrix" or "diagonalHessianVector" + std::string outputParamName = "hessianMatrix"; + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) + outputParamName = "diagonalHessianVector"; params.back() = ParmVarDecl::Create( - m_Context, - hessianFD, - noLoc, - noLoc, - &m_Context.Idents.get("hessianMatrix"), - paramTypes.back(), + m_Context, hessianFD, noLoc, noLoc, + &m_Context.Idents.get(outputParamName), paramTypes.back(), m_Context.getTrivialTypeSourceInfo(paramTypes.back(), noLoc), params.front()->getStorageClass(), /* No default value */ nullptr); @@ -301,7 +331,6 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( // Creates callExprs to the second derivative functions genereated // and creates maps array elements to input array. for (size_t i = 0, e = secDerivFuncs.size(); i < e; ++i) { - const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize; auto size_type = m_Context.getSizeType(); auto size_type_bits = m_Context.getIntWidth(size_type); @@ -345,22 +374,40 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( } } - size_t columnIndex = 0; - // Create Expr parameters for each independent arg in the CallExpr - for (size_t indArgSize : IndependentArgsSize) { - llvm::APInt offsetValue(size_type_bits, - HessianMatrixStartIndex + columnIndex); + if (m_DiffReq.Mode == DiffMode::hessian_diagonal) { + const size_t HessianMatrixStartIndex = i; + // Call the derived function for second derivative. + Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams); + // Create the offset argument. + llvm::APInt offsetValue(size_type_bits, HessianMatrixStartIndex); Expr* OffsetArg = IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc); - // Create the hessianMatrix + OffsetArg expression. - Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg); - - DeclRefToParams.push_back(SliceExpr); - columnIndex += indArgSize; + // Create a assignment expression to store the value of call expression + // into the diagonalHessianVector with index HessianMatrixStartIndex. + Expr* SliceExprLHS = BuildOp(BO_Add, m_Result, OffsetArg); + Expr* DerefExpr = BuildOp(UO_Deref, BuildParens(SliceExprLHS)); + Expr* AssignExpr = BuildOp(BO_Assign, DerefExpr, call); + CompStmtSave.push_back(AssignExpr); + } else { + const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize; + size_t columnIndex = 0; + // Create Expr parameters for each independent arg in the CallExpr + for (size_t indArgSize : IndependentArgsSize) { + llvm::APInt offsetValue(size_type_bits, + HessianMatrixStartIndex + columnIndex); + // Create the offset argument. + Expr* OffsetArg = + IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc); + // Create the hessianMatrix + OffsetArg expression. + Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg); + + DeclRefToParams.push_back(SliceExpr); + columnIndex += indArgSize; + } + Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams); + CompStmtSave.push_back(call); } - Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams); - CompStmtSave.push_back(call); } auto StmtsRef = diff --git a/test/Hessian/Arrays.C b/test/Hessian/Arrays.C index 67e6d6c4b..4f2a7f402 100644 --- a/test/Hessian/Arrays.C +++ b/test/Hessian/Arrays.C @@ -21,6 +21,25 @@ double g(double i, double j[2]) { return i * (j[0] + j[1]); } // CHECK-NEXT: g_darg1_1_grad(i, j, hessianMatrix + {{6U|6UL}}, hessianMatrix + {{7U|7UL}}); // CHECK-NEXT: } +double h(double arr[3], double weights[3], double multiplier) { + // return square of weighted sum. + double weightedSum = 0; + for (int i = 0; i < 3; ++i) { + weightedSum += arr[i] * weights[i]; + } + weightedSum *= multiplier; + return weightedSum * weightedSum; +} +// CHECK: void h_hessian_diagonal(double arr[3], double weights[3], double multiplier, double *diagonalHessianVector) { +// CHECK-NEXT: *(diagonalHessianVector + 0UL) = h_d2arg0_0(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 1UL) = h_d2arg0_1(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 2UL) = h_d2arg0_2(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 3UL) = h_d2arg1_0(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 4UL) = h_d2arg1_1(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 5UL) = h_d2arg1_2(arr, weights, multiplier); +// CHECK-NEXT: *(diagonalHessianVector + 6UL) = h_d2arg2(arr, weights, multiplier); +// CHECK-NEXT: } + #define TEST(var, i, j) \ result[0] = result[1] = result[2] = result[3] = result[4] = result[5] = \ result[6] = result[7] = result[8] = 0; \ @@ -44,4 +63,14 @@ int main() { TEST(h1, 2, x); // CHECK-EXEC: Result = {0.00 4.00 3.00 4.00 0.00 2.00 3.00 2.00 0.00} auto h2 = clad::hessian(g, "i, j[0:1]"); TEST(h2, 2, x); // CHECK-EXEC: Result = {0.00 1.00 1.00 1.00 0.00 0.00 1.00 0.00 0.00} + + double arr[] = {1, 2, 3}; + double weights[] = {4, 5, 6}; + double diag[7]; // result will be the diagonal of the Hessian matrix. + double multiplier = 2.0; + auto h3 = clad::hessian(h, "arr[0:2], weights[0:2], multiplier"); + h3.execute(arr, weights, multiplier, diag); + printf("Diagonal (arr) = {%.2f %.2f %.2f},\n", diag[0], diag[1], diag[2]); // CHECK-EXEC: Diagonal (arr) = {128.00 200.00 288.00}, + printf("Diagonal (weights) = {%.2f %.2f %.2f}\n", diag[3], diag[4], diag[5]); // CHECK-EXEC: Diagonal (weights) = {8.00 32.00 72.00} + printf("Diagonal (multiplier) = %.2f\n", diag[6]); // CHECK-EXEC: Diagonal (multiplier) = 2048.00 } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 5cf698a36..7dffb4256 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -137,7 +137,7 @@ namespace clad { FinalizeTranslationUnit(); } - FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { + FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest request) { Sema& S = m_CI.getSema(); // Required due to custom derivatives function templates that might be // used in the function that we need to derive. diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 2f4e23694..808cf1ec6 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -250,7 +250,7 @@ class CladTimerGroup { // FIXME: We should hide ProcessDiffRequest when we implement proper // handling of the differentiation plans. - clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); + clang::FunctionDecl* ProcessDiffRequest(DiffRequest request); private: void AppendDelayed(DelayedCallInfo DCI) {