Skip to content

Commit

Permalink
Add support for computing only the diagonal hessian entries
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jun 18, 2024
1 parent 6e5ef74 commit 4831ba8
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 35 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ enum opts : unsigned {
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),

// Specifying whether we only want the diagonal of the hessian.
diagonal_only = 1 << (ORDER_BITS + 4),
}; // enum opts

constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) {
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
experimental_vector_pushforward,
reverse,
hessian,
hessian_diagonal,
jacobian,
reverse_mode_forward_pass,
error_estimation
Expand All @@ -33,6 +34,8 @@ inline const char* DiffModeToString(DiffMode mode) {
return "reverse";
case DiffMode::hessian:
return "hessian";
case DiffMode::hessian_diagonal:
return "hessian_diagonal";
case DiffMode::jacobian:
return "jacobian";
case DiffMode::reverse_mode_forward_pass:
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
} else if (request.Mode == DiffMode::hessian ||
request.Mode == DiffMode::hessian_diagonal) {
HessianModeVisitor H(*this, request);
result = H.Derive(FD, request);
} else if (request.Mode == DiffMode::jacobian) {
Expand Down
13 changes: 12 additions & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,14 @@ namespace clad {
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian "
"mode.");
return true;
}
}
}

if (A->getAnnotation().equals("D")) {
Expand Down Expand Up @@ -651,7 +659,10 @@ namespace clad {
}
}
} else if (A->getAnnotation().equals("H")) {
request.Mode = DiffMode::hessian;
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only))
request.Mode = DiffMode::hessian_diagonal;
else
request.Mode = DiffMode::hessian;
} else if (A->getAnnotation().equals("J")) {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
Expand Down
109 changes: 78 additions & 31 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
return secondDerivative;
}

/// Derives the function two times with forward mode AD and returns the
/// FunctionDecl obtained.
static FunctionDecl* DeriveUsingForwardModeTwice(
Sema& SemaRef, clad::plugin::CladPlugin& CP,
clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest,
const Expr* ForwardModeArgs, DerivedFnCollector& DFC) {
// Set derivative order in the request to 2.
IndependentArgRequest.RequestedDerivativeOrder = 2;
IndependentArgRequest.Args = ForwardModeArgs;
IndependentArgRequest.Mode = DiffMode::forward;
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// Derive the function twice in forward mode.
FunctionDecl* secondDerivative =
Builder.HandleNestedDiffRequest(IndependentArgRequest);
return secondDerivative;
}

DerivativeAndOverload
HessianModeVisitor::Derive(const clang::FunctionDecl* FD,
const DiffRequest& request) {
Expand All @@ -91,14 +109,16 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

std::vector<FunctionDecl*> secondDerivativeColumns;
std::vector<FunctionDecl*> secondDerivativeFuncs;
llvm::SmallVector<size_t, 16> IndependentArgsSize{};
size_t TotalIndependentArgsSize = 0;

// request.Function is original function passed in from clad::hessian
assert(m_DiffReq == request);

std::string hessianFuncName = request.BaseFunctionName + "_hessian";
if (request.Mode == DiffMode::hessian_diagonal)
hessianFuncName += "_diagonal";
// To be consistent with older tests, nothing is appended to 'f_hessian' if
// we differentiate w.r.t. all the parameters at once.
if (args.size() != FD->getNumParams() ||
Expand Down Expand Up @@ -192,27 +212,38 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
PVD->getNameAsString() + "[" + std::to_string(i) + "]";
auto ForwardModeIASL =
CreateStringLiteral(m_Context, independentArgString);
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
FunctionDecl* DFD = nullptr;
if (request.Mode == DiffMode::hessian_diagonal)
DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder,
request, ForwardModeIASL,
m_Builder.m_DFC);
else
DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeFuncs.push_back(DFD);
}

} else {
IndependentArgsSize.push_back(1);
TotalIndependentArgsSize++;
// Derive the function w.r.t. to the current arg in forward mode and
// then in reverse mode w.r.t to all requested args
auto ForwardModeIASL =
CreateStringLiteral(m_Context, PVD->getNameAsString());
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
FunctionDecl* DFD = nullptr;
if (request.Mode == DiffMode::hessian_diagonal)
DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder,
request, ForwardModeIASL,
m_Builder.m_DFC);
else
DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeFuncs.push_back(DFD);
}
}
}
return Merge(secondDerivativeColumns, IndependentArgsSize,
return Merge(secondDerivativeFuncs, IndependentArgsSize,
TotalIndependentArgsSize, hessianFuncName, DC,
hessianFunctionType, paramTypes);
}
Expand Down Expand Up @@ -272,14 +303,13 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
return VD;
});

// The output parameter "hessianMatrix".
// The output parameter "hessianMatrix" or "diagonalHessianVector"
std::string outputParamName = "hessianMatrix";
if (m_DiffReq.Mode == DiffMode::hessian_diagonal)
outputParamName = "diagonalHessianVector";
params.back() = ParmVarDecl::Create(
m_Context,
hessianFD,
noLoc,
noLoc,
&m_Context.Idents.get("hessianMatrix"),
paramTypes.back(),
m_Context, hessianFD, noLoc, noLoc,
&m_Context.Idents.get(outputParamName), paramTypes.back(),
m_Context.getTrivialTypeSourceInfo(paramTypes.back(), noLoc),
params.front()->getStorageClass(),
/* No default value */ nullptr);
Expand All @@ -301,7 +331,6 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
// Creates callExprs to the second derivative functions genereated
// and creates maps array elements to input array.
for (size_t i = 0, e = secDerivFuncs.size(); i < e; ++i) {
const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize;
auto size_type = m_Context.getSizeType();
auto size_type_bits = m_Context.getIntWidth(size_type);

Expand Down Expand Up @@ -345,22 +374,40 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
}
}

