Skip to content

Commit

Permalink
Add support for computing only the diagonal hessian entries
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jun 20, 2024
1 parent ac24906 commit 7237986
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 48 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ enum opts : unsigned {
// 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid
enable_tbr = 1 << (ORDER_BITS + 2),
disable_tbr = 1 << (ORDER_BITS + 3),

// Specifying whether we only want the diagonal of the hessian.
diagonal_only = 1 << (ORDER_BITS + 4),
}; // enum opts

constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) {
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
experimental_vector_pushforward,
reverse,
hessian,
hessian_diagonal,
jacobian,
reverse_mode_forward_pass,
error_estimation
Expand All @@ -33,6 +34,8 @@ inline const char* DiffModeToString(DiffMode mode) {
return "reverse";
case DiffMode::hessian:
return "hessian";
case DiffMode::hessian_diagonal:
return "hessian_diagonal";
case DiffMode::jacobian:
return "jacobian";
case DiffMode::reverse_mode_forward_pass:
Expand Down
8 changes: 5 additions & 3 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ struct DiffRequest {
// A flag to enable the use of enzyme for backend instead of clad
bool use_enzyme = false;

/// A pointer to keep track of the prototype of the derived function.
/// This will be particularly useful for pushforward and pullback functions.
clang::FunctionDecl* DerivedFDPrototype = nullptr;
/// A pointer to keep track of the prototype of the derived functions.
/// For higher order derivatives, we store the entire sequence of
/// prototypes declared for all orders of derivatives.
/// This will be useful for forward declaration of the derived functions.
llvm::SmallVector<clang::FunctionDecl*, 2> DerivedFDPrototypes;

/// A boolean to indicate if only the declaration of the derived function
/// is required (and not the definition or body).
Expand Down
10 changes: 6 additions & 4 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down Expand Up @@ -529,8 +530,9 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}

m_Sema.PopFunctionScopeInfo();
Expand Down
13 changes: 11 additions & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
FunctionDecl* derivative = this->FindDerivedFunction(request);
if (!derivative) {
alreadyDerived = false;

// Store the function and its order before processing the nested request.
const FunctionDecl* origFn = request.Function;
unsigned origFnOrder = request.CurrentDerivativeOrder;

// Derive declaration of the the forward mode derivative.
request.DeclarationOnly = true;
derivative = plugin::ProcessDiffRequest(m_CladPlugin, request);
Expand All @@ -306,10 +311,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
(derivative->isDefined() || m_DFC.IsCustomDerivative(derivative)))
alreadyDerived = true;

// Restore the original function and its order.
request.CurrentDerivativeOrder = origFnOrder;
request.Function = origFn;

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
request.DeclarationOnly = false;
request.DerivedFDPrototype = derivative;
}
this->AddEdgeToGraph(request, alreadyDerived);
return derivative;
Expand Down Expand Up @@ -423,7 +431,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
} else if (request.Mode == DiffMode::hessian ||
request.Mode == DiffMode::hessian_diagonal) {
HessianModeVisitor H(*this, request);
result = H.Derive(FD, request);
} else if (request.Mode == DiffMode::jacobian) {
Expand Down
13 changes: 12 additions & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,14 @@ namespace clad {
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian "
"mode.");
return true;
}
}
}

if (A->getAnnotation().equals("D")) {
Expand Down Expand Up @@ -651,7 +659,10 @@ namespace clad {
}
}
} else if (A->getAnnotation().equals("H")) {
request.Mode = DiffMode::hessian;
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only))
request.Mode = DiffMode::hessian_diagonal;
else
request.Mode = DiffMode::hessian;
} else if (A->getAnnotation().equals("J")) {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
Expand Down
109 changes: 78 additions & 31 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
return secondDerivative;
}

/// Derives the function two times with forward mode AD and returns the
/// FunctionDecl obtained.
static FunctionDecl* DeriveUsingForwardModeTwice(
Sema& SemaRef, clad::plugin::CladPlugin& CP,
clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest,
const Expr* ForwardModeArgs, DerivedFnCollector& DFC) {
// Set derivative order in the request to 2.
IndependentArgRequest.RequestedDerivativeOrder = 2;
IndependentArgRequest.Args = ForwardModeArgs;
IndependentArgRequest.Mode = DiffMode::forward;
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// Derive the function twice in forward mode.
FunctionDecl* secondDerivative =
Builder.HandleNestedDiffRequest(IndependentArgRequest);
return secondDerivative;
}

