Skip to content

Commit

Permalink
Formatting changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 6, 2023
1 parent b5ce4ac commit fe7310f
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 312 deletions.
8 changes: 3 additions & 5 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ namespace clad {
/// Create new block.
Stmts& beginBlock(direction d = direction::forward) {
if (d == direction::forward)
m_Blocks.push_back({});
m_Blocks.emplace_back();
else if (d == direction::reverse)
m_Reverse.push_back({});
m_Reverse.emplace_back();
else
m_EssentialReverse.push_back({});
m_EssentialReverse.emplace_back();
return getCurrentBlock(d);
}
/// Remove the block from the stack, wrap it in CompoundStmt and return it.
Expand Down Expand Up @@ -616,8 +616,6 @@ namespace clad {

clang::QualType ComputeAdjointType(clang::QualType T);
clang::QualType ComputeParamType(clang::QualType T);

std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E);
};
} // end namespace clad

Expand Down
59 changes: 36 additions & 23 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
};

/// Stores all the necessary information about one variable. Fundamental type
/// variables need only one bool. An object/array needs a separate VarData for
/// variables need only one bit. An object/array needs a separate VarData for
/// every its field/element. Reference type variables have their own type for
/// convenience reasons and just point to the corresponding VarData.
/// UNDEFINED is used whenever the type of a node cannot be determined.
Expand Down Expand Up @@ -111,32 +111,32 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
union VarDataValue {
bool fundData;
/// objData, arrData are stored as pointers for VarDataValue to take
bool m_FundData;
/// m_ObjData, m_ArrData are stored as pointers for VarDataValue to take
/// less space.
ObjMap* objData;
ArrMap* arrData;
Expr* refData;
VarDataValue() : fundData(false) {}
ObjMap* m_ObjData;
ArrMap* m_ArrData;
Expr* m_RefData;
VarDataValue() : m_FundData(false) {}
};
VarDataType type = UNDEFINED;
VarDataValue val;

VarData() = default;

/// Builds a VarData object (and its children) based on the provided type.
VarData(const QualType QT);
VarData(QualType QT);

/// Erases all children VarData's of this VarData.
void erase() {
if (type == OBJ_TYPE) {
for (auto& pair : *val.objData)
for (auto& pair : *val.m_ObjData)
pair.second.erase();
delete val.objData;
delete val.m_ObjData;
} else if (type == ARR_TYPE) {
for (auto& pair : *val.arrData)
for (auto& pair : *val.m_ArrData)
pair.second.erase();
delete val.arrData;
delete val.m_ArrData;
}
}
};
Expand Down Expand Up @@ -202,7 +202,22 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
VarsData* prev = nullptr;

VarsData() {}
VarsData(VarsData& other) : data(other.data), prev(other.prev) {}
VarsData(const VarsData& other) : data(other.data), prev(other.prev) {}
VarsData(VarsData&& other) : data(std::move(other.data)), prev(other.prev) {}
VarsData& operator=(const VarsData& other) {
for (auto& pair : data)
pair.second.erase();
data = other.data;
prev = other.prev;
return *this;
}
VarsData& operator=(VarsData&& other) {
for (auto& pair : data)
pair.second.erase();
data = std::move(other.data);
prev = other.prev;
return *this;
}

~VarsData() {
for (auto& pair : data)
Expand Down Expand Up @@ -234,11 +249,11 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Note: the returned VarsData contains original data from
/// the predecessors (NOT copies). It should not be modified.
std::unordered_map<const clang::VarDecl*, VarData*>
collectDataFromPredecessors(VarsData* varsData, VarsData* limit = nullptr);
static collectDataFromPredecessors(VarsData* varsData, VarsData* limit = nullptr);

/// Finds the lowest common ancestor of two VarsData
/// (based on the prev field in VarsData).
VarsData* findLowestCommonAncestor(VarsData* varsData1, VarsData* varsData2);
static VarsData* findLowestCommonAncestor(VarsData* varsData1, VarsData* varsData2);

/// Merges mergeData into targetData. Should be called
/// after mergeData is passed and the corresponding CFG
Expand All @@ -260,9 +275,9 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
std::vector<short> modeStack;
std::vector<int> modeStack;

ASTContext* m_Context;
ASTContext& m_Context;

/// clang::CFG of the function being analysed.
std::unique_ptr<clang::CFG> m_CFG;
Expand Down Expand Up @@ -297,7 +312,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void setIsRequired(const clang::Expr* E, bool isReq = true);

/// Returns the VarsData of the CFG block being visited.
VarsData& getCurBranch() { return *blockData[curBlockID]; }
VarsData& getCurBlockVarsData() { return *blockData[curBlockID]; }

//// Modes Setters
/// Sets the mode manually
Expand All @@ -315,16 +330,14 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

public:
/// Constructor
TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) {
TBRAnalyzer(ASTContext& m_Context) : m_Context(m_Context) {
modeStack.push_back(0);
}

/// Destructor
~TBRAnalyzer() {
for (auto varsData : blockData) {
if (varsData) {
delete varsData;
}
delete varsData;
}
}

Expand All @@ -340,7 +353,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Visitors
void Analyze(const clang::FunctionDecl* FD);

void VisitCFGBlock(clang::CFGBlock* block);
void VisitCFGBlock(const clang::CFGBlock* block);

void Visit(const clang::Stmt* stmt) {
clang::ConstStmtVisitor<TBRAnalyzer, void>::Visit(stmt);
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace clad {

clang::Stmt* getForwSweepStmt_dx() { return m_DerivativeForForwSweep; }

clang::Expr* getRevSweepExpr() {
clang::Expr* getRevSweepAsExpr() {
return llvm::cast_or_null<clang::Expr>(getRevSweepStmt());
}

Expand Down
Loading

0 comments on commit fe7310f

Please sign in to comment.