Skip to content

Commit

Permalink
Optimize memory usage in analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Sep 5, 2023
1 parent 0230d3a commit 914023e
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 330 deletions.
12 changes: 6 additions & 6 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace clad {
std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD);

/// Creates and returns a compound statement having statements as follows:
/// {`S`, all the statement of `initial` in sequence}
/// {`S`, all the statement of `initial` in sequence}
clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);
Expand All @@ -38,7 +38,7 @@ namespace clad {
clang::CompoundStmt* AppendAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);

/// Shorthand to issues a warning or error.
template <std::size_t N>
void EmitDiag(clang::Sema& semaRef,
Expand Down Expand Up @@ -126,8 +126,8 @@ namespace clad {
///
/// \param S
/// \param namespc
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param DC
clang::NamespaceDecl* LookupNSD(clang::Sema& S, llvm::StringRef namespc,
bool shouldExist,
Expand Down Expand Up @@ -234,7 +234,7 @@ namespace clad {

bool IsCladValueAndPushforwardType(clang::QualType T);

/// Returns a valid `SourceRange` to be used in places where clang
/// Returns a valid `SourceRange` to be used in places where clang
/// requires a valid `SourceRange`.
clang::SourceRange GetValidSRange(clang::Sema& semaRef);

Expand Down Expand Up @@ -314,7 +314,7 @@ namespace clad {

bool hasNonDifferentiableAttribute(const clang::Expr* E);
/// FIXME: add documentation
std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E);
std::vector<clang::Expr*> GetInnermostReturnExpr(const clang::Expr* E);
} // namespace utils
}

Expand Down
29 changes: 14 additions & 15 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ namespace clad {
Stmts m_Globals;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// Output variable of vector-valued function
Expand Down Expand Up @@ -135,7 +139,7 @@ namespace clad {
return m_Blocks.back();
else if (d == direction::reverse)
return m_Reverse.back();
else
else
return m_EssentialReverse.back();
}
/// Create new block.
Expand All @@ -144,7 +148,7 @@ namespace clad {
m_Blocks.push_back({});
else if (d == direction::reverse)
m_Reverse.push_back({});
else
else
m_EssentialReverse.push_back({});
return getCurrentBlock(d);
}
Expand Down Expand Up @@ -227,11 +231,6 @@ namespace clad {
forceDeclCreation, IS);
}

/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;

/// For an expr E, decides if it is useful to store it in a global temporary
/// variable and replace E's further usage by a reference to that variable
/// to avoid recomputiation.
Expand Down Expand Up @@ -428,7 +427,7 @@ namespace clad {
clang::QualType xType);

/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
/// executed iterations of a loop.
///
/// It is required to save the number of executed iterations to use the
/// same number of iterations in the reverse pass.
Expand All @@ -447,11 +446,11 @@ namespace clad {
/// for counter; otherwise, returns nullptr.
clang::Expr* getPush() const { return m_Push; }

/// Returns `clad::pop(_t)` expression if clad tape is used for
/// Returns `clad::pop(_t)` expression if clad tape is used for
/// for counter; otherwise, returns nullptr.
clang::Expr* getPop() const { return m_Pop; }

/// Returns reference to the last object of the clad tape if clad tape
/// Returns reference to the last object of the clad tape if clad tape
/// is used as the counter; otherwise returns reference to the counter
/// variable.
clang::Expr* getRef() const { return m_Ref; }
Expand Down Expand Up @@ -493,11 +492,11 @@ namespace clad {

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
/// and only executing the statements, with the help of case labels,
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// was hit in which iteration and that in turn helps to determine which
/// case label should be selected.
///
Expand Down Expand Up @@ -525,7 +524,7 @@ namespace clad {
/// \note `m_ControlFlowTape` is only initialized if the body contains
/// `continue` or `break` statement.
std::unique_ptr<CladTapeResult> m_ControlFlowTape;

/// Each `break` and `continue` statement is assigned a unique number,
/// starting from 1, that is used as the case label corresponding to that `break`/`continue`
/// statement. `m_CaseCounter` stores the value that was used for last
Expand Down Expand Up @@ -564,7 +563,7 @@ namespace clad {
/// control flow switch statement.
clang::CaseStmt* GetNextCFCaseStmt();

/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();
Expand Down
128 changes: 82 additions & 46 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,41 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
union IdxOrMemberValue {
const clang::FieldDecl* field;
llvm::APInt index;
IdxOrMemberValue() {}
IdxOrMemberValue() : field(nullptr) {}
~IdxOrMemberValue() {}
IdxOrMemberValue(const IdxOrMemberValue&) = delete;
IdxOrMemberValue& operator=(const IdxOrMemberValue&) = delete;
IdxOrMemberValue(const IdxOrMemberValue&&) = delete;
IdxOrMemberValue& operator=(const IdxOrMemberValue&&) = delete;
};
IdxOrMemberType type;
IdxOrMemberValue val;
IdxOrMember(const clang::FieldDecl* field) : type(IdxOrMemberType::FIELD) {
val.field = field;
}
IdxOrMember(llvm::APInt index) : type(IdxOrMemberType::INDEX) {
IdxOrMember(llvm::APInt&& index) : type(IdxOrMemberType::INDEX) {
new (&val.index) llvm::APInt(index);
}
IdxOrMember(const IdxOrMember& other) : type(other.type) {
IdxOrMember(const IdxOrMember& other) {
new (&val.index) llvm::APInt();
*this = other;
}
IdxOrMember(const IdxOrMember&& other) noexcept {
new (&val.index) llvm::APInt();
*this = other;
}
IdxOrMember& operator=(const IdxOrMember& other) {
type = other.type;
if (type == IdxOrMemberType::FIELD)
val.field = other.val.field;
else
new (&val.index) llvm::APInt(other.val.index);
val.index = other.val.index;
return *this;
}
IdxOrMember& operator=(const IdxOrMember&& other) noexcept {
return *this = other;
}
~IdxOrMember() = default;
};

/// Stores all the necessary information about one variable. Fundamental type
Expand All @@ -71,39 +89,47 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// a row, which seems uncommon. It's worth considering analysing arrays as
/// whole structures instead (just one VarData for the whole array).

struct VarData;
using ObjMap = std::unordered_map<const clang::FieldDecl*, VarData*>;
using ArrMap = std::unordered_map<const llvm::APInt, VarData*, APIntHash>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
union VarDataValue {
bool fundData;
std::unordered_map<const clang::FieldDecl*, VarData*> objData;
std::unordered_map<const llvm::APInt, VarData*, APIntHash> arrData;
/// objData, arrData are stored as pointers for VarDataValue to take
/// less space.
ObjMap* objData;
ArrMap* arrData;
VarData* refData;
VarDataValue() {}
~VarDataValue() {}
VarDataValue() : fundData(false) {}
};
VarDataType type;
VarDataValue val;
bool isReferenced = false;

/// For non-fundamental type variables, all the child nodes have to be
/// deleted.
VarData() = default;
VarData(const VarData&) = delete;
VarData& operator=(const VarData&) = delete;
VarData(const VarData&&) = delete;
VarData& operator=(const VarData&&) = delete;

~VarData() {
if (type == OBJ_TYPE) {
for (auto pair : val.objData) {
for (auto& pair : *val.objData) {
delete pair.second;
}
} else if (type == ARR_TYPE) {
for (auto pair : val.arrData) {
for (auto& pair : *val.arrData) {
delete pair.second;
}
}
}

/// Recursively sets all the leaves' bools to isReq.
void setIsRequired(bool isReq = true);
/// Returns true if there is at least one required to store node among
/// child nodes.
bool findReq();
bool findReq() const;
/// Whenever an array element with a non-constant index is set to required
/// this function is used to set to required all the array elements that
/// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required
Expand All @@ -122,7 +148,9 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// corresponding original nodes in case those are referenced (a referenced
/// node is a child to multiple nodes, therefore, we need to make sure we
/// don't make multiple copies of it).
VarData* copy();
VarData* copy(std::unordered_map<VarData*, VarData*>& refVars);
void restoreRefs(std::unordered_map<VarData*, VarData*>& refVars);
};

/// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its
Expand All @@ -138,7 +166,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Given an Expr* returns its corresponding VarData.
VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false);
/// Adds the field FD to objData.
void addField(std::unordered_map<const clang::FieldDecl*, VarData*>& objData,
void addField(std::unordered_map<const clang::FieldDecl*, VarData*>* objData,
const FieldDecl* FD);
/// Whenever an array element with a non-constant index is set to required
/// this function is used to set to required all the array elements that
Expand All @@ -164,7 +192,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
std::map<clang::SourceLocation, bool> TBRLocs;
/// Stores VarsData for every branch in control flow (e.g. if-else statements,
/// loops).
std::vector<VarsData> reqStack;
std::vector<std::vector<VarsData>> reqStack;
/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
std::vector<int> modeStack;
Expand All @@ -179,7 +207,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

/// The index of the innermost branch corresponding to a loop (used to handle
/// break/continue statements).
size_t innermostLoopBranch = 0;
size_t innermostLoopLayer = 0;
/// Tells if the current branch should be deleted instead of merged with
/// others. This happens when the branch has a break/continue statement or a
/// return expression in it.
Expand All @@ -204,22 +232,28 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void setIsRequired(const clang::Expr* E, bool isReq = true);

//// Control Flow
/// Creates a new branch as a copy of the last used branch.
void addBranch();
/// Merges the last into the one right before it and deletes it.
/// If keepNewVars==false, it removes all the variables that are present
/// in the last branch but not the other. If keepNewVars==true, all the new
/// variables are kept.
/// Note: The branch we are merging into is not supposed to have its own
/// local variables (this doesn't matter to the branch being merged).
void mergeAndDelete(bool keepNewVars = false);
/// Swaps the last two branches in the stack.
void swapLastPairOfBranches();
/// Merges the current branch to a branch with a given index in the stack.
/// Current branch is NOT deleted.
/// Note: The branch we are merging into is not supposed to have its own
/// local variables (this doesn't matter to the branch being merged).
void mergeCurBranchTo(size_t targetBranchNum);
/// Returns the current branch.
VarsData& getCurBranch() { return reqStack.back().back(); }
/// Adds a new layer.
void addLayer() { reqStack.emplace_back(); }
/// Creates a new empty branch.
void addBranch() { reqStack.back().emplace_back(); }
/// Deletes the last branch.
void deleteBranch() {
for (auto& pair : getCurBranch())
delete pair.second;
reqStack.back().pop_back();
}
/// Merges the last layer into the one last branch on the previous layer
/// right and deletes the last layer.
void mergeLayer();
/// Merges the last layer but, unlike the previous method, basically replaces
/// the last branch on the previous layer with the result of merging. After
/// that, removes the last layer.
void mergeLayerOnTop();
/// Merges the branch with index targetBranch into a sourceBranchNum.
/// No branches are deleted.
void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch);
/// Removes local variables from the current branch (uses localVarsStack).
/// This is necessary when merging if-else branches.
void removeLocalVars();
Expand All @@ -239,29 +273,31 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void resetMode() { modeStack.pop_back(); }

public:
// Constructor
/// Constructor
TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) {
modeStack.push_back(0);
reqStack.push_back(VarsData());
addLayer();
addBranch();
}

// Destructor
/// Destructor
~TBRAnalyzer() {
/// By the end of analysis, reqStack is supposed have just one branch
/// but it's better to iterate through it just to make sure there's no
/// memory leak.
for (auto& branch : reqStack) {
for (auto pair : branch) {
delete pair.second;
}
}
for (auto& layer : reqStack)
for (auto& branch : layer)
for (auto& pair : branch)
delete pair.second;
}

/// Delete copy/move operators and constructors.
TBRAnalyzer(const TBRAnalyzer&) = delete;
TBRAnalyzer& operator=(const TBRAnalyzer&) = delete;
TBRAnalyzer(const TBRAnalyzer&&) = delete;
TBRAnalyzer& operator=(const TBRAnalyzer&&) = delete;

/// Returns the result of the whole analysis
const std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }
std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }

/// Visitors

void Analyze(const clang::FunctionDecl* FD);

void Visit(const clang::Stmt* stmt) {
Expand Down
Loading

0 comments on commit 914023e

Please sign in to comment.