Skip to content

Commit

Permalink
Merge branch 'master' into cuda-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Sep 2, 2024
2 parents 6816d9a + 8f7d247 commit f7b3caa
Show file tree
Hide file tree
Showing 21 changed files with 1,039 additions and 482 deletions.
12 changes: 12 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ template <typename T, typename U> struct ValueAndAdjoint {
/// class for which constructor pushforward is defined.
template <class T> class ConstructorPushforwardTag {};

template <class T> class ConstructorReverseForwTag {};

namespace custom_derivatives {
#ifdef __CUDACC__
template <typename T>
Expand Down Expand Up @@ -329,6 +331,16 @@ using std::pow_pullback;
using std::pow_pushforward;
using std::sin_pushforward;
using std::sqrt_pushforward;

namespace class_functions {
template <typename T, typename U>
void constructor_pullback(ValueAndPushforward<T, U>* lhs,
ValueAndPushforward<T, U> rhs,
ValueAndPushforward<T, U>* d_lhs,
ValueAndPushforward<T, U>* d_rhs) {
d_rhs->pushforward += d_lhs->pushforward;
}
} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad

Expand Down
487 changes: 433 additions & 54 deletions include/clad/Differentiator/KokkosBuiltins.h

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/Sema/Sema.h"

#include <array>
#include <limits>
#include <memory>
#include <stack>
#include <unordered_map>
Expand Down Expand Up @@ -689,6 +690,35 @@ namespace clad {
}

void PopSwitchStmtInfo() { m_SwitchStmtsData.pop_back(); }

struct ConstructorPullbackCallInfo {
clang::CallExpr* pullbackCE = nullptr;
size_t thisAdjointArgIdx = std::numeric_limits<size_t>::max();
void updateThisParmArgs(clang::Expr* thisE, clang::Expr* dThisE) const;
ConstructorPullbackCallInfo() = default;
ConstructorPullbackCallInfo(clang::CallExpr* pPullbackCE,
size_t pThisAdjointArgIdx)
: pullbackCE(pPullbackCE), thisAdjointArgIdx(pThisAdjointArgIdx) {}

bool empty() const { return !pullbackCE; }
};

void setConstructorPullbackCallInfo(clang::CallExpr* pullbackCE,
size_t thisAdjointArgIdx) {
m_ConstructorPullbackCallInfo = {pullbackCE, thisAdjointArgIdx};
}

ConstructorPullbackCallInfo getConstructorPullbackCallInfo() {
return m_ConstructorPullbackCallInfo;
}

void resetConstructorPullbackCallInfo() {
m_ConstructorPullbackCallInfo = ConstructorPullbackCallInfo{};
}

private:
ConstructorPullbackCallInfo m_ConstructorPullbackCallInfo;
bool m_TrackConstructorPullbackInfo = false;
};
} // end namespace clad

Expand Down
22 changes: 22 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,28 @@ void operator_subscript_pullback(::std::vector<T>* vec,
(*d_vec)[idx] += d_y;
}

template <typename T, typename S, typename U>
::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
S count, U val,
typename ::std::vector<T>::allocator_type alloc,
S d_count, U d_val,
typename ::std::vector<T>::allocator_type d_alloc) {
::std::vector<T> v(count, val);
::std::vector<T> d_v(count, 0);
return {v, d_v};
}

