Skip to content

Commit

Permalink
Fix hessian with new schedule plan and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 3, 2024
1 parent 993a719 commit 09522b6
Show file tree
Hide file tree
Showing 10 changed files with 475 additions and 392 deletions.
7 changes: 3 additions & 4 deletions demos/Arrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ int main() {
// the indexes of the array by using the format arr[0:<last index of arr>]
auto hessian_all = clad::hessian(weighted_avg, "arr[0:2], weights[0:2]");
// Generates the Hessian matrix for weighted_avg w.r.t. to arr.
// auto hessian_arr = clad::hessian(weighted_avg, "arr[0:2]");
auto hessian_arr = clad::hessian(weighted_avg, "arr[0:2]");

double matrix_all[36] = {0};
// double matrix_arr[9] = {0};
double matrix_arr[9] = {0};

hessian_all.execute(arr, weights, matrix_all);
printf("Hessian Mode w.r.t. to all:\n matrix =\n"
Expand All @@ -93,13 +93,12 @@ int main() {
matrix_all[28], matrix_all[29], matrix_all[30], matrix_all[31],
matrix_all[32], matrix_all[33], matrix_all[34], matrix_all[35]);

/*hessian_arr.execute(arr, weights, matrix_arr);
hessian_arr.execute(arr, weights, matrix_arr);
printf("Hessian Mode w.r.t. to arr:\n matrix =\n"
" {%.2g, %.2g, %.2g}\n"
" {%.2g, %.2g, %.2g}\n"
" {%.2g, %.2g, %.2g}\n",
matrix_arr[0], matrix_arr[1], matrix_arr[2], matrix_arr[3],
matrix_arr[4], matrix_arr[5], matrix_arr[6], matrix_arr[7],
matrix_arr[8]);
*/
}
261 changes: 133 additions & 128 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,143 +212,148 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
derivedFD->setParams(paramsRef);
derivedFD->setBody(nullptr);

// Function body scope
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();
// For each function parameter variable, store its derivative value.
for (auto param : params) {
// We cannot create derivatives of reference type since seed value is
// always a constant (r-value). We assume that all the arguments have no
// relation among them, thus it is safe (correct) to use the corresponding
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
continue;
Expr* dParam = nullptr;
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, dValue);
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
// double f_darg0(double x, double y) {
// double _d_x = 1;
// double _d_y = 0;
// ...
auto dParamDecl =
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (dParamType->isRecordType() && param == m_IndependentVar) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
Expr* memRef =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref);
assert(memRef->getType()->isRealType() &&
"Forward mode can only differentiate w.r.t builtin scalar "
"numerical types.");
addToCurrentBlock(BuildOp(
BinaryOperatorKind::BO_Assign, memRef,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1)));
}
// Memorize the derivative of param, i.e. whenever the param is visited
// in the future, it's derivative dParam is found (unless reassigned with
// something new).
m_Variables[param] = dParam;
}

if (auto MD = dyn_cast<CXXMethodDecl>(FD)) {
// We cannot create derivative of lambda yet because lambdas default
// constructor is deleted.
if (MD->isInstance() && !MD->getParent()->isLambda()) {
QualType thisObjectType =
clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD);
QualType thisType = MD->getThisType();
// Here we are effectively doing:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// We are not creating `this` expression derivative using `new` because
// then we would be responsible for freeing the memory as well and its
// more convenient to let compiler handle the object lifecycle.
VarDecl* derivativeVD = BuildVarDecl(thisObjectType, "_d_this_obj");
DeclRefExpr* derivativeE = BuildDeclRef(derivativeVD);
VarDecl* thisExprDerivativeVD =
BuildVarDecl(thisType, "_d_this",
BuildOp(UnaryOperatorKind::UO_AddrOf, derivativeE));
addToCurrentBlock(BuildDeclStmt(derivativeVD));
addToCurrentBlock(BuildDeclStmt(thisExprDerivativeVD));
m_ThisExprDerivative = BuildDeclRef(thisExprDerivativeVD);
if (!request.DeclarationOnly) {
// Function body scope
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();
// For each function parameter variable, store its derivative value.
for (auto param : params) {
// We cannot create derivatives of reference type since seed value is
// always a constant (r-value). We assume that all the arguments have no
// relation among them, thus it is safe (correct) to use the corresponding
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
continue;
Expr* dParam = nullptr;
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, dValue);
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
// double f_darg0(double x, double y) {
// double _d_x = 1;
// double _d_y = 0;
// ...
auto dParamDecl =
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (dParamType->isRecordType() && param == m_IndependentVar) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
Expr* memRef =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref);
assert(memRef->getType()->isRealType() &&
"Forward mode can only differentiate w.r.t builtin scalar "
"numerical types.");
addToCurrentBlock(BuildOp(
BinaryOperatorKind::BO_Assign, memRef,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1)));
}
// Memorize the derivative of param, i.e. whenever the param is visited
// in the future, it's derivative dParam is found (unless reassigned with
// something new).
m_Variables[param] = dParam;
}
}

