Skip to content

Commit

Permalink
Simplify error estimation by removing _EERepl_ and _delta_
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Feb 16, 2024
1 parent 24ff43f commit e1b3715
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 780 deletions.
116 changes: 15 additions & 101 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
// multiple header files.
// `Stmts` is originally defined in `VisitorBase`.
using Stmts = llvm::SmallVector<clang::Stmt*, 16>;
/// Keeps a track of the delta error expression we shouldn't emit.
bool m_DoNotEmitDelta;
/// Reference to the final error parameter in the augumented target
/// function.
clang::Expr* m_FinalError;
Expand All @@ -42,22 +40,19 @@ class ErrorEstimationHandler : public ExternalRMVSource {
Stmts m_ReverseErrorStmts;
/// The index expression for emitting final errors for input param errors.
clang::Expr* m_IdxExpr;
/// A set of declRefExprs for parameter value replacements.
std::unordered_map<const clang::VarDecl*, clang::Expr*> m_ParamRepls;
/// An expression to match nested function call errors with their
/// assignee (if any exists).
clang::Expr* m_NestedFuncError = nullptr;

std::stack<bool> m_ShouldEmit;
ReverseModeVisitor* m_RMV;
clang::Expr* m_DeltaVar = nullptr;
llvm::SmallVectorImpl<clang::QualType>* m_ParamTypes = nullptr;
llvm::SmallVectorImpl<clang::ParmVarDecl*>* m_Params = nullptr;

public:
using direction = rmv::direction;
ErrorEstimationHandler()
: m_DoNotEmitDelta(false), m_FinalError(nullptr), m_RetErrorExpr(nullptr),
: m_FinalError(nullptr), m_RetErrorExpr(nullptr),
m_EstModel(nullptr), m_IdxExpr(nullptr) {}
~ErrorEstimationHandler() override = default;

Expand All @@ -70,33 +65,16 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] finErrExpr The final error expression.
void SetFinalErrorExpr(clang::Expr* finErrExpr) { m_FinalError = finErrExpr; }

/// Shorthand to get array subscript expressions.
///
/// \param[in] arrBase The base expression of the array.
/// \param[in] idx The index expression.
/// \param[in] isCladSpType Keeps track of if we have to build a clad
/// special type (i.e. clad::Array or clad::ArrayRef).
///
/// \returns An expression of the kind arrBase[idx].
clang::Expr* getArraySubscriptExpr(clang::Expr* arrBase, clang::Expr* idx,
bool isCladSpType = true);

/// \returns The final error expression so far.
clang::Expr* GetFinalErrorExpr() { return m_FinalError; }

/// Function to build the final error statemnt of the function. This is the
/// last statement of any target function in error estimation and
/// aggregates the error in all the registered variables.
void BuildFinalErrorStmt();
/// Function to build the error statement corresponding
/// to the function's return statement.
void BuildReturnErrorStmt();

/// Function to emit error statements into the derivative body.
///
/// \param[in] var The variable whose error statement we want to emit.
/// \param[in] deltaVar The "_delta_" expression of the variable 'var'.
/// \param[in] errorExpr The error expression (LHS) of the variable 'var'.
/// \param[in] isInsideLoop A flag to indicate if 'val' is inside a loop.
void AddErrorStmtToBlock(clang::Expr* var, clang::Expr* deltaVar,
clang::Expr* errorExpr, bool isInsideLoop = false);
/// \param[in] errorExpr The error expression (LHS) of the variable.
/// \param[in] addToTheFront A flag to decide whether the error stmts
/// should be added to the beginning of the block or the current position.
void AddErrorStmtToBlock(clang::Expr* errorExpr, bool addToTheFront=true);

/// Emit the error estimation related statements that were saved to be
/// emitted at later points into specific blocks.
Expand Down Expand Up @@ -124,44 +102,12 @@ class ErrorEstimationHandler : public ExternalRMVSource {
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, size_t numArgs);

/// Save values of registered variables so that they can be replaced
/// properly in case of re-assignments.
///
/// \param[in] val The value to save.
/// \param[in] isInsideLoop A flag to indicate if 'val' is inside a loop.
///
/// \returns The saved variable and its derivative.
StmtDiff SaveValue(clang::Expr* val, bool isInLoop = false);

/// Save the orignal values of the input parameters in case of
/// re-assignments.
///
/// \param[in] paramRef The DeclRefExpr of the input parameter.
void SaveParamValue(clang::DeclRefExpr* paramRef);

/// Register variables to be used while accumulating error.
/// Register variable declarations so that they may be used while
/// calculating the final error estimates. Any unregistered variables will
/// not be considered for the final estimation.
///
/// \param[in] VD The variable declaration to be registered.
/// \param[in] toCurrentScope Add the created "_delta_" variable declaration
/// to the current scope instead of the global scope.
///
/// \returns The Variable declaration of the '_delta_' prefixed variable.
clang::Expr* RegisterVariable(clang::VarDecl* VD,
bool toCurrentScope = false);

