Skip to content

Commit

Permalink
Merge branch 'master' into cuda-compilation-support
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Sep 5, 2024
2 parents d5e9a98 + 685bcbf commit c42fd53
Show file tree
Hide file tree
Showing 62 changed files with 1,765 additions and 1,132 deletions.
12 changes: 12 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ template <typename T, typename U> struct ValueAndAdjoint {
/// class for which constructor pushforward is defined.
template <class T> class ConstructorPushforwardTag {};

template <class T> class ConstructorReverseForwTag {};

namespace custom_derivatives {
#ifdef __CUDACC__
template <typename T>
Expand Down Expand Up @@ -329,6 +331,16 @@ using std::pow_pullback;
using std::pow_pushforward;
using std::sin_pushforward;
using std::sqrt_pushforward;

namespace class_functions {
template <typename T, typename U>
void constructor_pullback(ValueAndPushforward<T, U>* lhs,
ValueAndPushforward<T, U> rhs,
ValueAndPushforward<T, U>* d_lhs,
ValueAndPushforward<T, U>* d_rhs) {
d_rhs->pushforward += d_lhs->pushforward;
}
} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad

Expand Down
487 changes: 433 additions & 54 deletions include/clad/Differentiator/KokkosBuiltins.h

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/Sema/Sema.h"

#include <array>
#include <limits>
#include <memory>
#include <stack>
#include <unordered_map>
Expand Down Expand Up @@ -689,6 +690,35 @@ namespace clad {
}

void PopSwitchStmtInfo() { m_SwitchStmtsData.pop_back(); }

struct ConstructorPullbackCallInfo {
clang::CallExpr* pullbackCE = nullptr;
size_t thisAdjointArgIdx = std::numeric_limits<size_t>::max();
void updateThisParmArgs(clang::Expr* thisE, clang::Expr* dThisE) const;
ConstructorPullbackCallInfo() = default;
ConstructorPullbackCallInfo(clang::CallExpr* pPullbackCE,
size_t pThisAdjointArgIdx)
: pullbackCE(pPullbackCE), thisAdjointArgIdx(pThisAdjointArgIdx) {}

bool empty() const { return !pullbackCE; }
};

void setConstructorPullbackCallInfo(clang::CallExpr* pullbackCE,
size_t thisAdjointArgIdx) {
m_ConstructorPullbackCallInfo = {pullbackCE, thisAdjointArgIdx};
}

ConstructorPullbackCallInfo getConstructorPullbackCallInfo() {
return m_ConstructorPullbackCallInfo;
}

void resetConstructorPullbackCallInfo() {
m_ConstructorPullbackCallInfo = ConstructorPullbackCallInfo{};
}

private:
ConstructorPullbackCallInfo m_ConstructorPullbackCallInfo;
bool m_TrackConstructorPullbackInfo = false;
};
} // end namespace clad

Expand Down
22 changes: 22 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,28 @@ void operator_subscript_pullback(::std::vector<T>* vec,
(*d_vec)[idx] += d_y;
}

template <typename T, typename S, typename U>
::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
S count, U val,
typename ::std::vector<T>::allocator_type alloc,
S d_count, U d_val,
typename ::std::vector<T>::allocator_type d_alloc) {
::std::vector<T> v(count, val);
::std::vector<T> d_v(count, 0);
return {v, d_v};
}