// Create derived variable for each member variable if we are
// differentiating a call operator.
if (m_Functor) {
for (FieldDecl* fieldDecl : m_Functor->fields()) {
Expr* dInitializer = nullptr;
QualType fieldType = fieldDecl->getType();

if (auto arrType = dyn_cast<ConstantArrayType>(fieldType.getTypePtr())) {
if (!arrType->getElementType()->isRealType())
continue;

auto arrSize = arrType->getSize().getZExtValue();
std::vector<Expr*> dArrVal;

// Create an initializer list to initialize derived variable created
// for array member variable.
// For example, if we are differentiating wrt arr[3], then
// ```
// double arr[7];
// ```
// will get differentiated to,
//
if (auto MD = dyn_cast<CXXMethodDecl>(FD)) {
// We cannot create derivative of lambda yet because lambdas default
// constructor is deleted.
if (MD->isInstance() && !MD->getParent()->isLambda()) {
QualType thisObjectType =
clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD);
QualType thisType = MD->getThisType();
// Here we are effectively doing:
// ```
// double _d_arr[7] = {0, 0, 0, 1, 0, 0, 0};
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
for (size_t i = 0; i < arrSize; ++i) {
int dValue =
(fieldDecl == m_IndependentVar && i == m_IndependentVarIndex);
auto dValueLiteral = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
// We are not creating `this` expression derivative using `new` because
// then we would be responsible for freeing the memory as well and its
// more convenient to let compiler handle the object lifecycle.
VarDecl* derivativeVD = BuildVarDecl(thisObjectType, "_d_this_obj");
DeclRefExpr* derivativeE = BuildDeclRef(derivativeVD);
VarDecl* thisExprDerivativeVD =
BuildVarDecl(thisType, "_d_this",
BuildOp(UnaryOperatorKind::UO_AddrOf, derivativeE));
addToCurrentBlock(BuildDeclStmt(derivativeVD));
addToCurrentBlock(BuildDeclStmt(thisExprDerivativeVD));
m_ThisExprDerivative = BuildDeclRef(thisExprDerivativeVD);
}
}

// Create derived variable for each member variable if we are
// differentiating a call operator.
if (m_Functor) {
for (FieldDecl* fieldDecl : m_Functor->fields()) {
Expr* dInitializer = nullptr;
QualType fieldType = fieldDecl->getType();

if (auto arrType = dyn_cast<ConstantArrayType>(fieldType.getTypePtr())) {
if (!arrType->getElementType()->isRealType())
continue;

auto arrSize = arrType->getSize().getZExtValue();
std::vector<Expr*> dArrVal;

// Create an initializer list to initialize derived variable created
// for array member variable.
// For example, if we are differentiating wrt arr[3], then
// ```
// double arr[7];
// ```
// will get differentiated to,
//
// ```
// double _d_arr[7] = {0, 0, 0, 1, 0, 0, 0};
// ```
for (size_t i = 0; i < arrSize; ++i) {
int dValue =
(fieldDecl == m_IndependentVar && i == m_IndependentVarIndex);
auto dValueLiteral = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
}
dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get();
} else if (auto ptrType = dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, dValue);
}
dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get();
} else if (auto ptrType = dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, dValue);
VarDecl* derivedFieldDecl =
BuildVarDecl(fieldType.getNonReferenceType(),
"_d_" + fieldDecl->getNameAsString(), dInitializer);
addToCurrentBlock(BuildDeclStmt(derivedFieldDecl));
m_Variables.emplace(fieldDecl, BuildDeclRef(derivedFieldDecl));
}
VarDecl* derivedFieldDecl =
BuildVarDecl(fieldType.getNonReferenceType(),
"_d_" + fieldDecl->getNameAsString(), dInitializer);
addToCurrentBlock(BuildDeclStmt(derivedFieldDecl));
m_Variables.emplace(fieldDecl, BuildDeclRef(derivedFieldDecl));
}
}

Stmt* BodyDiff = Visit(FD->getBody()).getStmt();
if (auto CS = dyn_cast<CompoundStmt>(BodyDiff))
for (Stmt* S : CS->body())
addToCurrentBlock(S);
else
addToCurrentBlock(BodyDiff);
Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);
Stmt* BodyDiff = Visit(FD->getBody()).getStmt();
if (auto CS = dyn_cast<CompoundStmt>(BodyDiff))
for (Stmt* S : CS->body())
addToCurrentBlock(S);
else
addToCurrentBlock(BodyDiff);
Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);

endScope(); // Function body scope
endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope
Expand Down
13 changes: 11 additions & 2 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,18 @@ namespace clad {
IndependentArgRequest.Args = ReverseModeArgs;
IndependentArgRequest.BaseFunctionName = firstDerivative->getNameAsString();
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);

// Derive declaration of the the forward mode derivative.
IndependentArgRequest.DeclarationOnly = true;
FunctionDecl* secondDerivative =
plugin::ProcessDiffRequest(CP, IndependentArgRequest);

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
IndependentArgRequest.DeclarationOnly = false;
IndependentArgRequest.DerivedFDPrototype = secondDerivative;
plugin::AddRequestToSchedule(CP, IndependentArgRequest);

return secondDerivative;
}

Expand Down Expand Up @@ -99,8 +108,8 @@ namespace clad {
std::string hessianFuncName = request.BaseFunctionName + "_hessian";
// To be consistent with older tests, nothing is appended to 'f_hessian' if
// we differentiate w.r.t. all the parameters at once.
if (!std::equal(m_Function->param_begin(), m_Function->param_end(),
std::begin(args))) {
if (!(args.size() == FD->getNumParams() && std::equal(m_Function->param_begin(), m_Function->param_end(),
args.begin()))) {
for (auto arg : args) {
auto it =
std::find(m_Function->param_begin(), m_Function->param_end(), arg);
Expand Down
Loading

0 comments on commit 09522b6

Please sign in to comment.