Skip to content

Commit

Permalink
Merge with master
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Aug 13, 2024
2 parents 6130c91 + 6cc83ee commit 98352a8
Show file tree
Hide file tree
Showing 22 changed files with 500 additions and 236 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ Reverse-mode AD allows computing the gradient of `f` using *at most* a constant
1. `f` is a pointer to a function or a method to be differentiated
2. `ARGS` is either:
* not provided, then `f` is differentiated w.r.t. its every argument
* a string literal with comma-separated names of independent variables (e.g. `"x"` or `"y"` or `"x, y"` or `"y, x"`)
* a string literal with comma-separated names/indices of independent variables (e.g. `"x"`, `"y"`, `"x, y"`, `"y, x"`, "0, 1", "0, y", etc.)
* a SINGLE number representing the index of the independent variable
Since a vector of derivatives must be returned from a function generated by the reverse mode, its signature is slightly different. The generated function has `void` return type and same input arguments. The function has additional `n` arguments (where `n` refers to the number of arguments whose gradient was requested) of type `T*`, where `T` is the type of the corresponding original variable. Each of these variables stores the derivative of the elements as they appear in the orignal function signature. *The caller is responsible for allocating and zeroing-out the gradient storage*. Example:
```cpp
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7~dev
1.8~dev
8 changes: 4 additions & 4 deletions docs/internalDocs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Introduction
============

This document contains the release notes for the automatic differentiation
plugin for clang Clad, release 1.7. Clad is built on top of
plugin for clang Clad, release 1.8. Clad is built on top of
[Clang](http://clang.llvm.org) and [LLVM](http://llvm.org>) compiler
infrastructure. Here we describe the status of Clad in some detail, including
major improvements from the previous release and new feature work.
Expand All @@ -11,7 +11,7 @@ Note that if you are reading this file from a git checkout,
this document applies to the *next* release, not the current one.


What's New in Clad 1.7?
What's New in Clad 1.8?
========================

Some of the major new features and improvements to Clad are listed here. Generic
Expand Down Expand Up @@ -54,7 +54,7 @@ Fixed Bugs
[XXX](https://github.com/vgvassilev/clad/issues/XXX)

<!---Get release bugs. Check for close, fix, resolve
git log v1.6..master | grep -i "close" | grep '#' | sed -E 's,.*\#([0-9]*).*,\[\1\]\(https://github.com/vgvassilev/clad/issues/\1\),g' | sort
git log v1.7..master | grep -i "close" | grep '#' | sed -E 's,.*\#([0-9]*).*,\[\1\]\(https://github.com/vgvassilev/clad/issues/\1\),g' | sort
--->

Special Kudos
Expand All @@ -68,5 +68,5 @@ FirstName LastName (#commits)
A B (N)

<!---Find contributor list for this release
git log --pretty=format:"%an" v1.6...master | sort | uniq -c | sort -rn | sed -E 's,^ *([0-9]+) (.*)$,\2 \(\1\),'
git log --pretty=format:"%an" v1.7...master | sort | uniq -c | sort -rn | sed -E 's,^ *([0-9]+) (.*)$,\2 \(\1\),'
--->
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class BaseForwardModeVisitor
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
StmtDiff
VisitCXXScalarValueInitExpr(const clang::CXXScalarValueInitExpr* SVIE);
StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* CSE);
StmtDiff VisitCXXFunctionalCastExpr(const clang::CXXFunctionalCastExpr* FCE);
StmtDiff VisitCXXBindTemporaryExpr(const clang::CXXBindTemporaryExpr* BTE);
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ template <typename T> class DynamicGraph {
bool isProcessingNode() { return m_currentId != -1; }

/// Get the nodes in the graph.
std::vector<T> getNodes() { return m_nodes; }
const std::vector<T>& getNodes() { return m_nodes; }

/// Print the nodes and edges in the graph.
void print() {
Expand Down Expand Up @@ -140,4 +140,4 @@ template <typename T> class DynamicGraph {
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H
#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ namespace clad {
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool AddToBlock = true);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
15 changes: 5 additions & 10 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,12 @@ namespace clad {
/// to avoid recomputation.
static bool UsefulToStore(clang::Expr* E);
/// A flag for silencing warnings/errors output by diag function.
bool silenceDiags = false;
/// Shorthand to issues a warning or error.
template <std::size_t N>
void diag(clang::DiagnosticsEngine::Level level, // Warning or Error
clang::SourceLocation loc, const char (&format)[N],
llvm::ArrayRef<llvm::StringRef> args = {}) {
if (!silenceDiags)
m_Builder.diag(level, loc, format, args);
m_Builder.diag(level, loc, format, args);
}

/// Creates unique identifier of the form "_nameBase<number>" that is
Expand Down Expand Up @@ -584,17 +582,14 @@ namespace clad {
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);

/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
///
/// \param[in] \c funcName The name of the underlying function of the
/// call expression.
/// \param[in] \c FD - The function declaration.
/// \param[in] \c srcLoc Any associated source location information.
/// \param[in] \c isDerived A flag to determine if differentiation of the
/// call expression was successful.
void CallExprDiffDiagnostics(llvm::StringRef funcName,
clang::SourceLocation srcLoc,
bool isDerived);
void CallExprDiffDiagnostics(const clang::FunctionDecl* FD,
clang::SourceLocation srcLoc);

clang::QualType DetermineCladArrayValueType(clang::QualType T);

Expand Down
32 changes: 24 additions & 8 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ DerivativeAndOverload
BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request && "Can't pass two different requests!");
silenceDiags = !request.VerboseDiags;
m_Functor = request.Functor;
assert(m_DiffReq.Mode == DiffMode::forward);
assert(!m_DerivativeInFlight &&
Expand Down Expand Up @@ -801,8 +800,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) ||
condUO) {
condDiff = Visit(cond);
if (condDiff.getExpr_dx() &&
(!isUnusedResult(condDiff.getExpr_dx()) || condUO))
if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx())))
cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()),
BuildParens(condDiff.getExpr()));
else
Expand Down Expand Up @@ -1335,7 +1333,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
GetSingleArgCentralDiffCall(fnCallee, CallArgs[0],
/*targetPos=*/0, /*numArgs=*/1, CallArgs);
}
CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), callDiff);
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!callDiff) {
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down Expand Up @@ -1388,7 +1386,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
return StmtDiff(op, diff.getExpr_dx());
Expr* zero = getZeroInit(UnOp->getType());
if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx()))
return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero};
return {op, zero};
} else if (opKind == UnaryOperatorKind::UO_Not) {
// ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones.
// Either way, taking a derivative gives us -_d_x.
Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx());
return {op, derivedOp};
} else {
unsupportedOpWarn(UnOp->getEndLoc());
auto zero =
Expand Down Expand Up @@ -1504,7 +1510,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
} else
opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()),
BuildParens(Rdiff.getExpr_dx()));
} else if (BinOp->isLogicalOp()) {
} else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() ||
BinOp->isComparisonOp() || opCode == BO_Rem) {
// For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and
// correct derivative execution.
auto buildOneSide = [this](StmtDiff& Xdiff) {
Expand All @@ -1521,8 +1528,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {

// Since the both parts are included in the opDiff, there's no point in
// including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left
// nullptr is used for treating expressions like ((A && B) && C) correctly.
return StmtDiff(opDiff, nullptr);
// zero is used for treating expressions like ((A && B) && C) correctly.
return StmtDiff(opDiff, getZeroInit(BinOp->getType()));
} else if (BinOp->isShiftOp()) {
// Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS).
// We should do the same to the derivarive.
opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr());
} else {
// FIXME: add support for other binary operators
unsupportedOpWarn(BinOp->getEndLoc());
Expand Down Expand Up @@ -2278,6 +2289,11 @@ StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr(
return Visit(ILE->getSubExpr());
}

StmtDiff BaseForwardModeVisitor::VisitCXXScalarValueInitExpr(
const CXXScalarValueInitExpr* SVIE) {
return {Clone(SVIE), Clone(SVIE)};
}

clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall(
const clang::CXXConstructExpr* CE,
llvm::SmallVectorImpl<clang::Expr*>& clonedArgs,
Expand Down
9 changes: 1 addition & 8 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
}
if (!NSD) {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
if (!NSD)
return R;
}
}
DeclContext* DC = NSD;

Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,21 @@ namespace clad {
DiffInputVarInfo dVarInfo;

dVarInfo.source = diffSpec.str();
// Check if diffSpec represents an index of an independent variable.
if ('0' <= diffSpec[0] && diffSpec[0] <= '9') {
unsigned idx = std::stoi(dVarInfo.source);
// Fail if the specified index is invalid.
if (idx >= FD->getNumParams()) {
utils::EmitDiag(
semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(),
"Invalid argument index '%0' of '%1' argument(s)",
{std::to_string(idx), std::to_string(FD->getNumParams())});
return;
}
dVarInfo.param = FD->getParamDecl(idx);
DVI.push_back(dVarInfo);
continue;
}
llvm::StringRef pName = computeParamName(diffSpec);
auto it = std::find_if(std::begin(candidates), std::end(candidates),
[&pName](
Expand Down
1 change: 0 additions & 1 deletion lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ DerivativeAndOverload
ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
assert(m_DiffReq == request);
silenceDiags = !request.VerboseDiags;

assert(m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass);

Expand Down
Loading

0 comments on commit 98352a8

Please sign in to comment.