diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 90aa97d82..9e0e9b102 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,56 +24,81 @@ jobs: matrix: include: - - name: osx-clang-runtime11 - os: macos-11 + - name: osx14-arm-clang-runtime13 + os: macos-14 compiler: clang - clang-runtime: '11' + clang-runtime: '13' + + - name: osx14-arm-clang-runtime14 + os: macos-14 + compiler: clang + clang-runtime: '14' + + - name: osx14-arm-clang-runtime15 + os: macos-14 + compiler: clang + clang-runtime: '15' + + - name: osx14-arm-clang-runtime16 + os: macos-14 + compiler: clang + clang-runtime: '16' - - name: osx-clang-runtime12 - os: macos-latest + - name: osx14-arm-clang-runtime17 + os: macos-14 + compiler: clang + clang-runtime: '17' + + - name: osx13-x86-clang-runtime12 + os: macos-13 compiler: clang clang-runtime: '12' - - name: osx-clang-runtime13 - os: macos-latest + - name: osx13-x86-clang-runtime13 + os: macos-13 compiler: clang clang-runtime: '13' - - name: osx-clang-runtime14 - os: macos-latest + - name: osx13-x86-clang-runtime14 + os: macos-13 compiler: clang clang-runtime: '14' - - name: osx-clang-runtime15 - os: macos-latest + - name: osx13-x86-clang-runtime15 + os: macos-13 compiler: clang clang-runtime: '15' - - name: osx-clang-runtime16 - os: macos-latest + - name: osx13-x86-clang-runtime16 + os: macos-13 compiler: clang clang-runtime: '16' - - name: osx-clang-runtime17 - os: macos-latest + - name: osx13-x86-clang-runtime17 + os: macos-13 compiler: clang clang-runtime: '17' - - name: win-msvc-runtime14 - os: windows-latest + - name: win2022-msvc-runtime14 + os: windows-2022 compiler: msvc clang-runtime: '14' - - name: win-msvc-runtime15 - os: windows-latest + - name: win2022-msvc-runtime15 + os: windows-2022 compiler: msvc clang-runtime: '15' - - name: win-msvc-runtime16 - os: windows-latest + - name: win2022-msvc-runtime16 + os: windows-2022 compiler: msvc clang-runtime: '16' + - name: win2022-msvc-runtime17 + os: windows-2022 + compiler: msvc + clang-runtime: '17' + - name: ubu22-clang15-runtime16-debug os: ubuntu-22.04 compiler: clang-15 @@ -559,7 +584,6 @@ jobs: # Update openssl on osx because the current one is deprecated by python. curl -L https://bootstrap.pypa.io/get-pip.py | sudo python3 echo "/usr/local/opt/ccache/libexec" >> $GITHUB_PATH - PATH_TO_LLVM_BUILD=/usr/local/opt/llvm@${{ matrix.clang-runtime }}/ # For now Package llvm@18 is unsuported on brew, llvm <=@11 are deprecated or deleted. # Install llvm from github releases. @@ -591,16 +615,22 @@ jobs: # allowing clang to work with system's SDK. sudo rm -fr /usr/local/opt/llvm*/include/c++ fi - + + PATH_TO_LLVM_BUILD=$(brew --prefix llvm@${{ matrix.clang-runtime }}) + pip3 install lit # LLVM lit is not part of the llvm releases... # We need headers in correct place - for file in $(xcrun --show-sdk-path)/usr/include/* - do - if [ ! -f /usr/local/include/$(basename $file) ]; then - ln -s $file /usr/local/include/$(basename $file) - fi - done + #FIXME: ln solution fails with error message No such file or directory on osx arm, + #Copying over files as a temporary solution + sudo cp -r -n $(xcrun --show-sdk-path)/usr/include/ /usr/local/include/ + #for file in $(xcrun --show-sdk-path)/usr/include/* + #do + # if [ ! -f /usr/local/include/$(basename $file) ]; then + # echo ${file} + # ln -s ${file} /usr/local/include/$(basename $file) + # fi + #done # We need PATH_TO_LLVM_BUILD later echo "PATH_TO_LLVM_BUILD=$PATH_TO_LLVM_BUILD" >> $GITHUB_ENV @@ -740,10 +770,10 @@ jobs: if: ${{ runner.os != 'windows' }} run: | mkdir obj && cd obj - cmake -DClang_DIR="$PATH_TO_LLVM_BUILD" \ - -DLLVM_DIR="$PATH_TO_LLVM_BUILD" \ + cmake -DClang_DIR=${{ env.PATH_TO_LLVM_BUILD }} \ + -DLLVM_DIR=${{ env.PATH_TO_LLVM_BUILD }} \ -DCMAKE_BUILD_TYPE=$([[ -z "$BUILD_TYPE" ]] && echo RelWithDebInfo || echo $BUILD_TYPE) \ - -DCLAD_CODE_COVERAGE=${CLAD_CODE_COVERAGE} \ + -DCLAD_CODE_COVERAGE=${{ env.CLAD_CODE_COVERAGE }} \ -DLLVM_EXTERNAL_LIT="`which lit`" \ -DLLVM_ENABLE_WERROR=On \ $GITHUB_WORKSPACE \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e6648ddf..8e8e1900e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,24 @@ cmake_minimum_required(VERSION 3.7.0) +enable_language(CXX) +set(CMAKE_CXX_EXTENSIONS NO) + include(GNUInstallDirs) +# MUST be done before call to clad project +get_cmake_property(_cache_vars CACHE_VARIABLES) +foreach(_cache_var ${_cache_vars}) + get_property(_helpstring CACHE ${_cache_var} PROPERTY HELPSTRING) + if(_helpstring STREQUAL + "No help, variable specified on the command line.") + set(CMAKE_ARGS "${CMAKE_ARGS} -D${_cache_var}=\"${${_cache_var}}\"") + endif() +endforeach() + +# Generate CMakeArgs.txt file with source, build dir and command line args +write_file("${CMAKE_CURRENT_BINARY_DIR}/CMakeArgs.txt" + "-S${CMAKE_SOURCE_DIR} -B${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_ARGS}") + if(POLICY CMP0075) cmake_policy(SET CMP0075 NEW) endif() diff --git a/demos/Arrays.cpp b/demos/Arrays.cpp index 7b44145be..53d6a2c2d 100644 --- a/demos/Arrays.cpp +++ b/demos/Arrays.cpp @@ -76,13 +76,13 @@ int main() { // the indexes of the array by using the format arr[0:] auto hessian_all = clad::hessian(weighted_avg, "arr[0:2], weights[0:2]"); // Generates the Hessian matrix for weighted_avg w.r.t. to arr. - auto hessian_arr = clad::hessian(weighted_avg, "arr[0:2]"); + // auto hessian_arr = clad::hessian(weighted_avg, "arr[0:2]"); double matrix_all[36] = {0}; - double matrix_arr[9] = {0}; + // double matrix_arr[9] = {0}; clad::array_ref matrix_all_ref(matrix_all, 36); - clad::array_ref matrix_arr_ref(matrix_arr, 9); + // clad::array_ref matrix_arr_ref(matrix_arr, 9); hessian_all.execute(arr, weights, matrix_all_ref); printf("Hessian Mode w.r.t. to all:\n matrix =\n" @@ -102,7 +102,7 @@ int main() { matrix_all[28], matrix_all[29], matrix_all[30], matrix_all[31], matrix_all[32], matrix_all[33], matrix_all[34], matrix_all[35]); - hessian_arr.execute(arr, weights, matrix_arr_ref); + /*hessian_arr.execute(arr, weights, matrix_arr_ref); printf("Hessian Mode w.r.t. to arr:\n matrix =\n" " {%.2g, %.2g, %.2g}\n" " {%.2g, %.2g, %.2g}\n" @@ -110,4 +110,5 @@ int main() { matrix_arr[0], matrix_arr[1], matrix_arr[2], matrix_arr[3], matrix_arr[4], matrix_arr[5], matrix_arr[6], matrix_arr[7], matrix_arr[8]); + */ } diff --git a/docs/userDocs/source/_static/vector-mode.png b/docs/userDocs/source/_static/vector-mode.png new file mode 100644 index 000000000..ea888bf80 Binary files /dev/null and b/docs/userDocs/source/_static/vector-mode.png differ diff --git a/docs/userDocs/source/user/CoreConcepts.rst b/docs/userDocs/source/user/CoreConcepts.rst index 642822900..b1c608dbf 100644 --- a/docs/userDocs/source/user/CoreConcepts.rst +++ b/docs/userDocs/source/user/CoreConcepts.rst @@ -127,6 +127,52 @@ Substituting `s = z` we will get `sz` = 1 Thus we don't need to run the program twice for each input. However, as mentioned above the only drawback is we need to re-run the program for a different output. + +Vectorized Forward Mode Automatic Differentiation +=================================================== + +Vectorized Forward Mode Automatic Differentiation is a computational technique +that combines two powerful concepts: vectorization and forward mode automatic +differentiation. This approach is used to efficiently compute derivatives of +functions with respect to multiple input variables by taking advantage of both +parallel processing capabilities and the structure of the computation graph. + +Working +-------- + +For computing gradient of a function with an n-dimensional input - forward mode +requires n forward passes. + +We can do this in a single forward pass, instead of accumulating a single +scalar value of derivative with respect to a particular node, we maintain a +gradient vector at each node. Although, the strategy is pretty similar, it requires +three passes for computing partial derivatives w.r.t. the three scalar inputs of +the function. + +At each node, we maintain a vector, storing the complete gradient of that node's +output w.r.t.. all the input parameters. All operations are now vector operations, +for example, applying the sum rule will result in the addition of vectors. +Initialization for input nodes are done using one-hot vectors. + +.. figure:: ../_static/vector-mode.png + :width: 600 + :align: center + :alt: Vectorized Forward Mode Automatic Differentiation + + Vectorized Forward Mode Automatic Differentiation to compute the gradient. + +Benefits +---------- + +We know that each node requires computing a vector, which requires more memory +and more time, which adds to these memory allocation calls. This must be offset +by some improvement in computing efficiency. + +This can prevent the recomputation of some expensive functions, which would have +executed in a non-vectorized version due to multiple forward passes. This approach +can take advantage of the hardware's vectorization and parallelization capabilities +using SIMD techniques. + Derived Function Types and Derivative Types ============================================= diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index 4ad838a31..f2c06bec2 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -21,23 +21,28 @@ enum order { third = 3, }; // enum order -enum opts { +enum opts : unsigned { use_enzyme = 1 << ORDER_BITS, vector_mode = 1 << (ORDER_BITS + 1), + + // Storing two bits for tbr analysis. + // 00 - default, 01 - enable, 10 - disable, 11 - not used / invalid + enable_tbr = 1 << (ORDER_BITS + 2), + disable_tbr = 1 << (ORDER_BITS + 3), }; // enum opts -constexpr unsigned GetDerivativeOrder(unsigned const bitmasked_opts) { +constexpr unsigned GetDerivativeOrder(const unsigned bitmasked_opts) { return bitmasked_opts & ORDER_MASK; } -constexpr bool HasOption(unsigned const bitmasked_opts, unsigned const option) { +constexpr bool HasOption(const unsigned bitmasked_opts, const unsigned option) { return (bitmasked_opts & option) == option; } constexpr unsigned GetBitmaskedOpts() { return 0; } -constexpr unsigned GetBitmaskedOpts(unsigned const first) { return first; } +constexpr unsigned GetBitmaskedOpts(const unsigned first) { return first; } template -constexpr unsigned GetBitmaskedOpts(unsigned const first, Opts... opts) { +constexpr unsigned GetBitmaskedOpts(const unsigned first, Opts... opts) { return first | GetBitmaskedOpts(opts...); } diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 2705cc447..7c108fa23 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -86,6 +86,12 @@ namespace clad { using DiffSchedule = llvm::SmallVector; using DiffInterval = std::vector; + struct RequestOptions { + /// This is a flag to indicate the default behaviour to enable/disable + /// TBR analysis during reverse-mode differentiation. + bool EnableTBRAnalysis = false; + }; + class DiffCollector: public clang::RecursiveASTVisitor { /// The source interval where clad was activated. /// @@ -101,9 +107,11 @@ namespace clad { const clang::FunctionDecl* m_TopMostFD = nullptr; clang::Sema& m_Sema; + RequestOptions& m_Options; + public: DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S); + DiffSchedule& plans, clang::Sema& S, RequestOptions& opts); bool VisitCallExpr(clang::CallExpr* E); private: diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index a686e35ef..b3e708b18 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -358,9 +358,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// \param[in] args independent parameters information /// \returns `CladFunction` object to access the corresponding derived /// function. - template , + template , typename = typename std::enable_if< !std::is_class>::value>::type> CladFunction, true> __attribute__(( @@ -376,9 +375,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// Specialization for differentiating functors. /// The specialization is needed because objects have to be passed /// by reference whereas functions have to be passed by value. - template , + template , typename = typename std::enable_if< std::is_class>::value>::type> CladFunction, true> __attribute__(( @@ -397,8 +395,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// \param[in] args independent parameters information /// \returns `CladFunction` object to access the corresponding derived /// function. - template , + template , typename = typename std::enable_if< !std::is_class>::value>::type> CladFunction> __attribute__(( @@ -406,18 +404,16 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { hessian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by hessian*/, - code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction>( + derivedFn /* will be replaced by hessian*/, code); } /// Specialization for differentiating functors. /// The specialization is needed because objects have to be passed /// by reference whereas functions have to be passed by value. - template , + template , typename = typename std::enable_if< std::is_class>::value>::type> CladFunction> __attribute__(( @@ -425,10 +421,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { hessian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by hessian*/, - code, f); + return CladFunction>( + derivedFn /* will be replaced by hessian*/, code, f); } /// Generates function which computes jacobian matrix of the given function @@ -438,8 +432,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// \param[in] args independent parameters information /// \returns `CladFunction` object to access the corresponding derived /// function. - template , + template , typename = typename std::enable_if< !std::is_class>::value>::type> CladFunction> __attribute__(( @@ -447,18 +441,16 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { jacobian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by Jacobian*/, - code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction>( + derivedFn /* will be replaced by Jacobian*/, code); } /// Specialization for differentiating functors. /// The specialization is needed because objects have to be passed /// by reference whereas functions have to be passed by value. - template , + template , typename = typename std::enable_if< std::is_class>::value>::type> CladFunction> __attribute__(( @@ -466,10 +458,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { jacobian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction< - DerivedFnType, - ExtractFunctorTraits_t>(derivedFn /* will be replaced by Jacobian*/, - code, f); + return CladFunction>( + derivedFn /* will be replaced by Jacobian*/, code, f); } template ; - /// Keeps a track of the delta error expression we shouldn't emit. - bool m_DoNotEmitDelta; /// Reference to the final error parameter in the augumented target /// function. clang::Expr* m_FinalError; @@ -42,23 +40,20 @@ class ErrorEstimationHandler : public ExternalRMVSource { Stmts m_ReverseErrorStmts; /// The index expression for emitting final errors for input param errors. clang::Expr* m_IdxExpr; - /// A set of declRefExprs for parameter value replacements. - std::unordered_map m_ParamRepls; /// An expression to match nested function call errors with their /// assignee (if any exists). clang::Expr* m_NestedFuncError = nullptr; std::stack m_ShouldEmit; ReverseModeVisitor* m_RMV; - clang::Expr* m_DeltaVar = nullptr; llvm::SmallVectorImpl* m_ParamTypes = nullptr; llvm::SmallVectorImpl* m_Params = nullptr; public: using direction = rmv::direction; ErrorEstimationHandler() - : m_DoNotEmitDelta(false), m_FinalError(nullptr), m_RetErrorExpr(nullptr), - m_EstModel(nullptr), m_IdxExpr(nullptr) {} + : m_FinalError(nullptr), m_RetErrorExpr(nullptr), m_EstModel(nullptr), + m_IdxExpr(nullptr) {} ~ErrorEstimationHandler() override = default; /// Function to set the error estimation model currently in use. @@ -70,33 +65,16 @@ class ErrorEstimationHandler : public ExternalRMVSource { /// \param[in] finErrExpr The final error expression. void SetFinalErrorExpr(clang::Expr* finErrExpr) { m_FinalError = finErrExpr; } - /// Shorthand to get array subscript expressions. - /// - /// \param[in] arrBase The base expression of the array. - /// \param[in] idx The index expression. - /// \param[in] isCladSpType Keeps track of if we have to build a clad - /// special type (i.e. clad::Array or clad::ArrayRef). - /// - /// \returns An expression of the kind arrBase[idx]. - clang::Expr* getArraySubscriptExpr(clang::Expr* arrBase, clang::Expr* idx, - bool isCladSpType = true); - - /// \returns The final error expression so far. - clang::Expr* GetFinalErrorExpr() { return m_FinalError; } - - /// Function to build the final error statemnt of the function. This is the - /// last statement of any target function in error estimation and - /// aggregates the error in all the registered variables. - void BuildFinalErrorStmt(); + /// Function to build the error statement corresponding + /// to the function's return statement. + void BuildReturnErrorStmt(); /// Function to emit error statements into the derivative body. /// - /// \param[in] var The variable whose error statement we want to emit. - /// \param[in] deltaVar The "_delta_" expression of the variable 'var'. - /// \param[in] errorExpr The error expression (LHS) of the variable 'var'. - /// \param[in] isInsideLoop A flag to indicate if 'val' is inside a loop. - void AddErrorStmtToBlock(clang::Expr* var, clang::Expr* deltaVar, - clang::Expr* errorExpr, bool isInsideLoop = false); + /// \param[in] errorExpr The error expression (LHS) of the variable. + /// \param[in] addToTheFront A flag to decide whether the error stmts + /// should be added to the beginning of the block or the current position. + void AddErrorStmtToBlock(clang::Expr* errorExpr, bool addToTheFront = true); /// Emit the error estimation related statements that were saved to be /// emitted at later points into specific blocks. @@ -124,44 +102,12 @@ class ErrorEstimationHandler : public ExternalRMVSource { llvm::SmallVectorImpl& CallArgs, llvm::SmallVectorImpl& ArgResultDecls, size_t numArgs); - /// Save values of registered variables so that they can be replaced - /// properly in case of re-assignments. - /// - /// \param[in] val The value to save. - /// \param[in] isInsideLoop A flag to indicate if 'val' is inside a loop. - /// - /// \returns The saved variable and its derivative. - StmtDiff SaveValue(clang::Expr* val, bool isInLoop = false); - - /// Save the orignal values of the input parameters in case of - /// re-assignments. - /// - /// \param[in] paramRef The DeclRefExpr of the input parameter. - void SaveParamValue(clang::DeclRefExpr* paramRef); - - /// Register variables to be used while accumulating error. - /// Register variable declarations so that they may be used while - /// calculating the final error estimates. Any unregistered variables will - /// not be considered for the final estimation. - /// - /// \param[in] VD The variable declaration to be registered. - /// \param[in] toCurrentScope Add the created "_delta_" variable declaration - /// to the current scope instead of the global scope. - /// - /// \returns The Variable declaration of the '_delta_' prefixed variable. - clang::Expr* RegisterVariable(clang::VarDecl* VD, - bool toCurrentScope = false); - - /// Checks if a variable can be registered for error estimation. + /// Checks if a variable should be considered in error estimation. /// - /// \param[in] VD The variable declaration to be registered. + /// \param[in] VD The variable declaration. /// - /// \returns True if the variable can be registered, false otherwise. - bool CanRegisterVariable(clang::VarDecl* VD); - - /// Calculate aggregate error from m_EstimateVar. - /// Builds the final error estimation statement. - clang::Stmt* CalculateAggregateError(); + /// \returns true if the variable should be considered, false otherwise. + bool ShouldEstimateErrorFor(clang::VarDecl* VD); /// Get the underlying DeclRefExpr type it it exists. /// @@ -170,14 +116,6 @@ class ErrorEstimationHandler : public ExternalRMVSource { /// \returns The DeclRefExpr of input or null. clang::DeclRefExpr* GetUnderlyingDeclRefOrNull(clang::Expr* expr); - /// Get the parameter replacement (if any). - /// - /// \param[in] VD The parameter variable declaration to get replacement - /// for. - /// - /// \returns The underlying replaced Expr. - clang::Expr* GetParamReplacement(const clang::ParmVarDecl* VD); - /// An abstraction of the error estimation model's AssignError. /// /// \param[in] val The variable to get the error for. @@ -190,16 +128,6 @@ class ErrorEstimationHandler : public ExternalRMVSource { return m_EstModel->AssignError({var, varDiff}, varName); } - /// An abstraction of the error estimation model's IsVariableRegistered. - /// - /// \param[in] VD The variable declaration to check the status of. - /// - /// \returns the reference to the respective '_delta_' expression if the - /// variable is registered, null otherwise. - clang::Expr* IsRegistered(clang::VarDecl* VD) { - return m_EstModel->IsVariableRegistered(VD); - } - /// This function adds the final error and the other parameter errors to the /// forward block. /// @@ -215,17 +143,6 @@ class ErrorEstimationHandler : public ExternalRMVSource { /// loop. void EmitUnaryOpErrorStmts(StmtDiff var, bool isInsideLoop); - /// This function registers all LHS declRefExpr in binary operations. - /// - /// \param[in] LExpr The LHS of the operation. - /// \param[in] RExpr The RHS of the operation. - /// \param[in] isAssign A flag to know if the current operation is a simple - /// assignment. - /// - /// \returns The delta value of the input 'var'. - clang::Expr* RegisterBinaryOpLHS(clang::Expr* LExpr, clang::Expr* RExpr, - bool isAssign); - /// This function emits the error in a binary operation. /// /// \param[in] LExpr The LHS of the operation. @@ -234,8 +151,7 @@ class ErrorEstimationHandler : public ExternalRMVSource { /// \param[in] deltaVar The delta value of the LHS. /// \param[in] isInsideLoop A flag to keep track of if we are inside a /// loop. - void EmitBinaryOpErrorStmts(clang::Expr* LExpr, clang::Expr* oldValue, - clang::Expr* deltaVar, bool isInsideLoop); + void EmitBinaryOpErrorStmts(clang::Expr* LExpr, clang::Expr* oldValue); /// This function emits the error in declaration statements. /// @@ -256,21 +172,19 @@ class ErrorEstimationHandler : public ExternalRMVSource { void ActBeforeDifferentiatingStmtInVisitCompoundStmt() override; void ActAfterProcessingStmtInVisitCompoundStmt() override; void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override; - void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override; + void ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() override; void ActBeforeDifferentiatingLoopInitStmt() override; void ActBeforeDifferentiatingSingleStmtLoopBody() override; void ActAfterProcessingSingleStmtBodyInVisitForLoop() override; - void ActBeforeFinalisingVisitReturnStmt(StmtDiff& retExprDiff) override; - void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) override; + void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) override; + void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override; void ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& fnDecl, llvm::SmallVectorImpl& derivedCallArgs, llvm::SmallVectorImpl& ArgResultDecls, bool asGrad) override; - void - ActAfterCloningLHSOfAssignOp(clang::Expr*& LCloned, clang::Expr*& R, - clang::BinaryOperator::Opcode& opCode) override; - void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&) override; + void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&, + clang::BinaryOperator::Opcode&) override; void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override; void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override; void ActBeforeDifferentiatingCallExpr( diff --git a/include/clad/Differentiator/EstimationModel.h b/include/clad/Differentiator/EstimationModel.h index 5ddb30715..2cb87fbde 100644 --- a/include/clad/Differentiator/EstimationModel.h +++ b/include/clad/Differentiator/EstimationModel.h @@ -34,20 +34,6 @@ namespace clad { /// Clear the variable estimate map so that we can start afresh. void clearEstimationVariables() { m_EstimateVar.clear(); } - /// Check if a variable is registered for estimation. - /// - /// \param[in] VD The variable to check. - /// - /// \returns The delta expression of the variable if it is registered, - /// nullptr otherwise. - clang::Expr* IsVariableRegistered(const clang::VarDecl* VD); - - /// Track the variable declaration and utilize it in error - /// estimation. - /// - /// \param[in] VD The declaration to track. - void AddVarToEstimate(clang::VarDecl* VD, clang::Expr* VDRef); - /// Helper to build a function call expression. /// /// \param[in] funcName The name of the function to build the expression @@ -86,31 +72,6 @@ namespace clad { virtual clang::Expr* AssignError(StmtDiff refExpr, const std::string& name) = 0; - /// Initializes errors for '_delta_' statements. - /// This function returns the initial error assignment. Similar to - /// AssignError, however, this function is only called during declaration of - /// variables. This function is separate from AssignError to keep - /// implementation of different estimation models more flexible. - /// - /// The default definition is as follows: - /// \n \code - /// clang::Expr* SetError(clang::VarDecl* declStmt) { - /// return nullptr; - /// } - /// \endcode - /// The above will return a 0 expression to be assigned to the '_delta_' - /// declaration of input decl. - /// - /// \param[in] decl The declaration to which the error has to be assigned. - /// - /// \returns The error expression for declaration statements. - virtual clang::Expr* SetError(clang::VarDecl* decl); - - /// Calculate aggregate error from m_EstimateVar. - /// - /// \returns the final error estimation statement. - clang::Expr* CalculateAggregateError(); - friend class ErrorEstimationHandler; }; diff --git a/include/clad/Differentiator/ExternalRMVSource.h b/include/clad/Differentiator/ExternalRMVSource.h index b5fcde1ab..4879cc18a 100644 --- a/include/clad/Differentiator/ExternalRMVSource.h +++ b/include/clad/Differentiator/ExternalRMVSource.h @@ -104,7 +104,7 @@ class ExternalRMVSource { /// This is called just before finalising processing of Single statement /// branch in `VisitBranch` lambda in - virtual void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() {} + virtual void ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() {} /// This is called just before differentiating init statement of loops. virtual void ActBeforeDifferentiatingLoopInitStmt() {} @@ -117,7 +117,7 @@ class ExternalRMVSource { virtual void ActAfterProcessingSingleStmtBodyInVisitForLoop() {} /// This is called just before finalising `VisitReturnStmt`. - virtual void ActBeforeFinalisingVisitReturnStmt(StmtDiff& retExprDiff) {} + virtual void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) {} /// This ic called just before finalising `VisitCallExpr`. /// @@ -131,15 +131,17 @@ class ExternalRMVSource { /// This is called just before finalising processing of post and pre /// increment and decrement operations. - virtual void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff){}; + virtual void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff){}; /// This is called just after cloning of LHS assignment operation. virtual void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&, clang::BinaryOperatorKind& opCode) { } - /// This is called just after finaising processing of assignment operator. - virtual void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&){}; + /// This is called just after finalising processing of assignment operator. + virtual void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, + clang::Expr*&, + clang::BinaryOperator::Opcode&){}; /// This is called at that beginning of /// `ReverseModeVisitor::DifferentiateSingleStmt`. diff --git a/include/clad/Differentiator/MultiplexExternalRMVSource.h b/include/clad/Differentiator/MultiplexExternalRMVSource.h index 7864828b0..2e1f35f8b 100644 --- a/include/clad/Differentiator/MultiplexExternalRMVSource.h +++ b/include/clad/Differentiator/MultiplexExternalRMVSource.h @@ -40,20 +40,21 @@ class MultiplexExternalRMVSource : public ExternalRMVSource { void ActBeforeDifferentiatingStmtInVisitCompoundStmt() override; void ActAfterProcessingStmtInVisitCompoundStmt() override; void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override; - void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override; + void ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() override; void ActBeforeDifferentiatingLoopInitStmt() override; void ActBeforeDifferentiatingSingleStmtLoopBody() override; void ActAfterProcessingSingleStmtBodyInVisitForLoop() override; - void ActBeforeFinalisingVisitReturnStmt(StmtDiff& retExprDiff) override; + void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) override; void ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn, llvm::SmallVectorImpl& derivedCallArgs, llvm::SmallVectorImpl& ArgResultDecls, bool asGrad) override; - void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) override; + void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override; void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&, clang::BinaryOperatorKind& opCode) override; - void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&) override; + void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&, + clang::BinaryOperator::Opcode&) override; void ActOnStartOfDifferentiateSingleStmt() override; void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override; void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override; diff --git a/include/clad/Differentiator/Sins.h b/include/clad/Differentiator/Sins.h new file mode 100644 index 000000000..28983d626 --- /dev/null +++ b/include/clad/Differentiator/Sins.h @@ -0,0 +1,29 @@ +#ifndef CLAD_DIFFERENTIATOR_SINS_H +#define CLAD_DIFFERENTIATOR_SINS_H + +#include + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { \ + friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \ + }; \ + template struct Only_##MEMBER; \ + template <> struct Only_##MEMBER { \ + friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); \ + }; \ + template struct CONCATE(MEMBER, \ + __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER< \ + std::remove_reference::type>*)nullptr) + +// NOLINTEND(cppcoreguidelines-macro-usage) + +#endif // CLAD_DIFFERENTIATOR_SINS_H diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index b6c95662a..eabcac02f 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -424,6 +424,13 @@ namespace clad { clang::Expr* BuildArraySubscript(clang::Expr* Base, const llvm::SmallVectorImpl& IS); + + /// Build an array subscript expression with a given base expression and + /// one index. + clang::Expr* BuildArraySubscript(clang::Expr* Base, clang::Expr*& Idx) { + llvm::SmallVector IS = {Idx}; + return BuildArraySubscript(Base, IS); + } /// Find namespace clad declaration. clang::NamespaceDecl* GetCladNamespace(); /// Find declaration of clad::class templated type diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 87ec8df05..2654562e0 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -232,9 +232,10 @@ namespace clad { } DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S) + DiffSchedule& plans, clang::Sema& S, + RequestOptions& opts) : m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr), - m_Sema(S) { + m_Sema(S), m_Options(opts) { if (Interval.empty()) return; @@ -556,12 +557,13 @@ namespace clad { return true; DiffRequest request{}; - if (A->getAnnotation().equals("D")) { - request.Mode = DiffMode::forward; - - // bitmask_opts is a template pack of unsigned integers, so we need to - // do bitwise or of all the values to get the final value. - unsigned bitmasked_opts_value = 0; + // bitmask_opts is a template pack of unsigned integers, so we need to + // do bitwise or of all the values to get the final value. + unsigned bitmasked_opts_value = 0; + bool enable_tbr_in_req = false; + bool disable_tbr_in_req = false; + if (!A->getAnnotation().equals("E") && + FD->getTemplateSpecializationArgs()) { const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); if (template_arg.getKind() == TemplateArgument::Pack) for (const auto& arg : @@ -569,14 +571,39 @@ namespace clad { bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); else bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); + + // Set option for TBR analysis. + enable_tbr_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr); + disable_tbr_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr); + if (enable_tbr_in_req && disable_tbr_in_req) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "Both enable and disable TBR options are specified."); + return true; + } + if (enable_tbr_in_req || disable_tbr_in_req) { + // override the default value of TBR analysis. + request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req; + } else { + request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis; + } + } + + if (A->getAnnotation().equals("D")) { + request.Mode = DiffMode::forward; unsigned derivative_order = clad::GetDerivativeOrder(bitmasked_opts_value); if (derivative_order == 0) { derivative_order = 1; // default to first order derivative. } request.RequestedDerivativeOrder = derivative_order; - if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) { + if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) request.use_enzyme = true; + if (enable_tbr_in_req) { + utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, + "TBR analysis is not meant for forward mode AD."); + return true; } if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) { request.Mode = DiffMode::vector_forward_mode; @@ -601,17 +628,6 @@ namespace clad { request.Mode = DiffMode::jacobian; } else if (A->getAnnotation().equals("G")) { request.Mode = DiffMode::reverse; - - // bitmask_opts is a template pack of unsigned integers, so we need to - // do bitwise or of all the values to get the final value. - unsigned bitmasked_opts_value = 0; - const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); - if (template_arg.getKind() == TemplateArgument::Pack) - for (const auto& arg : - FD->getTemplateSpecializationArgs()->get(0).pack_elements()) - bitmasked_opts_value |= arg.getAsIntegral().getExtValue(); - else - bitmasked_opts_value = template_arg.getAsIntegral().getExtValue(); if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme)) request.use_enzyme = true; // reverse vector mode is not yet supported. diff --git a/lib/Differentiator/ErrorEstimator.cpp b/lib/Differentiator/ErrorEstimator.cpp index cbab2a49e..358e36cfb 100644 --- a/lib/Differentiator/ErrorEstimator.cpp +++ b/lib/Differentiator/ErrorEstimator.cpp @@ -17,8 +17,7 @@ QualType getUnderlyingArrayType(QualType baseType, ASTContext& C) { } else if (auto PTType = baseType->getAs()) { return PTType->getPointeeType(); } - assert(0 && "Unreachable"); - return {}; + return baseType; } Expr* UpdateErrorForFuncCallAssigns(ErrorEstimationHandler* handler, @@ -39,65 +38,30 @@ void ErrorEstimationHandler::SetErrorEstimationModel( m_EstModel = estModel; } -Expr* ErrorEstimationHandler::getArraySubscriptExpr( - Expr* arrBase, Expr* idx, bool isCladSpType /*=true*/) { - if (isCladSpType) { - return m_RMV->m_Sema - .ActOnArraySubscriptExpr(m_RMV->getCurrentScope(), arrBase, - arrBase->getExprLoc(), idx, noLoc) - .get(); - } else { - return m_RMV->m_Sema - .CreateBuiltinArraySubscriptExpr(arrBase, noLoc, idx, noLoc) - .get(); - } -} - -void ErrorEstimationHandler::BuildFinalErrorStmt() { - Expr* finExpr = nullptr; +void ErrorEstimationHandler::BuildReturnErrorStmt() { // If we encountered any arithmetic expression in the return statement, // we must add its error to the final estimate. if (m_RetErrorExpr) { auto flitr = FloatingLiteral::Create(m_RMV->m_Context, llvm::APFloat(1.0), true, m_RMV->m_Context.DoubleTy, noLoc); - finExpr = + Expr* finExpr = m_EstModel->AssignError(StmtDiff(m_RetErrorExpr, flitr), "return_expr"); + m_RMV->addToCurrentBlock( + m_RMV->BuildOp(BO_AddAssign, m_FinalError, finExpr), + direction::forward); } - - // Build the final error statement with the sum of all _delta_*. - Expr* addErrorExpr = m_EstModel->CalculateAggregateError(); - if (addErrorExpr) { - if (finExpr) - addErrorExpr = m_RMV->BuildOp(BO_Add, addErrorExpr, finExpr); - } else if (finExpr) { - addErrorExpr = finExpr; - } - - // Finally add the final error expression to the derivative body. - m_RMV->addToCurrentBlock( - m_RMV->BuildOp(BO_AddAssign, m_FinalError, addErrorExpr), - direction::forward); } -void ErrorEstimationHandler::AddErrorStmtToBlock(Expr* var, Expr* deltaVar, - Expr* errorExpr, - bool isInsideLoop /*=false*/) { - - if (auto ASE = dyn_cast(var)) { - deltaVar = getArraySubscriptExpr(deltaVar, ASE->getIdx()); - m_RMV->addToCurrentBlock(m_RMV->BuildOp(BO_AddAssign, deltaVar, errorExpr), - direction::reverse); - // immediately emit fin_err += delta_[]. - // This is done to avoid adding all errors at the end - // and only add the errors that were calculated. - m_RMV->addToCurrentBlock( - m_RMV->BuildOp(BO_AddAssign, m_FinalError, deltaVar), - direction::reverse); - - } else - m_RMV->addToCurrentBlock(m_RMV->BuildOp(BO_AddAssign, deltaVar, errorExpr), - direction::reverse); +void ErrorEstimationHandler::AddErrorStmtToBlock(Expr* errorExpr, + bool addToTheFront) { + Stmt* errorStmt = m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr); + if (addToTheFront) { + auto& block = m_RMV->getCurrentBlock(direction::reverse); + block.insert(block.begin(), errorStmt); + } else { + m_RMV->addToCurrentBlock(errorStmt, direction::reverse); + } } void ErrorEstimationHandler::EmitErrorEstimationStmts( @@ -152,124 +116,11 @@ void ErrorEstimationHandler::EmitNestedFunctionParamError( } } -StmtDiff ErrorEstimationHandler::SaveValue(Expr* val, - bool isInsideLoop /*=false*/) { - // Definite not null. - DeclRefExpr* declRefVal = GetUnderlyingDeclRefOrNull(val); - assert(declRefVal && "Val cannot be null!"); - std::string name = "_EERepl_" + declRefVal->getDecl()->getNameAsString(); - if (isInsideLoop) { - auto tape = m_RMV->MakeCladTapeFor(val, name); - m_ForwardReplStmts.push_back(tape.Push); - // Nice to store pop values becuase user might refer to getExpr - // multiple times in Assign Error. - Expr* popVal = m_RMV->StoreAndRef(tape.Pop, direction::reverse); - return StmtDiff(tape.Push, popVal); - } else { - QualType QTval = val->getType(); - if (auto AType = dyn_cast(QTval)) - QTval = AType->getElementType(); - - auto savedVD = m_RMV->GlobalStoreImpl(QTval, name); - auto savedRef = m_RMV->BuildDeclRef(savedVD); - m_ForwardReplStmts.push_back(m_RMV->BuildOp(BO_Assign, savedRef, val)); - return StmtDiff(savedRef, savedRef); - } -} - -void ErrorEstimationHandler::SaveParamValue(DeclRefExpr* paramRef) { - assert(paramRef && "Must have a value"); - VarDecl* paramDecl = cast(paramRef->getDecl()); - QualType paramType = paramRef->getType(); - std::string name = "_EERepl_" + paramDecl->getNameAsString(); - VarDecl* savedDecl; - if (utils::isArrayOrPointerType(paramType)) { - auto diffVar = m_RMV->m_Variables[paramDecl]; - auto QType = m_RMV->GetCladArrayOfType( - getUnderlyingArrayType(paramType, m_RMV->m_Context)); - savedDecl = m_RMV->BuildVarDecl( - QType, name, m_RMV->BuildArrayRefSizeExpr(diffVar), - /*DirectInit=*/false, - /*TSI=*/nullptr, VarDecl::InitializationStyle::CallInit); - m_RMV->AddToGlobalBlock(m_RMV->BuildDeclStmt(savedDecl)); - ReverseModeVisitor::Stmts loopBody; - // Get iter variable. - auto loopIdx = - m_RMV->BuildVarDecl(m_RMV->m_Context.IntTy, "i", - m_RMV->getZeroInit(m_RMV->m_Context.IntTy)); - auto currIdx = m_RMV->BuildDeclRef(loopIdx); - // Build the assign expression. - loopBody.push_back(m_RMV->BuildOp( - BO_Assign, - getArraySubscriptExpr(m_RMV->BuildDeclRef(savedDecl), currIdx), - getArraySubscriptExpr(paramRef, currIdx, - /*isCladSpType=*/false))); - Expr* conditionExpr = - m_RMV->BuildOp(BO_LT, currIdx, m_RMV->BuildArrayRefSizeExpr(diffVar)); - Expr* incExpr = m_RMV->BuildOp(UO_PostInc, currIdx); - // Make for loop. - Stmt* ArrayParamLoop = new (m_RMV->m_Context) ForStmt( - m_RMV->m_Context, m_RMV->BuildDeclStmt(loopIdx), conditionExpr, nullptr, - incExpr, m_RMV->MakeCompoundStmt(loopBody), noLoc, noLoc, noLoc); - m_RMV->AddToGlobalBlock(ArrayParamLoop); - } else - savedDecl = m_RMV->GlobalStoreImpl(paramType, name, paramRef); - m_ParamRepls.emplace(paramDecl, m_RMV->BuildDeclRef(savedDecl)); -} - -Expr* ErrorEstimationHandler::RegisterVariable(VarDecl* VD, - bool toCurrentScope /*=false*/) { - if (!CanRegisterVariable(VD)) - return nullptr; - // Get the init error from setError. - Expr* init = m_EstModel->SetError(VD); - auto VDType = VD->getType(); - // The type of the _delta_ value should be customisable. - QualType QType; - Expr* deltaVar = nullptr; - auto diffVar = m_RMV->m_Variables[VD]; - if (m_RMV->isCladArrayType(diffVar->getType())) { - VarDecl* EstVD; - auto sizeExpr = m_RMV->BuildArrayRefSizeExpr(diffVar); - QType = m_RMV->GetCladArrayOfType( - getUnderlyingArrayType(VDType, m_RMV->m_Context)); - EstVD = m_RMV->BuildVarDecl( - QType, "_delta_" + VD->getNameAsString(), sizeExpr, - /*DirectInit=*/false, - /*TSI=*/nullptr, VarDecl::InitializationStyle::CallInit); - if (!toCurrentScope) - m_RMV->AddToGlobalBlock(m_RMV->BuildDeclStmt(EstVD)); - else - m_RMV->addToCurrentBlock(m_RMV->BuildDeclStmt(EstVD), direction::forward); - deltaVar = m_RMV->BuildDeclRef(EstVD); - } else { - QType = utils::isArrayOrPointerType(VDType) ? VDType - : m_RMV->m_Context.DoubleTy; - init = init ? init : m_RMV->getZeroInit(QType); - // Store the "_delta_*" value. - if (!toCurrentScope) { - auto EstVD = m_RMV->GlobalStoreImpl( - QType, "_delta_" + VD->getNameAsString(), init); - deltaVar = m_RMV->BuildDeclRef(EstVD); - } else { - deltaVar = m_RMV->StoreAndRef(init, QType, direction::forward, - "_delta_" + VD->getNameAsString(), - /*forceDeclCreation=*/true); - } - } - // Register the variable for estimate calculation. - m_EstModel->AddVarToEstimate(VD, deltaVar); - return deltaVar; -} - -bool ErrorEstimationHandler::CanRegisterVariable(VarDecl* VD) { +bool ErrorEstimationHandler::ShouldEstimateErrorFor(VarDecl* VD) { // Get the types on the declartion and initalization expression. QualType varDeclBase = VD->getType(); - QualType varDeclType = - utils::isArrayOrPointerType(varDeclBase) - ? getUnderlyingArrayType(varDeclBase, m_RMV->m_Context) - : varDeclBase; + QualType varDeclType = getUnderlyingArrayType(varDeclBase, m_RMV->m_Context); const Expr* init = VD->getInit(); // If declarationg type in not floating point type, we want to do two // things. @@ -315,44 +166,19 @@ DeclRefExpr* ErrorEstimationHandler::GetUnderlyingDeclRefOrNull(Expr* expr) { return dyn_cast(expr->IgnoreImplicit()); } -Expr* ErrorEstimationHandler::GetParamReplacement(const ParmVarDecl* VD) { - auto it = m_ParamRepls.find(VD); - if (it != m_ParamRepls.end()) - return it->second; - return nullptr; -} - void ErrorEstimationHandler::EmitFinalErrorStmts( llvm::SmallVectorImpl& params, unsigned numParams) { // Emit error variables of parameters at the end. for (size_t i = 0; i < numParams; i++) { - // Right now, we just ignore them since we have no way of knowing - // the size of an array. - // if (m_RMV->isArrayOrPointerType(params[i]->getType())) - // continue; - - // Check if the declaration was registered - auto decl = dyn_cast(params[i]); - Expr* deltaVar = IsRegistered(decl); - - // If not registered, check if it is eligible for registration and do - // the needful. - if (!deltaVar) { - deltaVar = RegisterVariable(decl, /*toCurrentScope=*/true); - } - - // If till now, we have a delta declaration, emit it into the code. - if (deltaVar) { + auto* decl = cast(params[i]); + if (ShouldEstimateErrorFor(decl)) { if (!m_RMV->isArrayOrPointerType(params[i]->getType())) { - // Since we need the input value of x, check for a replacement. - // If no replacement found, use the actual declRefExpr. - auto savedVal = GetParamReplacement(params[i]); - savedVal = savedVal ? savedVal : m_RMV->BuildDeclRef(decl); + auto* paramClone = m_RMV->BuildDeclRef(decl); // Finally emit the error. - auto errorExpr = GetError(savedVal, m_RMV->m_Variables[decl], - params[i]->getNameAsString()); + auto* errorExpr = GetError(paramClone, m_RMV->m_Variables[decl], + params[i]->getNameAsString()); m_RMV->addToCurrentBlock( - m_RMV->BuildOp(BO_AddAssign, deltaVar, errorExpr)); + m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr)); } else { auto LdiffExpr = m_RMV->m_Variables[decl]; VarDecl* idxExprDecl = nullptr; @@ -363,32 +189,20 @@ void ErrorEstimationHandler::EmitFinalErrorStmts( m_RMV->getZeroInit(m_RMV->m_Context.IntTy)); m_IdxExpr = m_RMV->BuildDeclRef(idxExprDecl); } - Expr *Ldiff, *Ldelta; - Ldiff = getArraySubscriptExpr( - LdiffExpr, m_IdxExpr, m_RMV->isCladArrayType(LdiffExpr->getType())); - Ldelta = getArraySubscriptExpr(deltaVar, m_IdxExpr); - auto savedVal = GetParamReplacement(params[i]); - savedVal = savedVal ? savedVal : m_RMV->BuildDeclRef(decl); - auto LRepl = getArraySubscriptExpr(savedVal, m_IdxExpr); + Expr* Ldiff = nullptr; + Ldiff = m_RMV->BuildArraySubscript(LdiffExpr, m_IdxExpr); + auto* paramClone = m_RMV->BuildDeclRef(decl); + auto* LRepl = m_RMV->BuildArraySubscript(paramClone, m_IdxExpr); // Build the loop to put in reverse mode. Expr* errorExpr = GetError(LRepl, Ldiff, params[i]->getNameAsString()); - auto commonVarDecl = - m_RMV->BuildVarDecl(errorExpr->getType(), "_t", errorExpr); - Expr* commonVarExpr = m_RMV->BuildDeclRef(commonVarDecl); - Expr* deltaAssignExpr = - m_RMV->BuildOp(BO_AddAssign, Ldelta, commonVarExpr); Expr* finalAssignExpr = - m_RMV->BuildOp(BO_AddAssign, m_FinalError, commonVarExpr); - ReverseModeVisitor::Stmts loopBody; - loopBody.push_back(m_RMV->BuildDeclStmt(commonVarDecl)); - loopBody.push_back(deltaAssignExpr); - loopBody.push_back(finalAssignExpr); + m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr); Expr* conditionExpr = m_RMV->BuildOp( BO_LT, m_IdxExpr, m_RMV->BuildArrayRefSizeExpr(LdiffExpr)); Expr* incExpr = m_RMV->BuildOp(UO_PostInc, m_IdxExpr); Stmt* ArrayParamLoop = new (m_RMV->m_Context) ForStmt(m_RMV->m_Context, nullptr, conditionExpr, nullptr, incExpr, - m_RMV->MakeCompoundStmt(loopBody), noLoc, noLoc, noLoc); + finalAssignExpr, noLoc, noLoc, noLoc); // For multiple array parameters, we want to keep the same // iterative variable, so reset that here in the case that this // is not out first array. @@ -403,7 +217,7 @@ void ErrorEstimationHandler::EmitFinalErrorStmts( } } } - BuildFinalErrorStmt(); + BuildReturnErrorStmt(); } void ErrorEstimationHandler::EmitUnaryOpErrorStmts(StmtDiff var, @@ -412,69 +226,24 @@ void ErrorEstimationHandler::EmitUnaryOpErrorStmts(StmtDiff var, if (DeclRefExpr* DRE = GetUnderlyingDeclRefOrNull(var.getExpr())) { // First check if it was registered. // If not, we don't care about it. - if (auto deltaVar = IsRegistered(cast(DRE->getDecl()))) { - // Create a variable/tape call to store the current value of the - // the sub-expression so that it can be used later. - StmtDiff savedVar = m_RMV->GlobalStoreAndRef( - DRE, "_EERepl_" + DRE->getDecl()->getNameAsString()); - if (isInsideLoop) { - // It is nice to save the pop value. - // We do not know how many times the user will use dx, - // hence we should pop values beforehand to avoid unequal pushes - // and and pops. - Expr* popVal = - m_RMV->StoreAndRef(savedVar.getExpr_dx(), direction::reverse); - savedVar = {savedVar.getExpr(), popVal}; - } - Expr* erroExpr = GetError(savedVar.getExpr_dx(), var.getExpr_dx(), - DRE->getDecl()->getNameAsString()); - AddErrorStmtToBlock(var.getExpr(), deltaVar, erroExpr, isInsideLoop); + if (ShouldEstimateErrorFor(cast(DRE->getDecl()))) { + Expr* erroExpr = + GetError(DRE, var.getExpr_dx(), DRE->getDecl()->getNameAsString()); + AddErrorStmtToBlock(erroExpr); } } } -Expr* ErrorEstimationHandler::RegisterBinaryOpLHS(Expr* LExpr, Expr* RExpr, - bool isAssign) { - DeclRefExpr* LRef = GetUnderlyingDeclRefOrNull(LExpr); - DeclRefExpr* RRef = GetUnderlyingDeclRefOrNull(RExpr); - VarDecl* Ldecl = LRef ? dyn_cast(LRef->getDecl()) : nullptr; - // In the case that an RHS expression is a declReference, we do not emit - // any error because the assignment operation entials zero error. - // However, for compound assignment operators, the RHS may be a - // declRefExpr but here we will need to emit its error. - // This variable checks for the above conditions. - bool declRefOk = !RRef || !isAssign; - Expr* deltaVar = nullptr; - // If the LHS can be decayed to a VarDecl and all other requirements - // are met, we should register the variable if it has not been already. - // We also do not support array input types yet. - if (Ldecl && declRefOk) { - deltaVar = IsRegistered(Ldecl); - // Usually we would expect independent variable to qualify for these - // checks. - if (!deltaVar) { - deltaVar = RegisterVariable(Ldecl); - SaveParamValue(LRef); - } - } - return deltaVar; -} - -void ErrorEstimationHandler::EmitBinaryOpErrorStmts(Expr* LExpr, Expr* oldValue, - Expr* deltaVar, - bool isInsideLoop) { - if (!deltaVar) - return; - // For now save all lhs. - // FIXME: We can optimize stores here by using the ones created - // previously. - StmtDiff savedExpr = SaveValue(LExpr, isInsideLoop); +void ErrorEstimationHandler::EmitBinaryOpErrorStmts(Expr* LExpr, + Expr* oldValue) { // Assign the error. auto decl = GetUnderlyingDeclRefOrNull(LExpr)->getDecl(); - Expr* errorExpr = - UpdateErrorForFuncCallAssigns(this, savedExpr.getExpr_dx(), oldValue, - m_NestedFuncError, decl->getNameAsString()); - AddErrorStmtToBlock(LExpr, deltaVar, errorExpr, isInsideLoop); + if (!ShouldEstimateErrorFor(cast(decl))) + return; + bool errorFromFunctionCall = (bool)m_NestedFuncError; + Expr* errorExpr = UpdateErrorForFuncCallAssigns( + this, LExpr, oldValue, m_NestedFuncError, decl->getNameAsString()); + AddErrorStmtToBlock(errorExpr, /*addToTheFront=*/!errorFromFunctionCall); // If there are assign statements to emit in reverse, do that. EmitErrorEstimationStmts(direction::reverse); } @@ -482,21 +251,19 @@ void ErrorEstimationHandler::EmitBinaryOpErrorStmts(Expr* LExpr, Expr* oldValue, void ErrorEstimationHandler::EmitDeclErrorStmts(VarDeclDiff VDDiff, bool isInsideLoop) { auto VD = VDDiff.getDecl(); - if (!CanRegisterVariable(VD)) + if (!ShouldEstimateErrorFor(VD)) return; // Build the delta expresion for the variable to be registered. - auto EstVD = RegisterVariable(VD); DeclRefExpr* VDRef = m_RMV->BuildDeclRef(VD); // FIXME: We should do this for arrays too. if (!VD->getType()->isArrayType()) { - StmtDiff savedDecl = SaveValue(VDRef, isInsideLoop); // If the VarDecl has an init, we should assign it with an error. if (VD->getInit() && !GetUnderlyingDeclRefOrNull(VD->getInit())) { + bool errorFromFunctionCall = (bool)m_NestedFuncError; Expr* errorExpr = UpdateErrorForFuncCallAssigns( - this, savedDecl.getExpr_dx(), - m_RMV->BuildDeclRef(VDDiff.getDecl_dx()), m_NestedFuncError, - VD->getNameAsString()); - AddErrorStmtToBlock(VDRef, EstVD, errorExpr, isInsideLoop); + this, VDRef, m_RMV->BuildDeclRef(VDDiff.getDecl_dx()), + m_NestedFuncError, VD->getNameAsString()); + AddErrorStmtToBlock(errorExpr, /*addToTheFront=*/!errorFromFunctionCall); } } } @@ -568,7 +335,7 @@ void ErrorEstimationHandler:: } void ErrorEstimationHandler:: - ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() { + ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() { // In error estimation, manually emit the code here instead of // DifferentiateSingleStmt to maintain correct order. EmitErrorEstimationStmts(direction::forward); @@ -587,7 +354,7 @@ void ErrorEstimationHandler::ActAfterProcessingSingleStmtBodyInVisitForLoop() { EmitErrorEstimationStmts(direction::forward); } -void ErrorEstimationHandler::ActBeforeFinalisingVisitReturnStmt( +void ErrorEstimationHandler::ActBeforeFinalizingVisitReturnStmt( StmtDiff& retExprDiff) { // If the return expression is not a DeclRefExpression and is of type // float, we should add it to the error estimate because returns are @@ -595,7 +362,7 @@ void ErrorEstimationHandler::ActBeforeFinalisingVisitReturnStmt( SaveReturnExpr(retExprDiff.getExpr()); } -void ErrorEstimationHandler::ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) { +void ErrorEstimationHandler::ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) { EmitUnaryOpErrorStmts(diff, m_RMV->isInsideLoop); } @@ -620,18 +387,17 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr( } } -void ErrorEstimationHandler::ActAfterCloningLHSOfAssignOp( - clang::Expr*& LCloned, clang::Expr*& R, +void ErrorEstimationHandler::ActBeforeFinalizingAssignOp( + clang::Expr*& LCloned, clang::Expr*& oldValue, clang::Expr*& R, clang::BinaryOperator::Opcode& opCode) { - m_DeltaVar = RegisterBinaryOpLHS(LCloned, R, - /*isAssign=*/opCode == BO_Assign); -} - -void ErrorEstimationHandler::ActBeforeFinalisingAssignOp( - clang::Expr*& LCloned, clang::Expr*& oldValue) { - // Now, we should emit the delta for LHS if it met all the - // requirements previously. - EmitBinaryOpErrorStmts(LCloned, oldValue, m_DeltaVar, m_RMV->isInsideLoop); + DeclRefExpr* RRef = GetUnderlyingDeclRefOrNull(R); + // In the case that an RHS expression is a declReference, we do not emit + // any error because the assignment operation entials zero error. + // However, for compound assignment operators, the RHS may be a + // declRefExpr but here we will need to emit its error. + // This checks for the above conditions. + if (opCode != BO_Assign || !RRef) + EmitBinaryOpErrorStmts(LCloned, oldValue); } void ErrorEstimationHandler::ActBeforeFinalizingDifferentiateSingleStmt( diff --git a/lib/Differentiator/EstimationModel.cpp b/lib/Differentiator/EstimationModel.cpp index e72963c6f..da92afc04 100644 --- a/lib/Differentiator/EstimationModel.cpp +++ b/lib/Differentiator/EstimationModel.cpp @@ -15,44 +15,6 @@ namespace clad { FPErrorEstimationModel::~FPErrorEstimationModel() {} - Expr* FPErrorEstimationModel::IsVariableRegistered(const VarDecl* VD) { - auto it = m_EstimateVar.find(VD); - if (it != m_EstimateVar.end()) - return it->second; - return nullptr; - } - - void FPErrorEstimationModel::AddVarToEstimate(VarDecl* VD, Expr* VDRef) { - m_EstimateVar.emplace(VD, VDRef); - } - - // FIXME: Maybe this should be left to the user too. - Expr* FPErrorEstimationModel::CalculateAggregateError() { - Expr* addExpr = nullptr; - // Loop over all the error variables and form the final error expression of - // the form... _final_error = _delta_var + _delta_var1 +... - for (auto var : m_EstimateVar) { - // Errors through array subscript expressions are already captured - // to avoid having long add expression at the end and to only add - // the values to the final error that have a non zero delta. - if (isArrayOrPointerType(var.first->getType())) - continue; - - if (!addExpr) { - addExpr = var.second; - continue; - } - addExpr = BuildOp(BO_Add, addExpr, var.second); - } - // Return an expression that can be directly assigned to final error. - return addExpr; - } - - // Return nullptr here, this is interpreted as 0 internally. - Expr* FPErrorEstimationModel::SetError(VarDecl* declStmt) { - return nullptr; - } - Expr* FPErrorEstimationModel::GetFunctionCall( std::string funcName, std::string nmspace, llvm::SmallVectorImpl& callArgs) { diff --git a/lib/Differentiator/MultiplexExternalRMVSource.cpp b/lib/Differentiator/MultiplexExternalRMVSource.cpp index 1b6f9c10f..837990656 100644 --- a/lib/Differentiator/MultiplexExternalRMVSource.cpp +++ b/lib/Differentiator/MultiplexExternalRMVSource.cpp @@ -114,9 +114,9 @@ void MultiplexExternalRMVSource:: } void MultiplexExternalRMVSource:: - ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() { + ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() { for (auto source : m_Sources) { - source->ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt(); + source->ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt(); } } @@ -139,10 +139,10 @@ void MultiplexExternalRMVSource:: } } -void MultiplexExternalRMVSource::ActBeforeFinalisingVisitReturnStmt( +void MultiplexExternalRMVSource::ActBeforeFinalizingVisitReturnStmt( StmtDiff& retExprDiff) { for (auto source : m_Sources) { - source->ActBeforeFinalisingVisitReturnStmt(retExprDiff); + source->ActBeforeFinalizingVisitReturnStmt(retExprDiff); } } @@ -156,10 +156,10 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingVisitCallExpr( } } -void MultiplexExternalRMVSource::ActBeforeFinalisingPostIncDecOp( +void MultiplexExternalRMVSource::ActBeforeFinalizingPostIncDecOp( StmtDiff& diff) { for (auto source : m_Sources) { - source->ActBeforeFinalisingPostIncDecOp(diff); + source->ActBeforeFinalizingPostIncDecOp(diff); } } void MultiplexExternalRMVSource::ActAfterCloningLHSOfAssignOp( @@ -169,10 +169,11 @@ void MultiplexExternalRMVSource::ActAfterCloningLHSOfAssignOp( } } -void MultiplexExternalRMVSource::ActBeforeFinalisingAssignOp( - clang::Expr*& LCloned, clang::Expr*& oldValue) { +void MultiplexExternalRMVSource::ActBeforeFinalizingAssignOp( + clang::Expr*& LCloned, clang::Expr*& oldValue, clang::Expr*& R, + clang::BinaryOperator::Opcode& opCode) { for (auto source : m_Sources) { - source->ActBeforeFinalisingAssignOp(LCloned, oldValue); + source->ActBeforeFinalizingAssignOp(LCloned, oldValue, R, opCode); } } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 65f9473a5..51d6c8221 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -901,7 +901,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource - ->ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt(); + ->ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt(); Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward)); Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); @@ -1226,7 +1226,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // ValueAndPushforward. if (!isCladValueAndPushforwardType(type)) { if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalisingVisitReturnStmt(ExprDiff); + m_ExternalSource->ActBeforeFinalizingVisitReturnStmt(ExprDiff); } // Create goto to the label. @@ -2179,7 +2179,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ResultRef = diff_dx; valueForRevPass = diff.getRevSweepAsExpr(); if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); + m_ExternalSource->ActBeforeFinalizingPostIncDecOp(diff); } else if (opCode == UO_PreInc || opCode == UO_PreDec) { diff = Visit(E, dfdx()); Expr* diff_dx = diff.getExpr_dx(); @@ -2583,7 +2583,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else llvm_unreachable("unknown assignment opCode"); if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue); + m_ExternalSource->ActBeforeFinalizingAssignOp(LCloned, ResultRef, R, + opCode); // Output statements from Visit(L). for (auto it = Lblock_begin; it != Lblock_end; ++it) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32ab9f161 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope*); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/test/Analyses/TBR.cpp b/test/Analyses/TBR.cpp index 17be557b1..c87dfb93c 100644 --- a/test/Analyses/TBR.cpp +++ b/test/Analyses/TBR.cpp @@ -1,4 +1,4 @@ -// RUN: %cladclang -mllvm -debug-only=clad-tbr -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s +// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | FileCheck %s // REQUIRES: asserts //CHECK-NOT: {{.*error|warning|note:.*}} @@ -13,7 +13,7 @@ double f1(double x) { #define TEST(F, x) { \ result[0] = 0; \ - auto F##grad = clad::gradient(F);\ + auto F##grad = clad::gradient(F);\ F##grad.execute(x, result);\ printf("{%.2f}\n", result[0]); \ } diff --git a/test/ErrorEstimation/Assignments.C b/test/ErrorEstimation/Assignments.C index 4817fe4c6..f809f32b4 100644 --- a/test/ErrorEstimation/Assignments.C +++ b/test/ErrorEstimation/Assignments.C @@ -14,13 +14,9 @@ float func(float x, float y) { //CHECK: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: float _t1; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = x + y; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: _t1 = y; //CHECK-NEXT: y = x; //CHECK-NEXT: goto _label0; @@ -33,17 +29,15 @@ float func(float x, float y) { //CHECK-NEXT: * _d_x += _r_d1; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; //CHECK-NEXT: * _d_x += _r_d0; //CHECK-NEXT: * _d_y += _r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{y|x}} + _delta_{{y|x}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } float func2(float x, int y) { @@ -53,16 +47,13 @@ float func2(float x, int y) { //CHECK: void func2_grad(float x, int y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = y * x + x * x; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: * _d_x += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; @@ -70,10 +61,8 @@ float func2(float x, int y) { //CHECK-NEXT: * _d_x += y * _r_d0; //CHECK-NEXT: * _d_x += _r_d0 * x; //CHECK-NEXT: * _d_x += x * _r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_x; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: } float func3(int x, int y) { @@ -104,33 +93,24 @@ float func4(float x, float y) { //CHECK: void func4_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: double _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: double _EERepl_z0; //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: double z = y; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = z + y; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: * _d_x += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; //CHECK-NEXT: _d_z += _r_d0; //CHECK-NEXT: * _d_y += _r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: * _d_y += _d_z; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } float func5(float x, float y) { @@ -142,28 +122,22 @@ float func5(float x, float y) { //CHECK: void func5_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: int _d_z = 0; //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: int z = 56; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = z + y; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: * _d_x += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; //CHECK-NEXT: _d_z += _r_d0; //CHECK-NEXT: * _d_y += _r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y}} + _delta_{{x|y}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } float func6(float x) { return x; } @@ -172,9 +146,7 @@ float func6(float x) { return x; } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: * _d_x += 1; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: _final_error += _delta_x; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: } float func7(float x, float y) { return (x * y); } @@ -188,11 +160,30 @@ float func7(float x, float y) { return (x * y); } //CHECK-NEXT: * _d_x += 1 * y; //CHECK-NEXT: * _d_y += x * 1; //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y}} + _delta_{{x|y}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: } + +float func8(int x, int y) { + x = y * y; + return x; +} + +//CHECK: void func8_grad(int x, int y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { +//CHECK-NEXT: int _t0; +//CHECK-NEXT: _t0 = x; +//CHECK-NEXT: x = y * y; +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: * _d_x += 1; +//CHECK-NEXT: { +//CHECK-NEXT: x = _t0; +//CHECK-NEXT: int _r_d0 = * _d_x; +//CHECK-NEXT: * _d_x -= _r_d0; +//CHECK-NEXT: * _d_y += _r_d0 * y; +//CHECK-NEXT: * _d_y += y * _r_d0; +//CHECK-NEXT: } //CHECK-NEXT: } int main() { @@ -203,4 +194,5 @@ int main() { clad::estimate_error(func5); clad::estimate_error(func6); clad::estimate_error(func7); + clad::estimate_error(func8); } diff --git a/test/ErrorEstimation/BasicOps.C b/test/ErrorEstimation/BasicOps.C index 104ca9721..75c4e1f96 100644 --- a/test/ErrorEstimation/BasicOps.C +++ b/test/ErrorEstimation/BasicOps.C @@ -15,56 +15,42 @@ float func(float x, float y) { //CHECK: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: float _t1; -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: float _EERepl_y0 = y; -//CHECK-NEXT: float _EERepl_y1; -//CHECK-NEXT: float _EERepl_y2; //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = x + y; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: _t1 = y; -//CHECK-NEXT: _EERepl_y1 = y; //CHECK-NEXT: y = y + y++ + y; -//CHECK-NEXT: _EERepl_y2 = y; //CHECK-NEXT: float z = y * x; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_z += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_z * z * {{.+}}); //CHECK-NEXT: * _d_y += _d_z * x; //CHECK-NEXT: * _d_x += y * _d_z; -//CHECK-NEXT: _delta_z += std::abs(_d_z * _EERepl_z0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: y = _t1; //CHECK-NEXT: float _r_d1 = * _d_y; //CHECK-NEXT: * _d_y -= _r_d1; //CHECK-NEXT: * _d_y += _r_d1; //CHECK-NEXT: * _d_y += _r_d1; //CHECK-NEXT: y--; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * _EERepl_y1 * {{.+}}); //CHECK-NEXT: * _d_y += _r_d1; -//CHECK-NEXT: _delta_y += std::abs(_r_d1 * _EERepl_y2 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; //CHECK-NEXT: * _d_x += _r_d0; //CHECK-NEXT: * _d_y += _r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: _delta_y += std::abs(* _d_y * _EERepl_y0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } // This function may evaluate incorrectly due to absence of usage of @@ -77,27 +63,21 @@ float func2(float x, float y) { //CHECK: void func2_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = x - y - y * y; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: float z = y / x; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_z += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_z * z * {{.+}}); //CHECK-NEXT: * _d_y += _d_z / x; //CHECK-NEXT: float _r0 = _d_z * -y / (x * x); //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: _delta_z += std::abs(_d_z * _EERepl_z0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; @@ -105,12 +85,9 @@ float func2(float x, float y) { //CHECK-NEXT: * _d_y += -_r_d0; //CHECK-NEXT: * _d_y += -_r_d0 * y; //CHECK-NEXT: * _d_y += y * -_r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } @@ -124,34 +101,22 @@ float func3(float x, float y) { //CHECK: void func3_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: float _t1; //CHECK-NEXT: float _t2; -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: float _EERepl_y0 = y; -//CHECK-NEXT: float _EERepl_y1; //CHECK-NEXT: float _d_t = 0; -//CHECK-NEXT: double _delta_t = 0; -//CHECK-NEXT: float _EERepl_t0; //CHECK-NEXT: _t0 = x; //CHECK-NEXT: x = x - y - y * y; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: float z = y; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _t2 = y; //CHECK-NEXT: _t1 = (y = x + x); //CHECK-NEXT: float t = x * z * _t1; -//CHECK-NEXT: _EERepl_t0 = t; -//CHECK-NEXT: _EERepl_y1 = y; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_t += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_t * t * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: * _d_x += _d_t * _t1 * z; //CHECK-NEXT: _d_z += x * _d_t * _t1; //CHECK-NEXT: * _d_y += x * z * _d_t; @@ -160,11 +125,10 @@ float func3(float x, float y) { //CHECK-NEXT: * _d_y -= _r_d1; //CHECK-NEXT: * _d_x += _r_d1; //CHECK-NEXT: * _d_x += _r_d1; -//CHECK-NEXT: _delta_y += std::abs(_r_d1 * _EERepl_y1 * {{.+}}); -//CHECK-NEXT: _delta_t += std::abs(_d_t * _EERepl_t0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: * _d_y += _d_z; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d0; @@ -172,11 +136,9 @@ float func3(float x, float y) { //CHECK-NEXT: * _d_y += -_r_d0; //CHECK-NEXT: * _d_y += -_r_d0 * y; //CHECK-NEXT: * _d_y += y * -_r_d0; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: _delta_y += std::abs(* _d_y * _EERepl_y0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{t|x|y|z}} + _delta_{{t|x|y|z}} + _delta_{{t|x|y|z}} + _delta_{{t|x|y|z}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } // Function call custom derivative exists but no assign expr @@ -196,11 +158,9 @@ float func4(float x, float y) { return std::pow(x, y); } //CHECK-NEXT: float _r1 = _grad1; //CHECK-NEXT: * _d_y += _r1; //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y}} + _delta_{{x|y}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } // Function call custom derivative exists and is assigned @@ -211,13 +171,9 @@ float func5(float x, float y) { //CHECK: void func5_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: float _EERepl_y0 = y; -//CHECK-NEXT: float _EERepl_y1; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: _t0 = y; //CHECK-NEXT: y = std::sin(x); -//CHECK-NEXT: _EERepl_y1 = y; //CHECK-NEXT: _ret_value0 = y * y; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -226,17 +182,16 @@ float func5(float x, float y) { //CHECK-NEXT: * _d_y += y * 1; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: y = _t0; //CHECK-NEXT: float _r_d0 = * _d_y; //CHECK-NEXT: * _d_y -= _r_d0; //CHECK-NEXT: float _r0 = _r_d0 * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: _delta_y += std::abs(_r_d0 * _EERepl_y1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: _delta_y += std::abs(* _d_y * _EERepl_y0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y}} + _delta_{{x|y}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } // Function call non custom derivative @@ -251,11 +206,9 @@ double helper(double x, double y) { return x * y; } //CHECK-NEXT: * _d_x += _d_y0 * y; //CHECK-NEXT: * _d_y += x * _d_y0; //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{y|x}} + _delta_{{y|x}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } float func6(float x, float y) { @@ -265,11 +218,8 @@ float func6(float x, float y) { //CHECK: void func6_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: float z = helper(x, y); -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _ret_value0 = z * z; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -286,13 +236,11 @@ float func6(float x, float y) { //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: double _r1 = _grad1; //CHECK-NEXT: * _d_y += _r1; -//CHECK-NEXT: _delta_z += _t0; +//CHECK-NEXT: _final_error += _t0; //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{y|x|z}} + _delta_{{y|x|z}} + _delta_{{y|x|z}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } float func7(float x) { @@ -310,9 +258,7 @@ float func7(float x) { //CHECK-NEXT: _d_z += 1; //CHECK-NEXT: } //CHECK-NEXT: * _d_x += _d_z; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: _final_error += _delta_x; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: } @@ -338,17 +284,12 @@ float func8(float x, float y) { //CHECK: void func8_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: float _t0; //CHECK-NEXT: float _t1; -//CHECK-NEXT: float _EERepl_z1; //CHECK-NEXT: float z; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _t0 = z; //CHECK-NEXT: _t1 = x; //CHECK-NEXT: z = y + helper2(x); -//CHECK-NEXT: _EERepl_z1 = z; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_z += 1; @@ -361,14 +302,11 @@ float func8(float x, float y) { //CHECK-NEXT: double _t2 = 0; //CHECK-NEXT: helper2_pullback(_t1, _r_d0, &* _d_x, _t2); //CHECK-NEXT: float _r0 = * _d_x; -//CHECK-NEXT: _delta_z += _t2; +//CHECK-NEXT: _final_error += _t2; //CHECK-NEXT: _final_error += std::abs(_r0 * _t1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } float func9(float x, float y) { @@ -380,24 +318,19 @@ float func9(float x, float y) { //CHECK: void func9_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _t1; //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: float _t3; //CHECK-NEXT: double _t4; //CHECK-NEXT: float _t5; //CHECK-NEXT: double _t7; //CHECK-NEXT: float _t8; -//CHECK-NEXT: float _EERepl_z1; //CHECK-NEXT: _t1 = x; //CHECK-NEXT: float z = helper(x, y) + helper2(x); -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _t3 = z; //CHECK-NEXT: _t5 = x; //CHECK-NEXT: _t7 = helper2(x); //CHECK-NEXT: _t8 = y; //CHECK-NEXT: _t4 = helper2(y); //CHECK-NEXT: z += _t7 * _t4; -//CHECK-NEXT: _EERepl_z1 = z; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_z += 1; @@ -412,7 +345,7 @@ float func9(float x, float y) { //CHECK-NEXT: double _t9 = 0; //CHECK-NEXT: helper2_pullback(_t8, _t7 * _r_d0, &* _d_y, _t9); //CHECK-NEXT: float _r4 = * _d_y; -//CHECK-NEXT: _delta_z += _t6 + _t9; +//CHECK-NEXT: _final_error += _t6 + _t9; //CHECK-NEXT: _final_error += std::abs(_r4 * _t8 * {{.+}}); //CHECK-NEXT: _final_error += std::abs(_r3 * _t5 * {{.+}}); //CHECK-NEXT: } @@ -429,14 +362,11 @@ float func9(float x, float y) { //CHECK-NEXT: double _t2 = 0; //CHECK-NEXT: helper2_pullback(_t1, _d_z, &* _d_x, _t2); //CHECK-NEXT: float _r2 = * _d_x; -//CHECK-NEXT: _delta_z += _t0 + _t2; +//CHECK-NEXT: _final_error += _t0 + _t2; //CHECK-NEXT: _final_error += std::abs(_r2 * _t1 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } int main() { diff --git a/test/ErrorEstimation/ConditonalStatements.C b/test/ErrorEstimation/ConditonalStatements.C index f9454134b..aa01c61ea 100644 --- a/test/ErrorEstimation/ConditonalStatements.C +++ b/test/ErrorEstimation/ConditonalStatements.C @@ -20,28 +20,19 @@ float func(float x, float y) { //CHECK: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: float _EERepl_y0 = y; -//CHECK-NEXT: float _EERepl_y1; //CHECK-NEXT: float _d_temp = 0; -//CHECK-NEXT: double _delta_temp = 0; -//CHECK-NEXT: float _EERepl_temp0; //CHECK-NEXT: float temp = 0; //CHECK-NEXT: float _t1; -//CHECK-NEXT: float _EERepl_temp1; //CHECK-NEXT: float _t2; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: _cond0 = x > y; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _t0 = y; //CHECK-NEXT: y = y * x; -//CHECK-NEXT: _EERepl_y1 = y; //CHECK-NEXT: } else { //CHECK-NEXT: temp = y; -//CHECK-NEXT: _EERepl_temp0 = temp; //CHECK-NEXT: _t1 = temp; //CHECK-NEXT: temp = y * y; -//CHECK-NEXT: _EERepl_temp1 = temp; //CHECK-NEXT: _t2 = x; //CHECK-NEXT: x = y; //CHECK-NEXT: } @@ -54,12 +45,12 @@ float func(float x, float y) { //CHECK-NEXT: } //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: y = _t0; //CHECK-NEXT: float _r_d0 = * _d_y; //CHECK-NEXT: * _d_y -= _r_d0; //CHECK-NEXT: * _d_y += _r_d0 * x; //CHECK-NEXT: * _d_x += y * _r_d0; -//CHECK-NEXT: _delta_y += std::abs(_r_d0 * _EERepl_y1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: } else { //CHECK-NEXT: { @@ -69,22 +60,21 @@ float func(float x, float y) { //CHECK-NEXT: * _d_y += _r_d2; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_temp * temp * {{.+}}); //CHECK-NEXT: temp = _t1; //CHECK-NEXT: float _r_d1 = _d_temp; //CHECK-NEXT: _d_temp -= _r_d1; //CHECK-NEXT: * _d_y += _r_d1 * y; //CHECK-NEXT: * _d_y += y * _r_d1; -//CHECK-NEXT: _delta_temp += std::abs(_r_d1 * _EERepl_temp1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_temp * temp * {{.+}}); //CHECK-NEXT: * _d_y += _d_temp; -//CHECK-NEXT: _delta_temp += std::abs(_d_temp * _EERepl_temp0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: _delta_y += std::abs(* _d_y * _EERepl_y0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y|temp}} + _delta_{{x|y|temp}} + _delta_{{x|y|temp}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } // Single return statement if/else @@ -98,12 +88,9 @@ float func2(float x) { //CHECK: void func2_grad(float x, clad::array_ref _d_x, double &_final_error) { //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: bool _cond0; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: float z = x * x; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _cond0 = z > 9; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _ret_value0 = x + x; @@ -125,13 +112,12 @@ float func2(float x) { //CHECK-NEXT: * _d_x += x * 1; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_z * z * {{.+}}); //CHECK-NEXT: * _d_x += _d_z * x; //CHECK-NEXT: * _d_x += x * _d_z; -//CHECK-NEXT: _delta_z += std::abs(_d_z * _EERepl_z0 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|z}} + _delta_{{x|z}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } float func3(float x, float y) { return x > 30 ? x * y : x + y; } @@ -150,11 +136,9 @@ float func3(float x, float y) { return x > 30 ? x * y : x + y; } //CHECK-NEXT: * _d_x += 1; //CHECK-NEXT: * _d_y += 1; //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y}} + _delta_{{x|y}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } float func4(float x, float y) { @@ -165,11 +149,7 @@ float func4(float x, float y) { //CHECK: void func4_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: float _t0; -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: float _EERepl_x0 = x; -//CHECK-NEXT: float _EERepl_x1; //CHECK-NEXT: float _t1; -//CHECK-NEXT: float _EERepl_x2; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: _cond0 = !x; //CHECK-NEXT: if (_cond0) @@ -177,8 +157,6 @@ float func4(float x, float y) { //CHECK-NEXT: else //CHECK-NEXT: _t1 = x; //CHECK-NEXT: _cond0 ? (x += 1) : (x *= x); -//CHECK-NEXT: _EERepl_x2 = x; -//CHECK-NEXT: _EERepl_x1 = x; //CHECK-NEXT: _ret_value0 = y / x; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -188,21 +166,20 @@ float func4(float x, float y) { //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t0; //CHECK-NEXT: float _r_d0 = * _d_x; -//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}}); //CHECK-NEXT: } else { +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: x = _t1; //CHECK-NEXT: float _r_d1 = * _d_x; //CHECK-NEXT: * _d_x -= _r_d1; //CHECK-NEXT: * _d_x += _r_d1 * x; //CHECK-NEXT: * _d_x += x * _r_d1; -//CHECK-NEXT: _delta_x += std::abs(_r_d1 * _EERepl_x2 * {{.+}}); //CHECK-NEXT: } -//CHECK-NEXT: _delta_x += std::abs(* _d_x * _EERepl_x0 * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|y}} + _delta_{{x|y}} + std::abs(1. * _ret_value0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } int main() { diff --git a/test/ErrorEstimation/LoopsAndArrays.C b/test/ErrorEstimation/LoopsAndArrays.C index 253c7a68f..1bce92a9a 100644 --- a/test/ErrorEstimation/LoopsAndArrays.C +++ b/test/ErrorEstimation/LoopsAndArrays.C @@ -17,21 +17,16 @@ float func(float* p, int n) { //CHECK: void func_grad(float *p, int n, clad::array_ref _d_p, clad::array_ref _d_n, double &_final_error) { //CHECK-NEXT: float _d_sum = 0; -//CHECK-NEXT: double _delta_sum = 0; -//CHECK-NEXT: float _EERepl_sum0; //CHECK-NEXT: unsigned long _t0; //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; -//CHECK-NEXT: clad::tape _EERepl_sum1 = {}; //CHECK-NEXT: float sum = 0; -//CHECK-NEXT: _EERepl_sum0 = sum; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < n; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, sum); //CHECK-NEXT: sum += p[i]; -//CHECK-NEXT: clad::push(_EERepl_sum1, sum); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -39,22 +34,16 @@ float func(float* p, int n) { //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: i--; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: sum = clad::pop(_t1); //CHECK-NEXT: float _r_d0 = _d_sum; //CHECK-NEXT: _d_p[i] += _r_d0; -//CHECK-NEXT: float _r0 = clad::pop(_EERepl_sum1); -//CHECK-NEXT: _delta_sum += std::abs(_r_d0 * _r0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * {{.+}}); -//CHECK-NEXT: clad::array _delta_p(_d_p.size()); +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: int i0 = 0; -//CHECK-NEXT: for (; i0 < _d_p.size(); i0++) { -//CHECK-NEXT: double _t2 = std::abs(_d_p[i0] * p[i0] * {{.+}}); -//CHECK-NEXT: _delta_p[i0] += _t2; -//CHECK-NEXT: _final_error += _t2; -//CHECK-NEXT: } -//CHECK-NEXT: _final_error += _delta_sum; +//CHECK-NEXT: for (; i0 < _d_p.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_p[i0] * p[i0] * {{.+}}); //CHECK-NEXT: } @@ -69,28 +58,20 @@ float func2(float x) { //CHECK: void func2_grad(float x, clad::array_ref _d_x, double &_final_error) { //CHECK-NEXT: float _d_z = 0; -//CHECK-NEXT: double _delta_z = 0; -//CHECK-NEXT: float _EERepl_z0; //CHECK-NEXT: unsigned long _t0; //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: float _d_m = 0; -//CHECK-NEXT: double _delta_m = 0; -//CHECK-NEXT: clad::tape _EERepl_m0 = {}; //CHECK-NEXT: float m = 0; //CHECK-NEXT: clad::tape _t2 = {}; -//CHECK-NEXT: clad::tape _EERepl_z1 = {}; //CHECK-NEXT: float z; -//CHECK-NEXT: _EERepl_z0 = z; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < 9; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, m) , m = x * x; -//CHECK-NEXT: clad::push(_EERepl_m0, m); //CHECK-NEXT: clad::push(_t2, z); //CHECK-NEXT: z = m + m; -//CHECK-NEXT: clad::push(_EERepl_z1, z); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -98,26 +79,22 @@ float func2(float x) { //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: i--; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_z * z * {{.+}}); //CHECK-NEXT: z = clad::pop(_t2); //CHECK-NEXT: float _r_d0 = _d_z; //CHECK-NEXT: _d_z -= _r_d0; //CHECK-NEXT: _d_m += _r_d0; //CHECK-NEXT: _d_m += _r_d0; -//CHECK-NEXT: float _r1 = clad::pop(_EERepl_z1); -//CHECK-NEXT: _delta_z += std::abs(_r_d0 * _r1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_m * m * {{.+}}); //CHECK-NEXT: * _d_x += _d_m * x; //CHECK-NEXT: * _d_x += x * _d_m; //CHECK-NEXT: _d_m = 0; //CHECK-NEXT: m = clad::pop(_t1); -//CHECK-NEXT: float _r0 = clad::pop(_EERepl_m0); -//CHECK-NEXT: _delta_m += std::abs(_d_m * _r0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{x|z|m}} + _delta_{{x|z|m}} + _delta_{{x|z|m}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: } float func3(float x, float y) { @@ -130,58 +107,45 @@ float func3(float x, float y) { //CHECK: void func3_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: clad::array _d_arr(3UL); -//CHECK-NEXT: clad::array _delta_arr(_d_arr.size()); //CHECK-NEXT: double _t0; -//CHECK-NEXT: double _EERepl_arr0; //CHECK-NEXT: double _t1; -//CHECK-NEXT: double _EERepl_arr1; //CHECK-NEXT: double _t2; -//CHECK-NEXT: double _EERepl_arr2; //CHECK-NEXT: double arr[3]; //CHECK-NEXT: _t0 = arr[0]; //CHECK-NEXT: arr[0] = x + y; -//CHECK-NEXT: _EERepl_arr0 = arr[0]; //CHECK-NEXT: _t1 = arr[1]; //CHECK-NEXT: arr[1] = x * x; -//CHECK-NEXT: _EERepl_arr1 = arr[1]; //CHECK-NEXT: _t2 = arr[2]; //CHECK-NEXT: arr[2] = arr[0] + arr[1]; -//CHECK-NEXT: _EERepl_arr2 = arr[2]; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_arr[2] += 1; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_arr[2] * arr[2] * {{.+}}); //CHECK-NEXT: arr[2] = _t2; //CHECK-NEXT: double _r_d2 = _d_arr[2]; //CHECK-NEXT: _d_arr[2] -= _r_d2; //CHECK-NEXT: _d_arr[0] += _r_d2; //CHECK-NEXT: _d_arr[1] += _r_d2; -//CHECK-NEXT: _delta_arr[2] += std::abs(_r_d2 * _EERepl_arr2 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_arr[2]; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_arr[1] * arr[1] * {{.+}}); //CHECK-NEXT: arr[1] = _t1; //CHECK-NEXT: double _r_d1 = _d_arr[1]; //CHECK-NEXT: _d_arr[1] -= _r_d1; //CHECK-NEXT: * _d_x += _r_d1 * x; //CHECK-NEXT: * _d_x += x * _r_d1; -//CHECK-NEXT: _delta_arr[1] += std::abs(_r_d1 * _EERepl_arr1 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_arr[1]; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_arr[0] * arr[0] * {{.+}}); //CHECK-NEXT: arr[0] = _t0; //CHECK-NEXT: double _r_d0 = _d_arr[0]; //CHECK-NEXT: _d_arr[0] -= _r_d0; //CHECK-NEXT: * _d_x += _r_d0; //CHECK-NEXT: * _d_y += _r_d0; -//CHECK-NEXT: _delta_arr[0] += std::abs(_r_d0 * _EERepl_arr0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_arr[0]; //CHECK-NEXT: } -//CHECK-NEXT: double _delta_x = 0; -//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}}); -//CHECK-NEXT: double _delta_y = 0; -//CHECK-NEXT: _delta_y += std::abs(* _d_y * y * {{.+}}); -//CHECK-NEXT: _final_error += _delta_{{y|x}} + _delta_{{y|x}}; +//CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); //CHECK-NEXT: } float func4(float x[10], float y[10]) { @@ -195,31 +159,19 @@ float func4(float x[10], float y[10]) { //CHECK: void func4_grad(float x[10], float y[10], clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { //CHECK-NEXT: float _d_sum = 0; -//CHECK-NEXT: double _delta_sum = 0; -//CHECK-NEXT: float _EERepl_sum0; //CHECK-NEXT: unsigned long _t0; //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; -//CHECK-NEXT: clad::array _delta_x(_d_x.size()); -//CHECK-NEXT: clad::array _EERepl_x0(_d_x.size()); -//CHECK-NEXT: for (int i0 = 0; i0 < _d_x.size(); i0++) { -//CHECK-NEXT: _EERepl_x0[i0] = x[i0]; -//CHECK-NEXT: } -//CHECK-NEXT: clad::tape _EERepl_x1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; -//CHECK-NEXT: clad::tape _EERepl_sum1 = {}; //CHECK-NEXT: float sum = 0; -//CHECK-NEXT: _EERepl_sum0 = sum; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < 10; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, x[i]); //CHECK-NEXT: x[i] += y[i]; -//CHECK-NEXT: clad::push(_EERepl_x1, x[i]); //CHECK-NEXT: clad::push(_t2, sum); //CHECK-NEXT: sum += x[i]; -//CHECK-NEXT: clad::push(_EERepl_sum1, sum); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -227,36 +179,25 @@ float func4(float x[10], float y[10]) { //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: i--; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: sum = clad::pop(_t2); //CHECK-NEXT: float _r_d1 = _d_sum; //CHECK-NEXT: _d_x[i] += _r_d1; -//CHECK-NEXT: float _r1 = clad::pop(_EERepl_sum1); -//CHECK-NEXT: _delta_sum += std::abs(_r_d1 * _r1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_x[i] * x[i] * {{.+}}); //CHECK-NEXT: x[i] = clad::pop(_t1); //CHECK-NEXT: float _r_d0 = _d_x[i]; //CHECK-NEXT: _d_y[i] += _r_d0; -//CHECK-NEXT: float _r0 = clad::pop(_EERepl_x1); -//CHECK-NEXT: _delta_x[i] += std::abs(_r_d0 * _r0 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_x[i]; //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: int i0 = 0; -//CHECK-NEXT: for (; i0 < _d_x.size(); i0++) { -//CHECK-NEXT: double _t3 = std::abs(_d_x[i0] * _EERepl_x0[i0] * {{.+}}); -//CHECK-NEXT: _delta_x[i0] += _t3; -//CHECK-NEXT: _final_error += _t3; -//CHECK-NEXT: } -//CHECK-NEXT: clad::array _delta_y(_d_y.size()); +//CHECK-NEXT: for (; i0 < _d_x.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_x[i0] * x[i0] * {{.+}}); //CHECK-NEXT: i0 = 0; -//CHECK-NEXT: for (; i0 < _d_y.size(); i0++) { -//CHECK-NEXT: double _t4 = std::abs(_d_y[i0] * y[i0] * {{.+}}); -//CHECK-NEXT: _delta_y[i0] += _t4; -//CHECK-NEXT: _final_error += _t4; -//CHECK-NEXT: } -//CHECK-NEXT: _final_error += _delta_sum; +//CHECK-NEXT: for (; i0 < _d_y.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_y[i0] * y[i0] * {{.+}}); //CHECK-NEXT: } @@ -269,26 +210,15 @@ double func5(double* x, double* y, double* output) { //CHECK: void func5_grad(double *x, double *y, double *output, clad::array_ref _d_x, clad::array_ref _d_y, clad::array_ref _d_output, double &_final_error) { //CHECK-NEXT: double _t0; -//CHECK-NEXT: clad::array _delta_output(_d_output.size()); -//CHECK-NEXT: clad::array _EERepl_output0(_d_output.size()); -//CHECK-NEXT: for (int i = 0; i < _d_output.size(); i++) { -//CHECK-NEXT: _EERepl_output0[i] = output[i]; -//CHECK-NEXT: } -//CHECK-NEXT: double _EERepl_output1; //CHECK-NEXT: double _t1; -//CHECK-NEXT: double _EERepl_output2; //CHECK-NEXT: double _t2; -//CHECK-NEXT: double _EERepl_output3; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: _t0 = output[0]; //CHECK-NEXT: output[0] = x[1] * y[2] - x[2] * y[1]; -//CHECK-NEXT: _EERepl_output1 = output[0]; //CHECK-NEXT: _t1 = output[1]; //CHECK-NEXT: output[1] = x[2] * y[0] - x[0] * y[2]; -//CHECK-NEXT: _EERepl_output2 = output[1]; //CHECK-NEXT: _t2 = output[2]; //CHECK-NEXT: output[2] = x[0] * y[1] - y[0] * x[1]; -//CHECK-NEXT: _EERepl_output3 = output[2]; //CHECK-NEXT: _ret_value0 = output[0] + output[1] + output[2]; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -298,6 +228,7 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: _d_output[2] += 1; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_output[2] * output[2] * {{.+}}); //CHECK-NEXT: output[2] = _t2; //CHECK-NEXT: double _r_d2 = _d_output[2]; //CHECK-NEXT: _d_output[2] -= _r_d2; @@ -305,10 +236,9 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: _d_y[1] += x[0] * _r_d2; //CHECK-NEXT: _d_y[0] += -_r_d2 * x[1]; //CHECK-NEXT: _d_x[1] += y[0] * -_r_d2; -//CHECK-NEXT: _delta_output[2] += std::abs(_r_d2 * _EERepl_output3 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_output[2]; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_output[1] * output[1] * {{.+}}); //CHECK-NEXT: output[1] = _t1; //CHECK-NEXT: double _r_d1 = _d_output[1]; //CHECK-NEXT: _d_output[1] -= _r_d1; @@ -316,10 +246,9 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: _d_y[0] += x[2] * _r_d1; //CHECK-NEXT: _d_x[0] += -_r_d1 * y[2]; //CHECK-NEXT: _d_y[2] += x[0] * -_r_d1; -//CHECK-NEXT: _delta_output[1] += std::abs(_r_d1 * _EERepl_output2 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_output[1]; //CHECK-NEXT: } //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_output[0] * output[0] * {{.+}}); //CHECK-NEXT: output[0] = _t0; //CHECK-NEXT: double _r_d0 = _d_output[0]; //CHECK-NEXT: _d_output[0] -= _r_d0; @@ -327,29 +256,16 @@ double func5(double* x, double* y, double* output) { //CHECK-NEXT: _d_y[2] += x[1] * _r_d0; //CHECK-NEXT: _d_x[2] += -_r_d0 * y[1]; //CHECK-NEXT: _d_y[1] += x[2] * -_r_d0; -//CHECK-NEXT: _delta_output[0] += std::abs(_r_d0 * _EERepl_output1 * {{.+}}); -//CHECK-NEXT: _final_error += _delta_output[0]; //CHECK-NEXT: } -//CHECK-NEXT: clad::array _delta_x(_d_x.size()); //CHECK-NEXT: int i = 0; -//CHECK-NEXT: for (; i < _d_x.size(); i++) { -//CHECK-NEXT: double _t3 = std::abs(_d_x[i] * x[i] * {{.+}}); -//CHECK-NEXT: _delta_x[i] += _t3; -//CHECK-NEXT: _final_error += _t3; -//CHECK-NEXT: } -//CHECK-NEXT: clad::array _delta_y(_d_y.size()); +//CHECK-NEXT: for (; i < _d_x.size(); i++) +//CHECK-NEXT: _final_error += std::abs(_d_x[i] * x[i] * {{.+}}); //CHECK-NEXT: i = 0; -//CHECK-NEXT: for (; i < _d_y.size(); i++) { -//CHECK-NEXT: double _t4 = std::abs(_d_y[i] * y[i] * {{.+}}); -//CHECK-NEXT: _delta_y[i] += _t4; -//CHECK-NEXT: _final_error += _t4; -//CHECK-NEXT: } +//CHECK-NEXT: for (; i < _d_y.size(); i++) +//CHECK-NEXT: _final_error += std::abs(_d_y[i] * y[i] * {{.+}}); //CHECK-NEXT: i = 0; -//CHECK-NEXT: for (; i < _d_output.size(); i++) { -//CHECK-NEXT: double _t5 = std::abs(_d_output[i] * _EERepl_output0[i] * {{.+}}); -//CHECK-NEXT: _delta_output[i] += _t5; -//CHECK-NEXT: _final_error += _t5; -//CHECK-NEXT: } +//CHECK-NEXT: for (; i < _d_output.size(); i++) +//CHECK-NEXT: _final_error += std::abs(_d_output[i] * output[i] * {{.+}}); //CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}}); //CHECK-NEXT: } diff --git a/test/ErrorEstimation/LoopsAndArraysExec.C b/test/ErrorEstimation/LoopsAndArraysExec.C index fb067c942..242413965 100644 --- a/test/ErrorEstimation/LoopsAndArraysExec.C +++ b/test/ErrorEstimation/LoopsAndArraysExec.C @@ -18,21 +18,16 @@ double runningSum(float* f, int n) { //CHECK: void runningSum_grad(float *f, int n, clad::array_ref _d_f, clad::array_ref _d_n, double &_final_error) { //CHECK-NEXT: double _d_sum = 0; -//CHECK-NEXT: double _delta_sum = 0; -//CHECK-NEXT: double _EERepl_sum0; //CHECK-NEXT: unsigned long _t0; //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; -//CHECK-NEXT: clad::tape _EERepl_sum1 = {}; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: _EERepl_sum0 = sum; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 1; i < n; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, sum); //CHECK-NEXT: sum += f[i] + f[i - 1]; -//CHECK-NEXT: clad::push(_EERepl_sum1, sum); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -40,23 +35,17 @@ double runningSum(float* f, int n) { //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: i--; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: sum = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_sum; //CHECK-NEXT: _d_f[i] += _r_d0; //CHECK-NEXT: _d_f[i - 1] += _r_d0; -//CHECK-NEXT: double _r0 = clad::pop(_EERepl_sum1); -//CHECK-NEXT: _delta_sum += std::abs(_r_d0 * _r0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * {{.+}}); -//CHECK-NEXT: clad::array _delta_f(_d_f.size()); +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: int i0 = 0; -//CHECK-NEXT: for (; i0 < _d_f.size(); i0++) { -//CHECK-NEXT: double _t2 = std::abs(_d_f[i0] * f[i0] * {{.+}}); -//CHECK-NEXT: _delta_f[i0] += _t2; -//CHECK-NEXT: _final_error += _t2; -//CHECK-NEXT: } -//CHECK-NEXT: _final_error += _delta_sum; +//CHECK-NEXT: for (; i0 < _d_f.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_f[i0] * f[i0] * {{.+}}); //CHECK-NEXT: } double mulSum(float* a, float* b, int n) { @@ -70,8 +59,6 @@ double mulSum(float* a, float* b, int n) { //CHECK: void mulSum_grad(float *a, float *b, int n, clad::array_ref _d_a, clad::array_ref _d_b, clad::array_ref _d_n, double &_final_error) { //CHECK-NEXT: double _d_sum = 0; -//CHECK-NEXT: double _delta_sum = 0; -//CHECK-NEXT: double _EERepl_sum0; //CHECK-NEXT: unsigned long _t0; //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; @@ -80,9 +67,7 @@ double mulSum(float* a, float* b, int n) { //CHECK-NEXT: int _d_j = 0; //CHECK-NEXT: int j = 0; //CHECK-NEXT: clad::tape _t3 = {}; -//CHECK-NEXT: clad::tape _EERepl_sum1 = {}; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: _EERepl_sum0 = sum; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < n; i++) { //CHECK-NEXT: _t0++; @@ -91,7 +76,6 @@ double mulSum(float* a, float* b, int n) { //CHECK-NEXT: clad::back(_t1)++; //CHECK-NEXT: clad::push(_t3, sum); //CHECK-NEXT: sum += a[i] * b[j]; -//CHECK-NEXT: clad::push(_EERepl_sum1, sum); //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: goto _label0; @@ -102,12 +86,11 @@ double mulSum(float* a, float* b, int n) { //CHECK-NEXT: { //CHECK-NEXT: for (; clad::back(_t1); clad::back(_t1)--) { //CHECK-NEXT: j--; +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: sum = clad::pop(_t3); //CHECK-NEXT: double _r_d0 = _d_sum; //CHECK-NEXT: _d_a[i] += _r_d0 * b[j]; //CHECK-NEXT: _d_b[j] += a[i] * _r_d0; -//CHECK-NEXT: double _r0 = clad::pop(_EERepl_sum1); -//CHECK-NEXT: _delta_sum += std::abs(_r_d0 * _r0 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { //CHECK-NEXT: _d_j = 0; @@ -116,22 +99,13 @@ double mulSum(float* a, float* b, int n) { //CHECK-NEXT: clad::pop(_t1); //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * {{.+}}); -//CHECK-NEXT: clad::array _delta_a(_d_a.size()); +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: int i0 = 0; -//CHECK-NEXT: for (; i0 < _d_a.size(); i0++) { -//CHECK-NEXT: double _t4 = std::abs(_d_a[i0] * a[i0] * {{.+}}); -//CHECK-NEXT: _delta_a[i0] += _t4; -//CHECK-NEXT: _final_error += _t4; -//CHECK-NEXT: } -//CHECK-NEXT: clad::array _delta_b(_d_b.size()); +//CHECK-NEXT: for (; i0 < _d_a.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_a[i0] * a[i0] * {{.+}}); //CHECK-NEXT: i0 = 0; -//CHECK-NEXT: for (; i0 < _d_b.size(); i0++) { -//CHECK-NEXT: double _t5 = std::abs(_d_b[i0] * b[i0] * {{.+}}); -//CHECK-NEXT: _delta_b[i0] += _t5; -//CHECK-NEXT: _final_error += _t5; -//CHECK-NEXT: } -//CHECK-NEXT: _final_error += _delta_sum; +//CHECK-NEXT: for (; i0 < _d_b.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_b[i0] * b[i0] * {{.+}}); //CHECK-NEXT: } double divSum(float* a, float* b, int n) { @@ -144,21 +118,16 @@ double divSum(float* a, float* b, int n) { //CHECK: void divSum_grad(float *a, float *b, int n, clad::array_ref _d_a, clad::array_ref _d_b, clad::array_ref _d_n, double &_final_error) { //CHECK-NEXT: double _d_sum = 0; -//CHECK-NEXT: double _delta_sum = 0; -//CHECK-NEXT: double _EERepl_sum0; //CHECK-NEXT: unsigned long _t0; //CHECK-NEXT: int _d_i = 0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: clad::tape _t1 = {}; -//CHECK-NEXT: clad::tape _EERepl_sum1 = {}; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: _EERepl_sum0 = sum; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < n; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, sum); //CHECK-NEXT: sum += a[i] / b[i]; -//CHECK-NEXT: clad::push(_EERepl_sum1, sum); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -166,31 +135,21 @@ double divSum(float* a, float* b, int n) { //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: i--; //CHECK-NEXT: { +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: sum = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_sum; //CHECK-NEXT: _d_a[i] += _r_d0 / b[i]; //CHECK-NEXT: double _r0 = _r_d0 * -a[i] / (b[i] * b[i]); //CHECK-NEXT: _d_b[i] += _r0; -//CHECK-NEXT: double _r1 = clad::pop(_EERepl_sum1); -//CHECK-NEXT: _delta_sum += std::abs(_r_d0 * _r1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK-NEXT: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * {{.+}}); -//CHECK-NEXT: clad::array _delta_a(_d_a.size()); +//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}}); //CHECK-NEXT: int i0 = 0; -//CHECK-NEXT: for (; i0 < _d_a.size(); i0++) { -//CHECK-NEXT: double _t2 = std::abs(_d_a[i0] * a[i0] * {{.+}}); -//CHECK-NEXT: _delta_a[i0] += _t2; -//CHECK-NEXT: _final_error += _t2; -//CHECK-NEXT: } -//CHECK-NEXT: clad::array _delta_b(_d_b.size()); +//CHECK-NEXT: for (; i0 < _d_a.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_a[i0] * a[i0] * {{.+}}); //CHECK-NEXT: i0 = 0; -//CHECK-NEXT: for (; i0 < _d_b.size(); i0++) { -//CHECK-NEXT: double _t3 = std::abs(_d_b[i0] * b[i0] * {{.+}}); -//CHECK-NEXT: _delta_b[i0] += _t3; -//CHECK-NEXT: _final_error += _t3; -//CHECK-NEXT: } -//CHECK-NEXT: _final_error += _delta_sum; +//CHECK-NEXT: for (; i0 < _d_b.size(); i0++) +//CHECK-NEXT: _final_error += std::abs(_d_b[i0] * b[i0] * {{.+}}); //CHECK-NEXT: } int main() { diff --git a/test/FirstDerivative/CodeGenSimple.C b/test/FirstDerivative/CodeGenSimple.C index 02a815c92..4ff77e806 100644 --- a/test/FirstDerivative/CodeGenSimple.C +++ b/test/FirstDerivative/CodeGenSimple.C @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...); int f_1_darg0(int x); +double sq_defined_later(double); + int main() { int x = 4; clad::differentiate(f_1, 0); + auto df = clad::differentiate(sq_defined_later, "x"); printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2 + printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6 return 0; } + +double sq_defined_later(double x) { + return x * x; +} diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index 953ccc6cf..a3654f503 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -175,6 +175,7 @@ int main () { clad::differentiate(test_6, "x"); clad::differentiate(test_7, "i"); clad::differentiate(test_8, "x"); - + clad::differentiate(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}} + clad::differentiate(test_8); // expected-error {{Both enable and disable TBR options are specified.}} return 0; } diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index ac30263e1..c0d180c2d 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -280,7 +280,7 @@ double f_cond3(double x, double c) { //CHECK-NEXT: } //CHECK-NEXT: } -double f_cond3_grad(double x, double c, clad::array_ref _d_x, clad::array_ref _d_y); +void f_cond3_grad(double x, double c, clad::array_ref _d_x, clad::array_ref _d_y); double f_cond4(double x, double y) { int i = 0; @@ -321,7 +321,7 @@ double f_cond4(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } -double f_cond4_grad(double x, double c, clad::array_ref _d_x, clad::array_ref _d_y); +void f_cond4_grad(double x, double c, clad::array_ref _d_x, clad::array_ref _d_y); double f_if1(double x, double y) { if (x > y) @@ -345,7 +345,7 @@ double f_if1(double x, double y) { //CHECK-NEXT: * _d_y += 1; //CHECK-NEXT: } -double f_if1_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y); +void f_if1_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y); double f_if2(double x, double y) { if (x > y) diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index cf5fd322b..294325ac6 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -617,7 +617,7 @@ int main() { d_structPointer.execute(5, &d_x); printf("%.2f\n", d_x); // CHECK-EXEC: 1.00 - auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x"); + auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x"); d_x = 0; d_cStyleMemoryAlloc.execute(5, 7, &d_x); printf("%.2f\n", d_x); // CHECK-EXEC: 4.00 diff --git a/test/Misc/Args.C b/test/Misc/Args.C index 58e44c751..35b7c3e5f 100644 --- a/test/Misc/Args.C +++ b/test/Misc/Args.C @@ -5,6 +5,9 @@ // CHECK_HELP-NEXT: -fdump-derived-fn // CHECK_HELP-NEXT: -fdump-derived-fn-ast // CHECK_HELP-NEXT: -fgenerate-source-file +// CHECK_HELP-NEXT: -fno-validate-clang-version +// CHECK_HELP-NEXT: -enable-tbr +// CHECK_HELP-NEXT: -disable-tbr // CHECK_HELP-NEXT: -fcustom-estimation-model // CHECK_HELP-NEXT: -fprint-num-diff-errors // CHECK_HELP-NEXT: -help @@ -23,3 +26,7 @@ // RUN: -Xclang %t.so %S/../../demos/ErrorEstimation/CustomModel/test.cpp \ // RUN: -I%S/../../include 2>&1 | FileCheck --check-prefix=CHECK_SO_INVALID %s // CHECK_SO_INVALID: Failed to load '{{.*.so}}', {{.*}}. Aborting. + +// RUN: clang -fsyntax-only -fplugin=%cladlib -Xclang -plugin-arg-clad -Xclang -enable-tbr \ +// RUN: -Xclang -plugin-arg-clad -Xclang -disable-tbr %s 2>&1 | FileCheck --check-prefix=CHECK_TBR %s +// CHECK_TBR: -enable-tbr and -disable-tbr cannot be used together \ No newline at end of file diff --git a/test/Misc/ClangConsumers.cpp b/test/Misc/ClangConsumers.cpp new file mode 100644 index 000000000..210c3060d --- /dev/null +++ b/test/Misc/ClangConsumers.cpp @@ -0,0 +1,81 @@ +// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \ +// RUN: -fms-compatibility -DMS_COMPAT -std=c++14 -fmodules \ +// RUN: -Xclang -print-stats 2>&1 | FileCheck %s +// CHECK-NOT: {{.*error|warning|note:.*}} +// +// RUN: clang -xc -Xclang -add-plugin -Xclang clad -Xclang -load \ +// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \ +// RUN: -Xclang -debug-info-kind=limited -Xclang -triple -Xclang bpf-linux-gnu \ +// RUN: -S -emit-llvm -Xclang -target-cpu -Xclang generic \ +// RUN: -Xclang -print-stats 2>&1 | \ +// RUN: FileCheck -check-prefix=CHECK_C %s +// CHECK_C-NOT: {{.*error|warning|note:.*}} +// XFAIL: clang-7, clang-8, clang-9, target={{i586.*}}, target=arm64-apple-{{.*}} +// +// RUN: clang -xobjective-c -Xclang -add-plugin -Xclang clad -Xclang -load \ +// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \ +// RUN: -Xclang -print-stats 2>&1 | \ +// RUN: FileCheck -check-prefix=CHECK_OBJC %s +// CHECK_OBJC-NOT: {{.*error|warning|note:.*}} + +#ifdef __cplusplus + +#pragma clang module build N + module N {} + #pragma clang module contents + #pragma clang module begin N + struct f { void operator()() const {} }; + template auto vtemplate = f{}; + #pragma clang module end +#pragma clang module endbuild + +#pragma clang module import N + +#ifdef MS_COMPAT +class __single_inheritance IncSingle; +#endif // MS_COMPAT + +struct V { virtual int f(); }; +int V::f() { return 1; } +template T f() { return T(); } +int i = f(); + +// Check if shouldSkipFunctionBody is called. +// RUN: %cladclang -I%S/../../include -fsyntax-only -fmodules \ +// RUN: -Xclang -code-completion-at=%s:%(line-1):1 %s -o - | \ +// RUN: FileCheck -check-prefix=CHECK-CODECOMP %s +// CHECK-CODECOMP: COMPLETION + +// CHECK: HandleImplicitImportDecl +// CHECK: AssignInheritanceModel +// CHECK: HandleTopLevelDecl +// CHECK: HandleCXXImplicitFunctionInstantiation +// CHECK: HandleInterestingDecl +// CHECK: HandleVTable +// CHECK: HandleCXXStaticMemberVarInstantiation + +#endif // __cplusplus + +#ifdef __STDC_VERSION__ // C mode +int i; + +extern char ch; +int test(void) { return ch; } +char ch = 1; + +// CHECK_C: CompleteTentativeDefinition +// CHECK_C: CompleteExternalDeclaration +#endif // __STDC_VERSION__ + +#ifdef __OBJC__ +@interface I +void f(); +@end +// CHECK_OBJC: HandleTopLevelDeclInObjCContainer +#endif // __OBJC__ + +int main() { +#ifdef __cplusplus + vtemplate(); +#endif // __cplusplus +} diff --git a/test/Misc/RunDemos.C b/test/Misc/RunDemos.C index c72587d93..93cee890f 100644 --- a/test/Misc/RunDemos.C +++ b/test/Misc/RunDemos.C @@ -109,21 +109,16 @@ //CHECK_FLOAT_SUM: void vanillaSum_grad(float x, unsigned int n, clad::array_ref _d_x, clad::array_ref _d_n, double &_final_error) { //CHECK_FLOAT_SUM: float _d_sum = 0; -//CHECK_FLOAT_SUM: double _delta_sum = 0; -//CHECK_FLOAT_SUM: float _EERepl_sum0; //CHECK_FLOAT_SUM: unsigned long _t0; //CHECK_FLOAT_SUM: unsigned int _d_i = 0; //CHECK_FLOAT_SUM: unsigned int i = 0; //CHECK_FLOAT_SUM: clad::tape _t1 = {}; -//CHECK_FLOAT_SUM: clad::tape _EERepl_sum1 = {}; //CHECK_FLOAT_SUM: float sum = 0.; -//CHECK_FLOAT_SUM: _EERepl_sum0 = sum; //CHECK_FLOAT_SUM: _t0 = 0; //CHECK_FLOAT_SUM: for (i = 0; i < n; i++) { //CHECK_FLOAT_SUM: _t0++; //CHECK_FLOAT_SUM: clad::push(_t1, sum); //CHECK_FLOAT_SUM: sum = sum + x; -//CHECK_FLOAT_SUM: clad::push(_EERepl_sum1, sum); //CHECK_FLOAT_SUM: } //CHECK_FLOAT_SUM: goto _label0; //CHECK_FLOAT_SUM: _label0: @@ -131,19 +126,16 @@ //CHECK_FLOAT_SUM: for (; _t0; _t0--) { //CHECK_FLOAT_SUM: i--; //CHECK_FLOAT_SUM: { +//CHECK_FLOAT_SUM: _final_error += std::abs(_d_sum * sum * 1.1920928955078125E-7); //CHECK_FLOAT_SUM: sum = clad::pop(_t1); //CHECK_FLOAT_SUM: float _r_d0 = _d_sum; //CHECK_FLOAT_SUM: _d_sum -= _r_d0; //CHECK_FLOAT_SUM: _d_sum += _r_d0; //CHECK_FLOAT_SUM: * _d_x += _r_d0; -//CHECK_FLOAT_SUM: float _r0 = clad::pop(_EERepl_sum1); -//CHECK_FLOAT_SUM: _delta_sum += std::abs(_r_d0 * _r0 * 1.1920928955078125E-7); //CHECK_FLOAT_SUM: } //CHECK_FLOAT_SUM: } -//CHECK_FLOAT_SUM: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * 1.1920928955078125E-7); -//CHECK_FLOAT_SUM: double _delta_x = 0; -//CHECK_FLOAT_SUM: _delta_x += std::abs(* _d_x * x * 1.1920928955078125E-7); -//CHECK_FLOAT_SUM: _final_error += _delta_x + _delta_sum; +//CHECK_FLOAT_SUM: _final_error += std::abs(_d_sum * sum * 1.1920928955078125E-7); +//CHECK_FLOAT_SUM: _final_error += std::abs(* _d_x * x * 1.1920928955078125E-7); //CHECK_FLOAT_SUM: } //-----------------------------------------------------------------------------/ @@ -161,31 +153,23 @@ // CHECK_CUSTOM_MODEL_EXEC: The code is: // CHECK_CUSTOM_MODEL_EXEC-NEXT: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { // CHECK_CUSTOM_MODEL_EXEC-NEXT: float _d_z = 0; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: double _delta_z = 0; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: float _EERepl_z0; // CHECK_CUSTOM_MODEL_EXEC-NEXT: float _t0; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: float _EERepl_z1; // CHECK_CUSTOM_MODEL_EXEC-NEXT: float z; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: _EERepl_z0 = z; // CHECK_CUSTOM_MODEL_EXEC-NEXT: _t0 = z; // CHECK_CUSTOM_MODEL_EXEC-NEXT: z = x + y; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: _EERepl_z1 = z; // CHECK_CUSTOM_MODEL_EXEC-NEXT: goto _label0; // CHECK_CUSTOM_MODEL_EXEC-NEXT: _label0: // CHECK_CUSTOM_MODEL_EXEC-NEXT: _d_z += 1; // CHECK_CUSTOM_MODEL_EXEC-NEXT: { +// CHECK_CUSTOM_MODEL_EXEC-NEXT: _final_error += _d_z * z; // CHECK_CUSTOM_MODEL_EXEC-NEXT: z = _t0; // CHECK_CUSTOM_MODEL_EXEC-NEXT: float _r_d0 = _d_z; // CHECK_CUSTOM_MODEL_EXEC-NEXT: _d_z -= _r_d0; // CHECK_CUSTOM_MODEL_EXEC-NEXT: * _d_x += _r_d0; // CHECK_CUSTOM_MODEL_EXEC-NEXT: * _d_y += _r_d0; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: _delta_z += _r_d0 * _EERepl_z1; // CHECK_CUSTOM_MODEL_EXEC-NEXT: } -// CHECK_CUSTOM_MODEL_EXEC-NEXT: double _delta_x = 0; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: _delta_x += * _d_x * x; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: double _delta_y = 0; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: _delta_y += * _d_y * y; -// CHECK_CUSTOM_MODEL_EXEC-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +// CHECK_CUSTOM_MODEL_EXEC-NEXT: _final_error += * _d_x * x; +// CHECK_CUSTOM_MODEL_EXEC-NEXT: _final_error += * _d_y * y; // CHECK_CUSTOM_MODEL_EXEC-NEXT: } //-----------------------------------------------------------------------------/ @@ -203,31 +187,23 @@ // CHECK_PRINT_MODEL_EXEC: The code is: // CHECK_PRINT_MODEL_EXEC-NEXT: void func_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_y, double &_final_error) { // CHECK_PRINT_MODEL_EXEC-NEXT: float _d_z = 0; -// CHECK_PRINT_MODEL_EXEC-NEXT: double _delta_z = 0; -// CHECK_PRINT_MODEL_EXEC-NEXT: float _EERepl_z0; // CHECK_PRINT_MODEL_EXEC-NEXT: float _t0; -// CHECK_PRINT_MODEL_EXEC-NEXT: float _EERepl_z1; // CHECK_PRINT_MODEL_EXEC-NEXT: float z; -// CHECK_PRINT_MODEL_EXEC-NEXT: _EERepl_z0 = z; // CHECK_PRINT_MODEL_EXEC-NEXT: _t0 = z; // CHECK_PRINT_MODEL_EXEC-NEXT: z = x + y; -// CHECK_PRINT_MODEL_EXEC-NEXT: _EERepl_z1 = z; // CHECK_PRINT_MODEL_EXEC-NEXT: goto _label0; // CHECK_PRINT_MODEL_EXEC-NEXT: _label0: // CHECK_PRINT_MODEL_EXEC-NEXT: _d_z += 1; // CHECK_PRINT_MODEL_EXEC-NEXT: { +// CHECK_PRINT_MODEL_EXEC-NEXT: _final_error += clad::getErrorVal(_d_z, z, "z"); // CHECK_PRINT_MODEL_EXEC-NEXT: z = _t0; // CHECK_PRINT_MODEL_EXEC-NEXT: float _r_d0 = _d_z; // CHECK_PRINT_MODEL_EXEC-NEXT: _d_z -= _r_d0; // CHECK_PRINT_MODEL_EXEC-NEXT: * _d_x += _r_d0; // CHECK_PRINT_MODEL_EXEC-NEXT: * _d_y += _r_d0; -// CHECK_PRINT_MODEL_EXEC-NEXT: _delta_z += clad::getErrorVal(_r_d0, _EERepl_z1, "z"); // CHECK_PRINT_MODEL_EXEC-NEXT: } -// CHECK_PRINT_MODEL_EXEC-NEXT: double _delta_x = 0; -// CHECK_PRINT_MODEL_EXEC-NEXT: _delta_x += clad::getErrorVal(* _d_x, x, "x"); -// CHECK_PRINT_MODEL_EXEC-NEXT: double _delta_y = 0; -// CHECK_PRINT_MODEL_EXEC-NEXT: _delta_y += clad::getErrorVal(* _d_y, y, "y"); -// CHECK_PRINT_MODEL_EXEC-NEXT: _final_error += _delta_{{x|y|z}} + _delta_{{x|y|z}} + _delta_{{x|y|z}}; +// CHECK_PRINT_MODEL_EXEC-NEXT: _final_error += clad::getErrorVal(* _d_x, x, "x"); +// CHECK_PRINT_MODEL_EXEC-NEXT: _final_error += clad::getErrorVal(* _d_y, y, "y"); // CHECK_PRINT_MODEL_EXEC-NEXT: } // CHECK_PRINT_MODEL_EXEC: Error in z : {{.+}} // CHECK_PRINT_MODEL_EXEC-NEXT: Error in x : {{.+}} @@ -301,11 +277,11 @@ // CHECK_ARRAYS_EXEC: {0.33, 0, 0, 0, 0, 0} // CHECK_ARRAYS_EXEC: {0, 0.33, 0, 0, 0, 0} // CHECK_ARRAYS_EXEC: {0, 0, 0.33, 0, 0, 0} -// CHECK_ARRAYS_EXEC: Hessian Mode w.r.t. to arr: -// CHECK_ARRAYS_EXEC: matrix = -// CHECK_ARRAYS_EXEC: {0, 0, 0} -// CHECK_ARRAYS_EXEC: {0, 0, 0} -// CHECK_ARRAYS_EXEC: {0, 0, 0} +// CHECK_ARRAYS_EXEC-FAIL: Hessian Mode w.r.t. to arr: +// CHECK_ARRAYS_EXEC-FAIL: matrix = +// CHECK_ARRAYS_EXEC-FAIL: {0, 0, 0} +// CHECK_ARRAYS_EXEC-FAIL: {0, 0, 0} +// CHECK_ARRAYS_EXEC-FAIL: {0, 0, 0} //-----------------------------------------------------------------------------/ // Demo: VectorForwardMode.cpp diff --git a/test/lit.cfg b/test/lit.cfg index 014f22024..8c3271983 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -294,6 +294,9 @@ if.*\[ ?(llvm[^ ]*) ([^ ]*) ?\].*{ if platform.system() not in ['Windows'] or lit_config.getBashPath() != '': config.available_features.add('shell') + +config.available_features.add("clang-{0}".format(config.clang_version_major)) + # Loadable module # FIXME: This should be supplied by Makefile or autoconf. #if sys.platform in ['win32', 'cygwin']: diff --git a/test/lit.site.cfg.in b/test/lit.site.cfg.in index 6c36bb63a..868d2db09 100644 --- a/test/lit.site.cfg.in +++ b/test/lit.site.cfg.in @@ -3,6 +3,7 @@ import sys ## Autogenerated by LLVM/clad configuration. # Do not edit! llvm_version_major = @LLVM_VERSION_MAJOR@ +config.clang_version_major = @CLANG_VERSION_MAJOR@ config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0f55cf065..8be825c3f 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -91,60 +92,48 @@ namespace clad { CladPlugin::~CladPlugin() {} - // We cannot use HandleTranslationUnit because codegen already emits code on - // HandleTopLevelDecl calls and makes updateCall with no effect. - bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { + ALLOW_ACCESS(MultiplexConsumer, Consumers, + std::vector>); + + void CladPlugin::Initialize(clang::ASTContext& C) { + // We know we have a multiplexer. We commit a sin here by stealing it and + // making the consumer pass-through so that we can delay all operations + // until clad is happy. + + auto& MultiplexC = cast(m_CI.getASTConsumer()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); + assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); + std::vector> StolenConsumers; + + // The range-based for loop in MultiplexConsumer::Initialize has + // dispatched this call. Generally, it is unsafe to delete elements while + // iterating but we know we are in the end of the loop and ::end() won't + // be invalidated. + std::move(RobbedCs.begin(), RobbedCs.end() - 1, + std::back_inserter(StolenConsumers)); + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); + m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); + } + + void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) { if (!CheckBuiltins()) - return true; + return; Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return true; - - DiffSchedule requests{}; - DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); - - if (requests.empty()) - return true; + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); - // FIXME: flags have to be set manually since DiffCollector's constructor - // does not have access to m_DO. - if (m_DO.EnableTBRAnalysis) - for (DiffRequest& request : requests) - request.EnableTBRAnalysis = true; - - // FIXME: Remove the PerformPendingInstantiations altogether. We should - // somehow make the relevant functions referenced. - // Instantiate all pending for instantiations templates, because we will - // need the full bodies to produce derivatives. - // FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`? - if (!m_PendingInstantiationsInFlight) { - m_PendingInstantiationsInFlight = true; - S.PerformPendingInstantiations(); - m_PendingInstantiationsInFlight = false; - } - - for (DiffRequest& request : requests) - ProcessDiffRequest(request); - return true; // Happiness - } - - void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + RequestOptions opts{}; + SetRequestOptions(opts); + DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S, opts); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { Sema& S = m_CI.getSema(); // Required due to custom derivatives function templates that might be // used in the function that we need to derive. + // FIXME: Remove the call to PerformPendingInstantiations(). S.PerformPendingInstantiations(); if (request.Function->getDefinition()) request.Function = request.Function->getDefinition(); @@ -267,6 +256,8 @@ namespace clad { // Call CodeGen only if the produced Decl is a top-most // decl or is contained in a namespace decl. + // FIXME: We could get rid of this by prepending the produced + // derivatives in CladPlugin::HandleTranslationUnitDecl DeclContext* derivativeDC = DerivativeDecl->getDeclContext(); bool isTUorND = derivativeDC->isTranslationUnit() || derivativeDC->isNamespace(); @@ -296,6 +287,68 @@ namespace clad { return nullptr; } + void CladPlugin::SendToMultiplexer() { + for (auto DelayedCall : m_DelayedCalls) { + DeclGroupRef& D = DelayedCall.m_DGR; + switch (DelayedCall.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + m_Multiplexer->HandleCXXStaticMemberVarInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDecl: + m_Multiplexer->HandleTopLevelDecl(D); + break; + case CallKind::HandleInlineFunctionDefinition: + m_Multiplexer->HandleInlineFunctionDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleInterestingDecl: + m_Multiplexer->HandleInterestingDecl(D); + break; + case CallKind::HandleTagDeclDefinition: + m_Multiplexer->HandleTagDeclDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTagDeclRequiredDefinition: + m_Multiplexer->HandleTagDeclRequiredDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + m_Multiplexer->HandleCXXImplicitFunctionInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + break; + case CallKind::HandleImplicitImportDecl: + m_Multiplexer->HandleImplicitImportDecl( + cast(D.getSingleDecl())); + break; + case CallKind::CompleteTentativeDefinition: + m_Multiplexer->CompleteTentativeDefinition( + cast(D.getSingleDecl())); + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + m_Multiplexer->CompleteExternalDeclaration( + cast(D.getSingleDecl())); + break; +#endif + case CallKind::AssignInheritanceModel: + m_Multiplexer->AssignInheritanceModel( + cast(D.getSingleDecl())); + break; + case CallKind::HandleVTable: + m_Multiplexer->HandleVTable(cast(D.getSingleDecl())); + break; + case CallKind::InitializeSema: + m_Multiplexer->InitializeSema(m_CI.getSema()); + break; + }; + } + m_HasMultiplexerProcessedDelayedCalls = true; + } + bool CladPlugin::CheckBuiltins() { // If we have included "clad/Differentiator/Differentiator.h" return. if (m_HasRuntime) @@ -318,6 +371,105 @@ namespace clad { m_HasRuntime = !R.empty(); return m_HasRuntime; } + + static void SetTBRAnalysisOptions(const DifferentiationOptions& DO, + RequestOptions& opts) { + // If user has explicitly specified the mode for TBR analysis, use it. + if (DO.EnableTBRAnalysis || DO.DisableTBRAnalysis) + opts.EnableTBRAnalysis = DO.EnableTBRAnalysis && !DO.DisableTBRAnalysis; + else + opts.EnableTBRAnalysis = false; // Default mode. + } + + void CladPlugin::SetRequestOptions(RequestOptions& opts) const { + SetTBRAnalysisOptions(m_DO, opts); + } + + void CladPlugin::HandleTranslationUnit(ASTContext& C) { + Sema& S = m_CI.getSema(); + // Restore the TUScope that became a 0 in Sema::ActOnEndOfTranslationUnit. + S.TUScope = m_StoredTUScope; + constexpr bool Enabled = true; + Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); + Sema::LocalEagerInstantiationScope LocalInstantiations(S); + + for (DiffRequest& request : m_DiffSchedule) { + // FIXME: flags have to be set manually since DiffCollector's + // constructor does not have access to m_DO. + request.EnableTBRAnalysis = m_DO.EnableTBRAnalysis; + ProcessDiffRequest(request); + } + // Put the TUScope in a consistent state after clad is done. + S.TUScope = nullptr; + // Force emission of the produced pending template instantiations. + LocalInstantiations.perform(); + GlobalInstantiations.perform(); + + SendToMultiplexer(); + m_Multiplexer->HandleTranslationUnit(C); + } + + void CladPlugin::PrintStats() { + llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n"; + for (const DelayedCallInfo& DCI : m_DelayedCalls) { + llvm::errs() << " "; + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; + break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; + break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; + break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; + break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; + break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; + break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; + break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; + break; +#endif + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; + break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; + break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; + break; + }; + for (const clang::Decl* D : DCI.m_DGR) { + llvm::errs() << " " << D; + if (const auto* ND = dyn_cast(D)) + llvm::errs() << " " << ND->getNameAsString(); + } + llvm::errs() << "\n"; + } + + m_Multiplexer->PrintStats(); + } + } // end namespace plugin clad::CladTimerGroup::CladTimerGroup() diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 808443e49..c152a0482 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -13,11 +13,12 @@ #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" -#include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Version.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Sema/SemaConsumer.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -74,42 +75,212 @@ namespace clad { namespace plugin { struct DifferentiationOptions { - DifferentiationOptions() - : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), - DumpDerivedAST(false), GenerateSourceFile(false), - ValidateClangVersion(true), EnableTBRAnalysis(false), - CustomEstimationModel(false), PrintNumDiffErrorInfo(false) {} - - bool DumpSourceFn : 1; - bool DumpSourceFnAST : 1; - bool DumpDerivedFn : 1; - bool DumpDerivedAST : 1; - bool GenerateSourceFile : 1; - bool ValidateClangVersion : 1; - bool EnableTBRAnalysis : 1; - bool CustomEstimationModel : 1; - bool PrintNumDiffErrorInfo : 1; - std::string CustomModelName; + DifferentiationOptions() + : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), + DumpDerivedAST(false), GenerateSourceFile(false), + ValidateClangVersion(true), EnableTBRAnalysis(false), + DisableTBRAnalysis(false), CustomEstimationModel(false), + PrintNumDiffErrorInfo(false) {} + + bool DumpSourceFn : 1; + bool DumpSourceFnAST : 1; + bool DumpDerivedFn : 1; + bool DumpDerivedAST : 1; + bool GenerateSourceFile : 1; + bool ValidateClangVersion : 1; + bool EnableTBRAnalysis : 1; + bool DisableTBRAnalysis : 1; + bool CustomEstimationModel : 1; + bool PrintNumDiffErrorInfo : 1; + std::string CustomModelName; }; - class CladPlugin : public clang::ASTConsumer { - clang::CompilerInstance& m_CI; - DifferentiationOptions m_DO; - std::unique_ptr m_DerivativeBuilder; - bool m_HasRuntime = false; - bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; - CladTimerGroup m_CTG; - DerivedFnCollector m_DFC; - public: - CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); - ~CladPlugin(); - bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; - clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); + class CladExternalSource : public clang::ExternalSemaSource { + // ExternalSemaSource + void ReadUndefinedButUsed( + llvm::MapVector& Undefined) + override { + // namespace { double f_darg0(double x); } will issue a warning that + // f_darg0 has internal linkage but is not defined. This is because we + // have not yet started to differentiate it. The warning is triggered by + // Sema::ActOnEndOfTranslationUnit before Clad is given control. + // To avoid the warning we should remove the entry from here. + using namespace clang; + Undefined.remove_if([](std::pair P) { + NamedDecl* ND = P.first; - private: - bool CheckBuiltins(); - void ProcessTopLevelDecl(clang::Decl* D); + if (!ND->getDeclName().isIdentifier()) + return false; + + // FIXME: We should replace this comparison with the canonical decl + // from the differentiation plan... + llvm::StringRef Name = ND->getName(); + return Name.contains("_darg") || Name.contains("_grad") || + Name.contains("_hessian") || Name.contains("_jacobian"); + }); + } + }; + class CladPlugin : public clang::SemaConsumer { + clang::CompilerInstance& m_CI; + DifferentiationOptions m_DO; + std::unique_ptr m_DerivativeBuilder; + bool m_HasRuntime = false; + CladTimerGroup m_CTG; + DerivedFnCollector m_DFC; + DiffSchedule m_DiffSchedule; + enum class CallKind { + HandleCXXStaticMemberVarInstantiation, + HandleTopLevelDecl, + HandleInlineFunctionDefinition, + HandleInterestingDecl, + HandleTagDeclDefinition, + HandleTagDeclRequiredDefinition, + HandleCXXImplicitFunctionInstantiation, + HandleTopLevelDeclInObjCContainer, + HandleImplicitImportDecl, + CompleteTentativeDefinition, +#if CLANG_VERSION_MAJOR > 9 + CompleteExternalDeclaration, +#endif + AssignInheritanceModel, + HandleVTable, + InitializeSema, + }; + struct DelayedCallInfo { + CallKind m_Kind; + clang::DeclGroupRef m_DGR; + DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR) + : m_Kind(K), m_DGR(DGR) {} + DelayedCallInfo(CallKind K, const clang::Decl* D) + : m_Kind(K), m_DGR(const_cast(D)) {} + bool operator==(const DelayedCallInfo& other) const { + if (m_Kind != other.m_Kind) + return false; + + if (std::distance(m_DGR.begin(), m_DGR.end()) != + std::distance(other.m_DGR.begin(), other.m_DGR.end())) + return false; + + clang::Decl* const* first1 = m_DGR.begin(); + clang::Decl* const* first2 = other.m_DGR.begin(); + clang::Decl* const* last1 = m_DGR.end(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + for (; first1 != last1; ++first1, ++first2) + if (!(*first1 == *first2)) + return false; + return true; + } + }; + /// The calls to the main action which clad delayed and will dispatch at + /// then end of the translation unit. + std::vector m_DelayedCalls; + /// The default clang consumers which are called after clad is done. + std::unique_ptr m_Multiplexer; + + /// Have we processed all delayed calls. + bool m_HasMultiplexerProcessedDelayedCalls = false; + + /// The Sema::TUScope to restore in CladPlugin::HandleTranslationUnit. + clang::Scope* m_StoredTUScope = nullptr; + + public: + CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); + ~CladPlugin() override; + // ASTConsumer + void Initialize(clang::ASTContext& Context) override; + void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { + AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); + } + bool HandleTopLevelDecl(clang::DeclGroupRef D) override { + HandleTopLevelDeclForClad(D); + AppendDelayed({CallKind::HandleTopLevelDecl, D}); + return true; // happyness, continue parsing + } + void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); + } + void HandleInterestingDecl(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleInterestingDecl, D}); + } + void HandleTagDeclDefinition(clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclDefinition, D}); + } + void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); + } + void + HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); + } + void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); + } + void HandleImplicitImportDecl(clang::ImportDecl* D) override { + AppendDelayed({CallKind::HandleImplicitImportDecl, D}); + } + void CompleteTentativeDefinition(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteTentativeDefinition, D}); + } +#if CLANG_VERSION_MAJOR > 9 + void CompleteExternalDeclaration(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteExternalDeclaration, D}); + } +#endif + void AssignInheritanceModel(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::AssignInheritanceModel, D}); + } + void HandleVTable(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::HandleVTable, D}); + } + + // Not delayed. + void HandleTranslationUnit(clang::ASTContext& C) override; + + // No need to handle the listeners, they will be handled non-delayed by + // the parent multiplexer. + // + // clang::ASTMutationListener *GetASTMutationListener() override; + // clang::ASTDeserializationListener *GetASTDeserializationListener() + // override; + void PrintStats() override; + + bool shouldSkipFunctionBody(clang::Decl* D) override { + return m_Multiplexer->shouldSkipFunctionBody(D); + } + + // SemaConsumer + void InitializeSema(clang::Sema& S) override { + // We are also a ExternalSemaSource. + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + S.addExternalSource(new CladExternalSource()); // Owned by Sema. + m_StoredTUScope = S.TUScope; + AppendDelayed({CallKind::InitializeSema, nullptr}); + } + void ForgetSema() override { + // ForgetSema is called in the destructor of Sema which is much later + // than where we can process anything. We can't delay this call. + m_Multiplexer->ForgetSema(); + } + + // FIXME: We should hide ProcessDiffRequest when we implement proper + // handling of the differentiation plans. + clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); + + private: + void AppendDelayed(DelayedCallInfo DCI) { + assert(!m_HasMultiplexerProcessedDelayedCalls); + m_DelayedCalls.push_back(DCI); + } + void SendToMultiplexer(); + bool CheckBuiltins(); + void SetRequestOptions(RequestOptions& opts) const; + + void ProcessTopLevelDecl(clang::Decl* D) { + DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; + assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); + AppendDelayed(DCI); + } + void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR); }; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, @@ -146,6 +317,8 @@ namespace clad { m_DO.ValidateClangVersion = false; } else if (args[i] == "-enable-tbr") { m_DO.EnableTBRAnalysis = true; + } else if (args[i] == "-disable-tbr") { + m_DO.DisableTBRAnalysis = true; } else if (args[i] == "-fcustom-estimation-model") { m_DO.CustomEstimationModel = true; if (++i == e) { @@ -170,6 +343,14 @@ namespace clad { "derivative.\n" << "-fgenerate-source-file - Produces a file containing the " "derivatives.\n" + << "-fno-validate-clang-version - Disables the validation of " + "the clang version.\n" + << "-enable-tbr - Ensures that TBR analysis is enabled during " + "reverse-mode differentiation unless explicitly specified " + "in an individual request.\n" + << "-disable-tbr - Ensures that TBR analysis is disabled " + "during reverse-mode differentiation unless explicitly " + "specified in an individual request.\n" << "-fcustom-estimation-model - allows user to send in a " "shared object to use as the custom estimation model.\n" << "-fprint-num-diff-errors - allows users to print the " @@ -186,11 +367,16 @@ namespace clad { if (!checkClangVersion()) return false; } + if (m_DO.EnableTBRAnalysis && m_DO.DisableTBRAnalysis) { + llvm::errs() << "clad: Error: -enable-tbr and -disable-tbr cannot " + "be used together.\n"; + return false; + } return true; } PluginASTAction::ActionType getActionType() override { - return AddBeforeMainAction; + return AddAfterMainAction; } }; } // end namespace plugin