template <typename T, typename S, typename U>
void constructor_pullback(::std::vector<T>* v, S count, U val,
typename ::std::vector<T>::allocator_type alloc,
::std::vector<T>* d_v, S* d_count, U* d_val,
typename ::std::vector<T>::allocator_type* d_alloc) {
for (unsigned i = 0; i < count; ++i)
*d_val += (*d_v)[i];
d_v->clear();
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
7 changes: 7 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ namespace clad {
/// Returns type clad::Identify<T>
clang::QualType GetCladConstructorPushforwardTagOfType(clang::QualType T);

/// Returns clad::ConstructorReverseForwTag template declaration.
clang::TemplateDecl* GetCladConstructorReverseForwTag();

/// Returns type clad::ConstructorReverseForwTag<T>
clang::QualType GetCladConstructorReverseForwTagOfType(clang::QualType T);

public:
/// Rebuild a sequence of nested namespaces ending with DC.
clang::NamespaceDecl* RebuildEnclosingNamespaces(clang::DeclContext* DC);
Expand Down Expand Up @@ -661,6 +667,7 @@ namespace clad {

private:
clang::TemplateDecl* m_CladConstructorPushforwardTag = nullptr;
clang::TemplateDecl* m_CladConstructorReverseForwTag = nullptr;
};
} // end namespace clad

Expand Down
9 changes: 6 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,8 +1070,9 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
// If DRE is of type pointer, then the derivative is a null pointer.
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);
QualType literalTy = utils::GetValueType(clonedDRE->getType());
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
literalTy, m_Context, /*val=*/0));
}

StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down Expand Up @@ -1374,8 +1375,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
} else if (opKind == UnaryOperatorKind::UO_Deref) {
if (Expr* dx = diff.getExpr_dx())
return StmtDiff(op, BuildOp(opKind, dx));
return StmtDiff(op, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
QualType literalTy =
utils::GetValueType(UnOp->getSubExpr()->getType()->getPointeeType());
return StmtDiff(
op, ConstantFolder::synthesizeLiteral(literalTy, m_Context, /*val=*/0));
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
Expand Down
4 changes: 4 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ namespace clad {
else if (T->isArrayType())
valueType =
T->getPointeeOrArrayElementType()->getCanonicalTypeInternal();
else if (T->isEnumeralType()) {
if (const auto* ET = dyn_cast<EnumType>(T))
valueType = ET->getDecl()->getIntegerType();
}
valueType.removeLocalConst();
return valueType;
}
Expand Down
32 changes: 27 additions & 5 deletions lib/Differentiator/ConstantFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ namespace clad {
return FloatingLiteral::Create(C, val, /*isexact*/true, QT, noLoc);
}

static Expr* synthesizeLiteral(QualType QT, ASTContext& C, bool val) {
assert(QT->isBooleanType() && "Not a boolean type.");
SourceLocation noLoc;
return new (C) CXXBoolLiteralExpr(val, QT, noLoc);
}

static Expr* synthesizeLiteral(QualType QT, ASTContext& C) {
assert(QT->isPointerType() && "Not a pointer type.");
SourceLocation noLoc;
return new (C) CXXNullPtrLiteralExpr(QT, noLoc);
}

Expr* ConstantFolder::trivialFold(Expr* E) {
Expr::EvalResult Result;
if (E->EvaluateAsRValue(Result, m_Context)) {
Expand Down Expand Up @@ -126,18 +138,28 @@ namespace clad {

Expr* ConstantFolder::synthesizeLiteral(QualType QT, ASTContext& C,
uint64_t val) {
//SourceLocation noLoc;
// SourceLocation noLoc;
Expr* Result = 0;
if (QT->isIntegralType(C)) {
QT = QT.getCanonicalType();
if (QT->isPointerType()) {
Result = clad::synthesizeLiteral(QT, C);
} else if (QT->isBooleanType()) {
Result = clad::synthesizeLiteral(QT, C, (bool)val);
} else if (QT->isIntegralType(C)) {
if (QT->isAnyCharacterType())
QT = C.IntTy;
llvm::APInt APVal(C.getIntWidth(QT), val,
QT->isSignedIntegerOrEnumerationType());
Result = clad::synthesizeLiteral(QT, C, APVal);
}
else {
} else if (QT->isRealFloatingType()) {
llvm::APFloat APVal(C.getFloatTypeSemantics(QT), val);
Result = clad::synthesizeLiteral(QT, C, APVal);
} else {
// FIXME: Handle other types, like Complex, Structs, typedefs, etc.
// typecasting may be needed right now
Result = ConstantFolder::synthesizeLiteral(C.IntTy, C, val);
}
assert(Result && "Must not be zero.");
assert(Result && "Unsupported type for constant folding.");
return Result;
}
} // end namespace clad
1 change: 0 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {

CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);
Expand Down
Loading

0 comments on commit c42fd53

Please sign in to comment.