Skip to content

Commit

Permalink
Synthesize booleans separately from integers and cast type to int whe…
Browse files Browse the repository at this point in the history
…n necessary
  • Loading branch information
kchristin22 committed Aug 26, 2024
2 parents ac042f9 + 6f4b081 commit 69075f1
Show file tree
Hide file tree
Showing 20 changed files with 791 additions and 225 deletions.
5 changes: 5 additions & 0 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ template <typename T> class array {

/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
/// Iterator functions
CUDA_HOST_DEVICE T* begin() { return m_arr; }
CUDA_HOST_DEVICE const T* begin() const { return m_arr; }
CUDA_HOST_DEVICE T* end() { return m_arr + m_size; }
CUDA_HOST_DEVICE const T* end() const { return m_arr + m_size; }
/// Returns the ptr of the underlying array
CUDA_HOST_DEVICE PUREFUNC T* ptr() const { return m_arr; }
CUDA_HOST_DEVICE PUREFUNC T*& ptr_ref() { return m_arr; }
Expand Down
14 changes: 0 additions & 14 deletions include/clad/Differentiator/EstimationModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ namespace clad {
/// Clear the variable estimate map so that we can start afresh.
void clearEstimationVariables() { m_EstimateVar.clear(); }

/// Helper to build a function call expression.
///
/// \param[in] funcName The name of the function to build the expression
/// for.
/// \param[in] nmspace The name of the namespace for the function,
/// currently does not support nested namespaces.
/// \param[in] callArgs A vector of \c clang::Expr of all the parameters
/// to the function call.
///
/// \return The function call expression that can be used to emit into
/// code.
clang::Expr* GetFunctionCall(std::string funcName, std::string nmspace,
llvm::SmallVectorImpl<clang::Expr*>& callArgs);

/// User overridden function to return the error expression of a
/// specific estimation model. The error expression is returned in the form
/// of a clang::Expr, the user may use BuildOp() to build the final
Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0,
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

/// Fence
template <typename S> void fence_pushforward(const S& s, const S& /*d_s*/) {
::Kokkos::fence(s);
}

/// Parallel for
template <class... PolicyParams, class FunctorType> // range policy
void parallel_for_pushforward(
Expand Down
20 changes: 13 additions & 7 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,14 @@ namespace clad {
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
clang::Expr* Placeholder;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
clang::VarDecl* pDeclaration, bool pIsInsideLoop,
bool pIsFnScope, bool pNeedsUpdate = false,
clang::Expr* pPlaceholder = nullptr)
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
isInsideLoop(pIsInsideLoop), isFnScope(pIsFnScope),
needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {}
void Finalize(clang::Expr* New);
};

Expand All @@ -297,7 +298,8 @@ namespace clad {
/// This is what DelayedGlobalStoreAndRef does. E is expected to be the
/// original (uncloned) expression.
DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
bool forceStore = false);

struct CladTapeResult {
ReverseModeVisitor& V;
Expand Down Expand Up @@ -400,12 +402,16 @@ namespace clad {
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
virtual StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff
VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
StmtDiff VisitWhileStmt(const clang::WhileStmt* WS);
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS);
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
Expand All @@ -417,7 +423,7 @@ namespace clad {
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool AddToBlock = true);
bool keepLocal = false);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
15 changes: 15 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,21 @@ namespace clad {
/// Instantiate clad::tape<T> type.
clang::QualType GetCladTapeOfType(clang::QualType T);

/// Helper to build a function call expression.
///
/// \param[in] funcName The name of the function to build the expression
/// for.
/// \param[in] nmspace The name of the namespace for the function,
/// currently does not support nested namespaces.
/// \param[in] callArgs A vector of \c clang::Expr of all the parameters
/// to the function call.
///
/// \return The function call expression that can be used to emit into
/// code.
clang::Expr* GetFunctionCall(const std::string& funcName,
const std::string& nmspace,
llvm::SmallVectorImpl<clang::Expr*>& callArgs);

clang::DeclRefExpr* GetCladTapePushDRE();

clang::Stmt* GetCladZeroInit(llvm::MutableArrayRef<clang::Expr*> args);
Expand Down
16 changes: 13 additions & 3 deletions lib/Differentiator/ConstantFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ namespace clad {
return FloatingLiteral::Create(C, val, /*isexact*/true, QT, noLoc);
}

static Expr* synthesizeLiteral(QualType QT, ASTContext& C, bool val) {
assert(QT->isBooleanType() && "Not a boolean type.");
SourceLocation noLoc;
return CXXBoolLiteralExpr::Create(C, val, QT, noLoc);
}

Expr* ConstantFolder::trivialFold(Expr* E) {
Expr::EvalResult Result;
if (E->EvaluateAsRValue(Result, m_Context)) {
Expand Down Expand Up @@ -128,14 +134,18 @@ namespace clad {
uint64_t val) {
//SourceLocation noLoc;
Expr* Result = 0;
if (QT->isIntegralType(C)) {
if (QT->isBooleanType()) {
printf("synthesizing boolean literal\n");
Result = clad::synthesizeLiteral(QT, C, (bool)val);
} else if (QT->isIntegralType(C)) {
llvm::APInt APVal(C.getIntWidth(QT), val,
QT->isSignedIntegerOrEnumerationType());
Result = clad::synthesizeLiteral(QT, C, APVal);
}
else {
} else if (QT->isRealFloatingType()) {
llvm::APFloat APVal(C.getFloatTypeSemantics(QT), val);
Result = clad::synthesizeLiteral(QT, C, APVal);
} else {
Result = ConstantFolder::synthesizeLiteral(C.IntTy, C, 0);
}
assert(Result && "Must not be zero.");
return Result;
Expand Down
28 changes: 0 additions & 28 deletions lib/Differentiator/EstimationModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,6 @@ namespace clad {

FPErrorEstimationModel::~FPErrorEstimationModel() {}

Expr* FPErrorEstimationModel::GetFunctionCall(
std::string funcName, std::string nmspace,
llvm::SmallVectorImpl<Expr*>& callArgs) {
NamespaceDecl* NSD =
utils::LookupNSD(m_Sema, nmspace, /*shouldExist=*/true);
DeclContext* DC = NSD;
CXXScopeSpec SS;
SS.Extend(m_Context, NSD, noLoc, noLoc);

IdentifierInfo* II = &m_Context.Idents.get(funcName);
DeclarationName name(II);
DeclarationNameInfo DNI(name, noLoc);
LookupResult R(m_Sema, DNI, Sema::LookupOrdinaryName);

if (DC)
m_Sema.LookupQualifiedName(R, DC);
Expr* UnresolvedLookup = nullptr;
if (!R.empty())
UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL=*/false).get();
llvm::MutableArrayRef<Expr*> MARargs =
llvm::MutableArrayRef<Expr*>(callArgs);
SourceLocation Loc;
return m_Sema
.ActOnCallExpr(getCurrentScope(), UnresolvedLookup, Loc, MARargs, Loc)
.get();
}

Expr* TaylorApprox::AssignError(StmtDiff refExpr,
const std::string& varName) {
// Get the machine epsilon value.
Expand Down
Loading

0 comments on commit 69075f1

Please sign in to comment.