Skip to content

Commit

Permalink
Store VarData by value instead of pointers.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 31, 2023
1 parent 8e32d5b commit bbd7280
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 127 deletions.
63 changes: 33 additions & 30 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// whole structures instead (just one VarData for the whole array).

struct VarData;
using ObjMap = std::unordered_map<const clang::FieldDecl*, VarData*>;
using ObjMap = std::unordered_map<const clang::FieldDecl*, VarData>;
using ArrMap =
std::unordered_map<const llvm::APInt, VarData*, APIntHash, APIntComp>;
std::unordered_map<const llvm::APInt, VarData, APIntHash, APIntComp>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
Expand All @@ -119,57 +119,57 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
Expr* refData;
VarDataValue() : fundData(false) {}
};
VarDataType type;
VarDataType type = UNDEFINED;
VarDataValue val;

VarData() = default;
VarData(const VarData&) = delete;
VarData& operator=(const VarData&) = delete;
VarData(const VarData&&) = delete;
VarData& operator=(const VarData&&) = delete;

/// Builds a VarData object (and its children) based on the provided type.
VarData(const QualType QT);

~VarData() {
if (type == OBJ_TYPE)
/// Erases all children VarData's of this VarData.
void erase() {
if (type == OBJ_TYPE) {
for (auto& pair : *val.objData)
delete pair.second;
else if (type == ARR_TYPE)
pair.second.erase();
delete val.objData;
} else if (type == ARR_TYPE) {
for (auto& pair : *val.arrData)
delete pair.second;
pair.second.erase();
delete val.arrData;
}
}
};
/// Recursively sets all the leaves' bools to isReq.
void setIsRequired(VarData* varData, bool isReq = true);
void setIsRequired(VarData& varData, bool isReq = true);
/// Whenever an array element with a non-constant index is set to required
/// this function is used to set to required all the array elements that
/// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required
/// when 'a[k].y' is set to required). Takes unwrapped sequence of
/// indices/members of the expression being overlaid and the index of of the
/// current index/member.
void overlay(VarData* targetData,
void overlay(VarData& targetData,
llvm::SmallVector<IdxOrMember, 2>& IdxAndMemberSequence,
size_t i);
/// Returns true if there is at least one required to store node among
/// child nodes.
bool findReq(const VarData* varData);
bool findReq(const VarData& varData);
/// Used to merge together VarData for one variable from two branches
/// (e.g. after an if-else statements). Look at the Control Flow section for
/// more information.
void merge(VarData* targetData, VarData* mergeData);
void merge(VarData& targetData, VarData& mergeData);
/// Used to recursively copy VarData when separating into different branches
/// (e.g. when entering an if-else statements). Look at the Control Flow
/// section for more information.
VarData* copy(VarData* copyData);
VarData copy(VarData& copyData);

clang::CFGBlock* getCFGBlockByID(unsigned ID);

/// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its
/// corresponding VarData. If the given element of an array does not have a
/// VarData* yet it will be added automatically. If addNonConstIdx==false this
/// will return the last VarData* before the non-constant index
/// (e.g. for 'x.arr[k+1].y' the return value will be the VarData* of x.arr).
/// VarData yet it will be added automatically. If addNonConstIdx==false this
/// will return the last VarData before the non-constant index
/// (e.g. for 'x.arr[k+1].y' the return value will be the VarData of x.arr).
/// Otherwise, non-const indices will be represented as index -1.
VarData* getMemberVarData(const clang::MemberExpr* ME,
bool addNonConstIdx = false);
Expand Down Expand Up @@ -197,23 +197,28 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Note: 'this' pointer does not have a declaration so nullptr is used as
/// its key instead.
struct VarsData {
std::unordered_map<const clang::VarDecl*, VarData*> data =
std::unordered_map<const clang::VarDecl*, VarData*>();
std::unordered_map<const clang::VarDecl*, VarData> data =
std::unordered_map<const clang::VarDecl*, VarData>();
VarsData* prev = nullptr;

VarsData() {}
VarsData(VarsData& other) : data(other.data), prev(other.prev) {}

~VarsData() {
for (auto& pair : data)
pair.second.erase();
}

using iterator =
std::unordered_map<const clang::VarDecl*, VarData*>::iterator;
std::unordered_map<const clang::VarDecl*, VarData>::iterator;
iterator begin() { return data.begin(); }
iterator end() { return data.end(); }
VarData*& operator[](const clang::VarDecl* VD) { return data[VD]; }
VarData& operator[](const clang::VarDecl* VD) { return data[VD]; }
iterator find(const clang::VarDecl* VD) { return data.find(VD); }
void emplace(const clang::VarDecl* VD, VarData* varsData) {
void emplace(const clang::VarDecl* VD, VarData varsData) {
data.emplace(VD, varsData);
}
void emplace(std::pair<const clang::VarDecl*, VarData*> pair) {
void emplace(std::pair<const clang::VarDecl*, VarData> pair) {
data.emplace(pair);
}
void clear() {
Expand All @@ -223,12 +228,12 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {


/// Collects the data from 'varsData' and its predecessors until
/// 'limit' into one VarsData ('limit' VarsData is not included).
/// 'limit' into one map ('limit' VarsData is not included).
/// If 'limit' is 'nullptr', data is collected starting with
/// the entry CFG block.
/// Note: the returned VarsData contains original data from
/// the predecessors (NOT copies). It should not be modified.
std::unique_ptr<VarsData>
std::unordered_map<const clang::VarDecl*, VarData*>
collectDataFromPredecessors(VarsData* varsData, VarsData* limit = nullptr);

/// Finds the lowest common ancestor of two VarsData
Expand Down Expand Up @@ -318,8 +323,6 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
~TBRAnalyzer() {
for (auto varsData : blockData) {
if (varsData) {
for (auto pair : *varsData)
delete pair.second;
delete varsData;
}
}
Expand Down
Loading

0 comments on commit bbd7280

Please sign in to comment.