DerivativeAndOverload
HessianModeVisitor::Derive(const clang::FunctionDecl* FD,
const DiffRequest& request) {
Expand All @@ -91,14 +109,16 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

std::vector<FunctionDecl*> secondDerivativeColumns;
std::vector<FunctionDecl*> secondDerivativeFuncs;
llvm::SmallVector<size_t, 16> IndependentArgsSize{};
size_t TotalIndependentArgsSize = 0;

// request.Function is original function passed in from clad::hessian
assert(m_DiffReq == request);

std::string hessianFuncName = request.BaseFunctionName + "_hessian";
if (request.Mode == DiffMode::hessian_diagonal)
hessianFuncName += "_diagonal";
// To be consistent with older tests, nothing is appended to 'f_hessian' if
// we differentiate w.r.t. all the parameters at once.
if (args.size() != FD->getNumParams() ||
Expand Down Expand Up @@ -192,27 +212,38 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
PVD->getNameAsString() + "[" + std::to_string(i) + "]";
auto ForwardModeIASL =
CreateStringLiteral(m_Context, independentArgString);
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
FunctionDecl* DFD = nullptr;
if (request.Mode == DiffMode::hessian_diagonal)
DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder,
request, ForwardModeIASL,
m_Builder.m_DFC);
else
DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeFuncs.push_back(DFD);
}

} else {
IndependentArgsSize.push_back(1);
TotalIndependentArgsSize++;
// Derive the function w.r.t. to the current arg in forward mode and
// then in reverse mode w.r.t to all requested args
auto ForwardModeIASL =
CreateStringLiteral(m_Context, PVD->getNameAsString());
auto* DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeColumns.push_back(DFD);
FunctionDecl* DFD = nullptr;
if (request.Mode == DiffMode::hessian_diagonal)
DFD = DeriveUsingForwardModeTwice(m_Sema, m_CladPlugin, m_Builder,
request, ForwardModeIASL,
m_Builder.m_DFC);
else
DFD = DeriveUsingForwardAndReverseMode(
m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL,
request.Args, m_Builder.m_DFC);
secondDerivativeFuncs.push_back(DFD);
}
}
}
return Merge(secondDerivativeColumns, IndependentArgsSize,
return Merge(secondDerivativeFuncs, IndependentArgsSize,
TotalIndependentArgsSize, hessianFuncName, DC,
hessianFunctionType, paramTypes);
}
Expand Down Expand Up @@ -272,14 +303,13 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
return VD;
});

// The output parameter "hessianMatrix".
// The output parameter "hessianMatrix" or "diagonalHessianVector"
std::string outputParamName = "hessianMatrix";
if (m_DiffReq.Mode == DiffMode::hessian_diagonal)
outputParamName = "diagonalHessianVector";
params.back() = ParmVarDecl::Create(
m_Context,
hessianFD,
noLoc,
noLoc,
&m_Context.Idents.get("hessianMatrix"),
paramTypes.back(),
m_Context, hessianFD, noLoc, noLoc,
&m_Context.Idents.get(outputParamName), paramTypes.back(),
m_Context.getTrivialTypeSourceInfo(paramTypes.back(), noLoc),
params.front()->getStorageClass(),
/* No default value */ nullptr);
Expand All @@ -301,7 +331,6 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
// Creates callExprs to the second derivative functions genereated
// and creates maps array elements to input array.
for (size_t i = 0, e = secDerivFuncs.size(); i < e; ++i) {
const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize;
auto size_type = m_Context.getSizeType();
auto size_type_bits = m_Context.getIntWidth(size_type);

Expand Down Expand Up @@ -345,22 +374,40 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
}
}

size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t indArgSize : IndependentArgsSize) {
llvm::APInt offsetValue(size_type_bits,
HessianMatrixStartIndex + columnIndex);
if (m_DiffReq.Mode == DiffMode::hessian_diagonal) {
const size_t HessianMatrixStartIndex = i;
// Call the derived function for second derivative.
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);

// Create the offset argument.
llvm::APInt offsetValue(size_type_bits, HessianMatrixStartIndex);
Expr* OffsetArg =
IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc);
// Create the hessianMatrix + OffsetArg expression.
Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg);

DeclRefToParams.push_back(SliceExpr);
columnIndex += indArgSize;
// Create a assignment expression to store the value of call expression
// into the diagonalHessianVector with index HessianMatrixStartIndex.
Expr* SliceExprLHS = BuildOp(BO_Add, m_Result, OffsetArg);
Expr* DerefExpr = BuildOp(UO_Deref, BuildParens(SliceExprLHS));
Expr* AssignExpr = BuildOp(BO_Assign, DerefExpr, call);
CompStmtSave.push_back(AssignExpr);
} else {
const size_t HessianMatrixStartIndex = i * TotalIndependentArgsSize;
size_t columnIndex = 0;
// Create Expr parameters for each independent arg in the CallExpr
for (size_t indArgSize : IndependentArgsSize) {
llvm::APInt offsetValue(size_type_bits,
HessianMatrixStartIndex + columnIndex);
// Create the offset argument.
Expr* OffsetArg =
IntegerLiteral::Create(m_Context, offsetValue, size_type, noLoc);
// Create the hessianMatrix + OffsetArg expression.
Expr* SliceExpr = BuildOp(BO_Add, m_Result, OffsetArg);

DeclRefToParams.push_back(SliceExpr);
columnIndex += indArgSize;
}
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);
CompStmtSave.push_back(call);
}
Expr* call = BuildCallExprToFunction(secDerivFuncs[i], DeclRefToParams);
CompStmtSave.push_back(call);
}

