Skip to content

Commit

Permalink
Simplify array TBR analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Sep 8, 2023
1 parent 914023e commit 1b3a74d
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 183 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ namespace clad {
/// Returns true if `QT` is Array or Pointer Type, otherwise returns false.
bool isArrayOrPointerType(const clang::QualType QT);

/// Returns true if `T` is auto or auto* type, otherwise returns false.
bool IsAutoOrAutoPtrType(const clang::Type* T);

clang::DeclarationNameInfo BuildDeclarationNameInfo(clang::Sema& S,
llvm::StringRef name);

Expand Down
31 changes: 21 additions & 10 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Used to provide a hash function for an unordered_map with llvm::APInt
/// type keys.
struct APIntHash {
size_t operator()(const llvm::APInt& apint) const {
return llvm::hash_value(apint);
size_t operator()(const llvm::APInt& x) const {
return llvm::hash_value(x);
}
};

static bool eqAPInt(const llvm::APInt& x, const llvm::APInt& y) {
if (x.getBitWidth() != y.getBitWidth())
return false;
return x == y;
}

struct APIntComp {
bool operator()(const llvm::APInt& x, const llvm::APInt& y) const {
return eqAPInt(x, y);
}
};

Expand Down Expand Up @@ -91,7 +103,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

struct VarData;
using ObjMap = std::unordered_map<const clang::FieldDecl*, VarData*>;
using ArrMap = std::unordered_map<const llvm::APInt, VarData*, APIntHash>;
using ArrMap =
std::unordered_map<const llvm::APInt, VarData*, APIntHash, APIntComp>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
Expand All @@ -114,15 +127,16 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
VarData(const VarData&&) = delete;
VarData& operator=(const VarData&&) = delete;

/// Builds a VarData object (and its children) based on the provided type.
VarData(const QualType QT);

~VarData() {
if (type == OBJ_TYPE) {
for (auto& pair : *val.objData) {
for (auto& pair : *val.objData)
delete pair.second;
}
} else if (type == ARR_TYPE) {
for (auto& pair : *val.arrData) {
for (auto& pair : *val.arrData)
delete pair.second;
}
}
}
/// Recursively sets all the leaves' bools to isReq.
Expand Down Expand Up @@ -165,9 +179,6 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
bool addNonConstIdx = false);
/// Given an Expr* returns its corresponding VarData.
VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false);
/// Adds the field FD to objData.
void addField(std::unordered_map<const clang::FieldDecl*, VarData*>* objData,
const FieldDecl* FD);
/// Whenever an array element with a non-constant index is set to required
/// this function is used to set to required all the array elements that
/// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required when
Expand Down
12 changes: 12 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,17 @@ namespace clad {
Finder finder;
return finder.Find(E);
}

bool IsAutoOrAutoPtrType(const clang::Type* T) {
if (isa<clang::AutoType>(T))
return true;

if (const auto pointerType = dyn_cast<clang::PointerType>(T)) {
return IsAutoOrAutoPtrType(
pointerType->getPointeeType().getTypePtrOrNull());
}

return false;
}
} // namespace utils
} // namespace clad
Loading

0 comments on commit 1b3a74d

Please sign in to comment.