template <typename T, typename S, typename U>
void constructor_pullback(::std::vector<T>* v, S count, U val,
typename ::std::vector<T>::allocator_type alloc,
::std::vector<T>* d_v, S* d_count, U* d_val,
typename ::std::vector<T>::allocator_type* d_alloc) {
for (unsigned i = 0; i < count; ++i)
*d_val += (*d_v)[i];
d_v->clear();
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
7 changes: 7 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ namespace clad {
/// Returns type clad::Identify<T>
clang::QualType GetCladConstructorPushforwardTagOfType(clang::QualType T);

/// Returns clad::ConstructorReverseForwTag template declaration.
clang::TemplateDecl* GetCladConstructorReverseForwTag();

/// Returns type clad::ConstructorReverseForwTag<T>
clang::QualType GetCladConstructorReverseForwTagOfType(clang::QualType T);

public:
/// Rebuild a sequence of nested namespaces ending with DC.
clang::NamespaceDecl* RebuildEnclosingNamespaces(clang::DeclContext* DC);
Expand Down Expand Up @@ -661,6 +667,7 @@ namespace clad {

private:
clang::TemplateDecl* m_CladConstructorPushforwardTag = nullptr;
clang::TemplateDecl* m_CladConstructorReverseForwTag = nullptr;
};
} // end namespace clad

Expand Down
1 change: 0 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {

CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);
Expand Down
187 changes: 176 additions & 11 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#include "clang/Sema/Sema.h"
#include "clang/Sema/SemaInternal.h"
#include "clang/Sema/Template.h"
#include <clang/AST/DeclCXX.h>
#include <clang/AST/ExprCXX.h>
#include <clang/AST/OperationKinds.h>
#include <clang/Sema/Ownership.h>

#include "llvm/Support/SaveAndRestore.h"

Expand Down Expand Up @@ -2763,6 +2767,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;

ConstructorPullbackCallInfo constructorPullbackInfo;

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
if (const auto* AT = dyn_cast<ArrayType>(VDType)) {
Expand Down Expand Up @@ -2819,6 +2825,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

if (VDType->isStructureOrClassType()) {
m_TrackConstructorPullbackInfo = true;
initDiff = Visit(VD->getInit());
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
// once reverse mode supports pointers. `specialThisDiffCase` is only
// required for correctly differentiating the following code:
Expand Down Expand Up @@ -2880,9 +2896,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

if (VD->getInit()) {
if (isa<CXXConstructExpr>(VD->getInit()))
initDiff = Visit(VD->getInit());
else
if (VDType->isStructureOrClassType()) {
if (!initDiff.getExpr())
initDiff = Visit(VD->getInit());
} else
initDiff = Visit(VD->getInit(), derivedE);
}

Expand Down Expand Up @@ -2994,6 +3011,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDType != VDClone->getType()))
m_DeclReplacements[VD] = VDClone;

if (!constructorPullbackInfo.empty()) {
Expr* thisE =
BuildOp(UnaryOperatorKind::UO_AddrOf, BuildDeclRef(VDClone));
Expr* dThisE =
BuildOp(UnaryOperatorKind::UO_AddrOf, BuildDeclRef(VDDerived));
constructorPullbackInfo.updateThisParmArgs(thisE, dThisE);
}
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -4111,31 +4135,151 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {nullptr, nullptr};
}

// FIXME: Add support for differentiating calls to constructors.
// We currently assume that constructor arguments are non-differentiable.
StmtDiff
ReverseModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
llvm::SmallVector<Expr*, 4> clonedArgs;

llvm::SmallVector<Expr*, 4> primalArgs;
llvm::SmallVector<Expr*, 4> adjointArgs;
llvm::SmallVector<Expr*, 4> reverseForwAdjointArgs;
// It is used to store '_r0' temporary gradient variables that are used for
// differentiating non-reference args.
llvm::SmallVector<Stmt*, 4> prePullbackCallStmts;

// Insertion point is required because we need to insert pullback call
// before the statements inserted by 'Visit(arg, ...)' calls for arguments.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

// FIXME: Restore arguments passed as non-const reference.
for (const auto* arg : CE->arguments()) {
auto argDiff = Visit(arg, dfdx());
clonedArgs.push_back(argDiff.getExpr());
QualType ArgTy = arg->getType();
StmtDiff argDiff{};
Expr* adjointArg = nullptr;
if (utils::IsReferenceOrPointerArg(arg->IgnoreParenImpCasts())) {
argDiff = Visit(arg);
adjointArg = argDiff.getExpr_dx();
} else {
// non-reference arguments are differentiated as follows:
//
// primal code:
// ```
// SomeClass c(u, ...);
// ```
//
// Derivative code:
// ```
// // forward pass
// ...
// // reverse pass
// double _r0 = 0;
// SomeClass_pullback(c, u, ..., &_d_c, &_r0, ...);
// _d_u += _r0;
QualType dArgTy = getNonConstType(ArgTy, m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));
prePullbackCallStmts.push_back(BuildDeclStmt(dArgDecl));
adjointArg = BuildDeclRef(dArgDecl);
argDiff = Visit(arg, BuildDeclRef(dArgDecl));
}

if (utils::isArrayOrPointerType(ArgTy)) {
reverseForwAdjointArgs.push_back(adjointArg);
adjointArgs.push_back(adjointArg);
} else {
if (utils::IsReferenceOrPointerArg(arg->IgnoreParenImpCasts()))
reverseForwAdjointArgs.push_back(adjointArg);
else
reverseForwAdjointArgs.push_back(getZeroInit(ArgTy));
adjointArgs.push_back(BuildOp(UnaryOperatorKind::UO_AddrOf, adjointArg,
m_DiffReq->getLocation()));
}
primalArgs.push_back(argDiff.getExpr());
}