auto StmtsRef =
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
m_Derivative->setBody(fnBody);
endScope();

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down
12 changes: 7 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// added by the plugins yet.
if (request.Mode != DiffMode::jacobian && numExtraParam == 0)
shouldCreateOverload = true;
if (request.DerivedFDPrototype)
if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty())
// If the overload is already created, we don't need to create it again.
shouldCreateOverload = false;

Expand Down Expand Up @@ -452,8 +452,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Derivative->setBody(gradientBody);
endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down Expand Up @@ -585,8 +586,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Derivative->setBody(fnBody);
endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down
1 change: 1 addition & 0 deletions test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ int main () {
clad::differentiate(test_8, "x");
clad::differentiate<clad::opts::enable_tbr>(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}}
clad::differentiate<clad::opts::enable_tbr, clad::opts::disable_tbr>(test_8); // expected-error {{Both enable and disable TBR options are specified.}}
clad::differentiate<clad::opts::diagonal_only>(test_8); // expected-error {{Diagonal only option is only valid for Hessian mode.}}
clad::differentiate(test_9);
return 0;

Expand Down
28 changes: 28 additions & 0 deletions test/Hessian/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ double g(double i, double j[2]) { return i * (j[0] + j[1]); }
// CHECK-NEXT: g_darg1_1_grad(i, j, hessianMatrix + {{6U|6UL}}, hessianMatrix + {{7U|7UL}});
// CHECK-NEXT: }

double h(double arr[3], double weights[3], double multiplier) {
// return square of weighted sum.
double weightedSum = arr[0] * weights[0];
weightedSum += arr[1] * weights[1];
weightedSum += arr[2] * weights[2];
weightedSum *= multiplier;
return weightedSum * weightedSum;
}
// CHECK: void h_hessian_diagonal(double arr[3], double weights[3], double multiplier, double *diagonalHessianVector) {
// CHECK-NEXT: *(diagonalHessianVector + 0{{U|UL}}) = h_d2arg0_0(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 1{{U|UL}}) = h_d2arg0_1(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 2{{U|UL}}) = h_d2arg0_2(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 3{{U|UL}}) = h_d2arg1_0(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 4{{U|UL}}) = h_d2arg1_1(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 5{{U|UL}}) = h_d2arg1_2(arr, weights, multiplier);
// CHECK-NEXT: *(diagonalHessianVector + 6{{U|UL}}) = h_d2arg2(arr, weights, multiplier);
// CHECK-NEXT: }

#define TEST(var, i, j) \
result[0] = result[1] = result[2] = result[3] = result[4] = result[5] = \
result[6] = result[7] = result[8] = 0; \
Expand All @@ -44,4 +62,14 @@ int main() {
TEST(h1, 2, x); // CHECK-EXEC: Result = {0.00 4.00 3.00 4.00 0.00 2.00 3.00 2.00 0.00}
auto h2 = clad::hessian(g, "i, j[0:1]");
TEST(h2, 2, x); // CHECK-EXEC: Result = {0.00 1.00 1.00 1.00 0.00 0.00 1.00 0.00 0.00}

double arr[] = {1, 2, 3};
double weights[] = {4, 5, 6};
double diag[7]; // result will be the diagonal of the Hessian matrix.
double multiplier = 2.0;
auto h3 = clad::hessian<clad::opts::diagonal_only>(h, "arr[0:2], weights[0:2], multiplier");
h3.execute(arr, weights, multiplier, diag);
printf("Diagonal (arr) = {%.2f %.2f %.2f},\n", diag[0], diag[1], diag[2]); // CHECK-EXEC: Diagonal (arr) = {128.00 200.00 288.00},
printf("Diagonal (weights) = {%.2f %.2f %.2f}\n", diag[3], diag[4], diag[5]); // CHECK-EXEC: Diagonal (weights) = {8.00 32.00 72.00}
printf("Diagonal (multiplier) = %.2f\n", diag[6]); // CHECK-EXEC: Diagonal (multiplier) = 2048.00
}
Loading

0 comments on commit 7237986

Please sign in to comment.