Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operator overload in reverse mode #619

Merged
merged 1 commit into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp) override;
};
} // namespace clad

Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ namespace clad {
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
virtual StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
StmtDiff VisitWhileStmt(const clang::WhileStmt* WS);
Expand Down
30 changes: 27 additions & 3 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
DiffParams args{};
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

auto fnName = m_Function->getNameAsString() + "_forw";
auto fnName = clad::utils::ComputeEffectiveFnName(m_Function) + "_forw";
auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -86,8 +86,6 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
QualType
ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand Down Expand Up @@ -240,4 +238,30 @@ ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get();
return {newRS};
}

StmtDiff
ReverseModeForwPassVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
auto opCode = UnOp->getOpcode();
StmtDiff diff{};
// If it is a post-increment/decrement operator, its result is a reference
// and we should return it.
Expr* ResultRef = nullptr;
if (opCode == UnaryOperatorKind::UO_Deref) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance()) {
diff = Visit(UnOp->getSubExpr());
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());
Expr* derivedE = diff.getExpr_dx();
return {cloneE, derivedE};
}
}
} else if (opCode == UO_Plus)
diff = Visit(UnOp->getSubExpr(), dfdx());
PhrygianGates marked this conversation as resolved.
Show resolved Hide resolved
else if (opCode == UO_Minus) {
auto d = BuildOp(UO_Minus, dfdx());
PhrygianGates marked this conversation as resolved.
Show resolved Hide resolved
diff = Visit(UnOp->getSubExpr(), d);
}
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef);
}
} // namespace clad
63 changes: 39 additions & 24 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnScope();

auto params = BuildParams(args);

if (m_ExternalSource)
Expand Down Expand Up @@ -411,7 +411,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_IndependentVars.push_back(arg);
}
}

if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope();

Expand Down Expand Up @@ -743,7 +743,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff SDiff = DifferentiateSingleStmt(S);
addToCurrentBlock(SDiff.getStmt(), direction::forward);
addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);

if (m_ExternalSource)
m_ExternalSource->ActAfterProcessingStmtInVisitCompoundStmt();
}
Expand Down Expand Up @@ -861,7 +861,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt();
StmtDiff BranchDiff = DifferentiateSingleStmt(Branch, /*dfdS=*/nullptr);
addToCurrentBlock(BranchDiff.getStmt(), direction::forward);

if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt();

Expand Down Expand Up @@ -1372,7 +1372,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
// contribute to gradient.
if (!NArgs && !isa<CXXMemberCallExpr>(CE))
if ((NArgs == 0U) && !isa<CXXMemberCallExpr>(CE) &&
!isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

// Stores the call arguments for the function to be derived
Expand All @@ -1392,7 +1393,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE)) {
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
Expand Down Expand Up @@ -1424,9 +1425,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) {
bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const auto* PVD = FD->getParamDecl(i - skipFirstArg);
const auto* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
// We do not need to create result arg for arguments passed by reference
Expand Down Expand Up @@ -1714,11 +1720,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullback);

// Try to find it in builtin derivatives
std::string customPullback = FD->getNameAsString() + "_pullback";
if (baseDiff.getExpr())
pullbackCallArgs.insert(
pullbackCallArgs.begin(),
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}

// should be true if we are using numerical differentiation to differentiate
Expand Down Expand Up @@ -1749,7 +1762,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;
pullbackRequest.BaseFunctionName = FD->getNameAsString();
pullbackRequest.BaseFunctionName =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: implicit conversion 'clang::Expr *' -> bool [readability-implicit-bool-conversion]

Suggested change
pullbackRequest.BaseFunctionName =
pullbackCallArgs, ArgDeclStmts, dfdx() != nullptr);

clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
Expand Down Expand Up @@ -1882,7 +1896,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString();
calleeFnForwPassReq.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;
FunctionDecl* calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);
Expand All @@ -1906,13 +1921,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// (isCladArrayType(derivedBase->getType()))
// CallArgs.push_back(derivedBase);
// else
// Currently derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc));
}

for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const ParmVarDecl* PVD = FD->getParamDecl(i);
const ParmVarDecl* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff = Visit(arg);
if ((argDiff.getExpr_dx() != nullptr) &&
PVD->getType()->isReferenceType()) {
Expand Down Expand Up @@ -1988,8 +2007,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
}
}
else {
} else {
// FIXME: This is not adding 'address-of' operator support.
// This is just making this special case differentiable that is required
// for computing hessian:
Expand Down Expand Up @@ -2382,13 +2400,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = getZeroInit(VD->getType());

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
// the following code:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
// differentiated.
bool specialThisDiffCase = false;
if (auto MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (VDDerivedType->isPointerType() && MD->isInstance()) {
Expand Down Expand Up @@ -2507,10 +2525,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

return VarDeclDiff(VDClone, VDDerived);
}

// TODO: 'shouldEmit' parameter should be removed after converting
// Error estimation framework to callback style. Some more research
// need to be done to
// need to be done to
StmtDiff
ReverseModeVisitor::DifferentiateSingleStmt(const Stmt* S, Expr* dfdS) {
if (m_ExternalSource)
Expand Down Expand Up @@ -3121,7 +3139,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

bodyDiff = {bodyDiff.getStmt(), CFSS};
}

void ReverseModeVisitor::AddExternalSource(ExternalRMVSource& source) {
if (!m_ExternalSource)
m_ExternalSource = new MultiplexExternalRMVSource();
Expand Down Expand Up @@ -3177,13 +3195,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {

if (m_Mode == DiffMode::reverse)
assert(yType->isRealType() &&
"yType should be a non-reference builtin-numerical scalar type!!");
else if (m_Mode == DiffMode::experimental_pullback)
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand Down
Loading
Loading