size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t indArgSize : IndependentArgsSize) {
llvm::APInt offsetValue(size_type_bits,
HessianMatrixStartIndex + columnIndex);
if (m_DiffReq.Mode == DiffMode::hessian_diagonal) {
const size_t HessianMatrixStartIndex = i;
// Call the derived function for second derivative.
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);

// Create the offset argument.
llvm::APInt offsetValue(size_type_bits, HessianMatrixStartIndex);
Expr* OffsetArg =
IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc);
// Create the hessianMatrix + OffsetArg expression.
Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg);

DeclRefToParams.push_back(SliceExpr);
columnIndex += indArgSize;
// Create a assignment expression to store the value of call expression
// into the diagonalHessianVector with index HessianMatrixStartIndex.
Expr* SliceExprLHS = BuildOp(BO_Add, m_Result, OffsetArg);
Expr* DerefExpr = BuildOp(UO_Deref, BuildParens(SliceExprLHS));
Expr* AssignExpr = BuildOp(BO_Assign, DerefExpr, call);
CompStmtSave.push_back(AssignExpr);
} else {
const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize;
size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t indArgSize : IndependentArgsSize) {
llvm::APInt offsetValue(size_type_bits,
HessianMatrixStartIndex + columnIndex);
// Create the offset argument.
Expr* OffsetArg =
IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc);
// Create the hessianMatrix + OffsetArg expression.
Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg);

DeclRefToParams.push_back(SliceExpr);
columnIndex += indArgSize;
}
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);
CompStmtSave.push_back(call);
}
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);
CompStmtSave.push_back(call);
}

auto StmtsRef =
Expand Down
29 changes: 29 additions & 0 deletions test/Hessian/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ double g(double i, double j[2]) { return i * (j[0] + j[1]); }
// CHECK-NEXT: g_darg1_1_grad(i, j, hessianMatrix + {{6U|6UL}}, hessianMatrix + {{7U|7UL}});
// CHECK-NEXT: }

double h(double arr[3], double weights[3], double multiplier) {
// return square of weighted sum.
double weightedSum = 0;
for (int i = 0; i < 3; ++i) {
weightedSum += arr[i] * weights[i];
}
weightedSum *= multiplier;
return weightedSum * weightedSum;
}
// CHECK: void h_hessian_diagonal(double arr[3], double weights[3], double multiplier, double *diagonalHessianVector) {
// CHECK-NEXT: *(diagonalHessianVector + 0UL) = h_d2arg0_0(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 1UL) = h_d2arg0_1(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 2UL) = h_d2arg0_2(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 3UL) = h_d2arg1_0(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 4UL) = h_d2arg1_1(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 5UL) = h_d2arg1_2(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 6UL) = h_d2arg2(arr, weights, multiplier);
// CHECK-NEXT: }

#define TEST(var, i, j) \
result[0] = result[1] = result[2] = result[3] = result[4] = result[5] = \
result[6] = result[7] = result[8] = 0; \
Expand All @@ -44,4 +63,14 @@ int main() {
TEST(h1, 2, x); // CHECK-EXEC: Result = {0.00 4.00 3.00 4.00 0.00 2.00 3.00 2.00 0.00}
auto h2 = clad::hessian(g, "i, j[0:1]");
TEST(h2, 2, x); // CHECK-EXEC: Result = {0.00 1.00 1.00 1.00 0.00 0.00 1.00 0.00 0.00}

double arr[] = {1, 2, 3};
double weights[] = {4, 5, 6};
double diag[7]; // result will be the diagonal of the Hessian matrix.
double multiplier = 2.0;
auto h3 = clad::hessian<clad::opts::diagonal_only>(h, "arr[0:2], weights[0:2], multiplier");
h3.execute(arr, weights, multiplier, diag);
printf("Diagonal (arr) = {%.2f %.2f %.2f},\n", diag[0], diag[1], diag[2]); // CHECK-EXEC: Diagonal (arr) = {128.00 200.00 288.00},
printf("Diagonal (weights) = {%.2f %.2f %.2f}\n", diag[3], diag[4], diag[5]); // CHECK-EXEC: Diagonal (weights) = {8.00 32.00 72.00}
printf("Diagonal (multiplier) = %.2f\n", diag[6]); // CHECK-EXEC: Diagonal (multiplier) = 2048.00
}
2 changes: 1 addition & 1 deletion tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ namespace clad {
FinalizeTranslationUnit();
}

FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) {
FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest request) {
Sema& S = m_CI.getSema();
// Required due to custom derivatives function templates that might be
// used in the function that we need to derive.
Expand Down
2 changes: 1 addition & 1 deletion tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class CladTimerGroup {

// FIXME: We should hide ProcessDiffRequest when we implement proper
// handling of the differentiation plans.
clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request);
clang::FunctionDecl* ProcessDiffRequest(DiffRequest request);

private:
void AppendDelayed(DelayedCallInfo DCI) {
Expand Down

0 comments on commit 4831ba8

Please sign in to comment.