// Try to create a pullback constructor call
llvm::SmallVector<Expr*, 4> pullbackArgs;
QualType recordType =
m_Context.getRecordType(CE->getConstructor()->getParent());
QualType recordPointerType = m_Context.getPointerType(recordType);
// thisE = object being created by this constructor call.
// dThisE = adjoint of the object being created by this constructor call.
//
// We cannot fill these args yet because these objects have not yet been
// created. The caller which triggers 'VisitCXXConstructExpr' is
// responsible for updating these args.
Expr* thisE = getZeroInit(recordPointerType);
Expr* dThisE = getZeroInit(recordPointerType);

pullbackArgs.push_back(thisE);
pullbackArgs.append(primalArgs.begin(), primalArgs.end());
pullbackArgs.push_back(dThisE);
pullbackArgs.append(adjointArgs.begin(), adjointArgs.end());

Stmts& curRevBlock = getCurrentBlock(direction::reverse);
Stmts::iterator it = std::begin(curRevBlock) + insertionPoint;
curRevBlock.insert(it, prePullbackCallStmts.begin(),
prePullbackCallStmts.end());
it += prePullbackCallStmts.size();
std::string customPullbackName = "constructor_pullback";
if (Expr* customPullbackCall =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullbackName, pullbackArgs, getCurrentScope(),
const_cast<DeclContext*>(
CE->getConstructor()->getDeclContext()))) {
curRevBlock.insert(it, customPullbackCall);
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(llvm::cast<CallExpr>(customPullbackCall),
primalArgs.size() + 1);
m_TrackConstructorPullbackInfo = false;
}
}
// FIXME: If no compatible custom constructor pullback is found then try
// to automatically differentiate the constructor.

// Create the constructor call in the forward-pass, or creates
// 'constructor_forw' call if possible.

// This works as follows:
//
// primal code:
// ```
// SomeClass c(u, v);
// ```
//
// adjoint code:
// ```
// // forward-pass
// clad::ValueAndAdjoint<SomeClass, SomeClass> _t0 =
// constructor_forw(clad::ConstructorReverseForwTag<SomeClass>{}, u, v,
// _d_u, _d_v);
// SomeClass _d_c = _t0.adjoint;
// SomeClass c = _t0.value;
// ```
if (Expr* customReverseForwFnCall = BuildCallToCustomForwPassFn(
CE->getConstructor(), primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
Expr* callRes = StoreAndRef(customReverseForwFnCall);
Expr* val =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
Expr* adjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return {val, nullptr, adjoint};
}

Expr* clonedArgsE = nullptr;

if (CE->getNumArgs() != 1) {
if (CE->isListInitialization()) {
clonedArgsE = m_Sema.ActOnInitList(noLoc, clonedArgs, noLoc).get();
clonedArgsE = m_Sema.ActOnInitList(noLoc, primalArgs, noLoc).get();
} else {
if (CE->getNumArgs() == 0) {
// ParenList is empty -- default initialisation.
// Passing empty parenList here will silently cause 'most vexing
// parse' issue.
return StmtDiff();
}
clonedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, clonedArgs).get();
clonedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, primalArgs).get();
}
} else {
clonedArgsE = clonedArgs[0];
clonedArgsE = primalArgs[0];
}
// `CXXConstructExpr` node will be created automatically by passing these
// initialiser to higher level `ActOn`/`Build` Sema functions.
Expand Down Expand Up @@ -4384,6 +4528,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_DiffReq->getLocation());
args.push_back(baseExpr);
}
if (auto CD = llvm::dyn_cast<CXXConstructorDecl>(FD)) {
const RecordDecl* RD = CD->getParent();
QualType constructorReverseForwTagT =
GetCladConstructorReverseForwTagOfType(m_Context.getRecordType(RD));
Expr* constructorReverseForwTagArg =
m_Sema
.BuildCXXTypeConstructExpr(
m_Context.getTrivialTypeSourceInfo(
constructorReverseForwTagT, utils::GetValidSLoc(m_Sema)),
utils::GetValidSLoc(m_Sema), MultiExprArg{},
utils::GetValidSLoc(m_Sema),
/*ListInitialization=*/false)
.get();
args.push_back(constructorReverseForwTagArg);
}
args.append(primalArgs.begin(), primalArgs.end());
args.append(derivedArgs.begin(), derivedArgs.end());
Expr* customForwPassCE =
Expand All @@ -4392,4 +4551,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const_cast<DeclContext*>(FD->getDeclContext()));
return customForwPassCE;
}

void ReverseModeVisitor::ConstructorPullbackCallInfo::updateThisParmArgs(
Expr* thisE, Expr* dThisE) const {
pullbackCE->setArg(0, thisE);
pullbackCE->setArg(thisAdjointArgIdx, dThisE);
}
} // end namespace clad
Loading

0 comments on commit f7b3caa

Please sign in to comment.