/// Checks if a variable can be registered for error estimation.
/// Checks if a variable should be considered in error estimation.
///
/// \param[in] VD The variable declaration to be registered.
/// \param[in] VD The variable declaration.
///
/// \returns True if the variable can be registered, false otherwise.
bool CanRegisterVariable(clang::VarDecl* VD);

/// Calculate aggregate error from m_EstimateVar.
/// Builds the final error estimation statement.
clang::Stmt* CalculateAggregateError();
/// \returns true if the variable should be considered, false otherwise.
bool ShouldEstimateErrorFor(clang::VarDecl* VD);

/// Get the underlying DeclRefExpr type it it exists.
///
Expand All @@ -170,14 +116,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \returns The DeclRefExpr of input or null.
clang::DeclRefExpr* GetUnderlyingDeclRefOrNull(clang::Expr* expr);

/// Get the parameter replacement (if any).
///
/// \param[in] VD The parameter variable declaration to get replacement
/// for.
///
/// \returns The underlying replaced Expr.
clang::Expr* GetParamReplacement(const clang::ParmVarDecl* VD);

/// An abstraction of the error estimation model's AssignError.
///
/// \param[in] val The variable to get the error for.
Expand All @@ -190,16 +128,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
return m_EstModel->AssignError({var, varDiff}, varName);
}

/// An abstraction of the error estimation model's IsVariableRegistered.
///
/// \param[in] VD The variable declaration to check the status of.
///
/// \returns the reference to the respective '_delta_' expression if the
/// variable is registered, null otherwise.
clang::Expr* IsRegistered(clang::VarDecl* VD) {
return m_EstModel->IsVariableRegistered(VD);
}

/// This function adds the final error and the other parameter errors to the
/// forward block.
///
Expand All @@ -215,17 +143,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// loop.
void EmitUnaryOpErrorStmts(StmtDiff var, bool isInsideLoop);

/// This function registers all LHS declRefExpr in binary operations.
///
/// \param[in] LExpr The LHS of the operation.
/// \param[in] RExpr The RHS of the operation.
/// \param[in] isAssign A flag to know if the current operation is a simple
/// assignment.
///
/// \returns The delta value of the input 'var'.
clang::Expr* RegisterBinaryOpLHS(clang::Expr* LExpr, clang::Expr* RExpr,
bool isAssign);

/// This function emits the error in a binary operation.
///
/// \param[in] LExpr The LHS of the operation.
Expand All @@ -234,8 +151,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] deltaVar The delta value of the LHS.
/// \param[in] isInsideLoop A flag to keep track of if we are inside a
/// loop.
void EmitBinaryOpErrorStmts(clang::Expr* LExpr, clang::Expr* oldValue,
clang::Expr* deltaVar, bool isInsideLoop);
void EmitBinaryOpErrorStmts(clang::Expr* LExpr, clang::Expr* oldValue);

/// This function emits the error in declaration statements.
///
Expand Down Expand Up @@ -267,10 +183,8 @@ class ErrorEstimationHandler : public ExternalRMVSource {
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
void
ActAfterCloningLHSOfAssignOp(clang::Expr*& LCloned, clang::Expr*& R,
clang::BinaryOperator::Opcode& opCode) override;
void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&) override;
void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&) override;
void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override;
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
Expand Down
5 changes: 3 additions & 2 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ class ExternalRMVSource {
clang::BinaryOperatorKind& opCode) {
}

/// This is called just after finaising processing of assignment operator.
virtual void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&){};
/// This is called just after finalising processing of assignment operator.
virtual void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&){};

/// This is called at that beginning of
/// `ReverseModeVisitor::DifferentiateSingleStmt`.
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) override;
void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&,
clang::BinaryOperatorKind& opCode) override;
void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&) override;
void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&) override;
void ActOnStartOfDifferentiateSingleStmt() override;
void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override;
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
Expand Down
9 changes: 9 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@ namespace clad {
clang::Expr*
BuildArraySubscript(clang::Expr* Base,
const llvm::SmallVectorImpl<clang::Expr*>& IS);

/// Build an array subscript expression with a given base expression and
/// one index.
clang::Expr*
BuildArraySubscript(clang::Expr* Base,
clang::Expr*& Idx) {
llvm::SmallVector<clang::Expr*> IS = {Idx};
return BuildArraySubscript(Base, IS);
}
/// Find namespace clad declaration.
clang::NamespaceDecl* GetCladNamespace();
/// Find declaration of clad::class templated type
Expand Down
Loading

0 comments on commit e1b3715

Please sign in to comment.