From 2734b033dad8dea317b8c766e117fb205228d164 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Fri, 27 Oct 2023 02:49:01 +0530 Subject: [PATCH] include algorithm --- .github/workflows/ci.yml | 732 +++++++++--------- .../clad/Differentiator/BuiltinDerivatives.h | 55 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 18 +- lib/Differentiator/ReverseModeVisitor.cpp | 14 + test/FirstDerivative/FunctionCalls.C | 1 + 5 files changed, 429 insertions(+), 391 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c1689bde..10d3f6292 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,377 +24,377 @@ jobs: matrix: include: - - name: osx-clang-runtime9 - os: macos-11 - compiler: clang - clang-runtime: '9' - - - name: osx-clang-runtime10 - os: macos-11 - compiler: clang - clang-runtime: '10' - - - name: osx-clang-runtime11 - os: macos-latest - compiler: clang - clang-runtime: '11' - - - name: osx-clang-runtime12 - os: macos-latest - compiler: clang - clang-runtime: '12' - - - name: osx-clang-runtime13 - os: macos-latest - compiler: clang - clang-runtime: '13' - - - name: osx-clang-runtime14 - os: macos-latest - compiler: clang - clang-runtime: '14' - - - name: osx-clang-runtime15 - os: macos-latest - compiler: clang - clang-runtime: '15' + # - name: osx-clang-runtime9 + # os: macos-11 + # compiler: clang + # clang-runtime: '9' + + # - name: osx-clang-runtime10 + # os: macos-11 + # compiler: clang + # clang-runtime: '10' + + # - name: osx-clang-runtime11 + # os: macos-latest + # compiler: clang + # clang-runtime: '11' + + # - name: osx-clang-runtime12 + # os: macos-latest + # compiler: clang + # clang-runtime: '12' + + # - name: osx-clang-runtime13 + # os: macos-latest + # compiler: clang + # clang-runtime: '13' + + # - name: osx-clang-runtime14 + # os: macos-latest + # compiler: clang + # clang-runtime: '14' + + # - name: osx-clang-runtime15 + # os: macos-latest + # compiler: clang + # clang-runtime: '15' - name: osx-clang-runtime16 os: macos-latest compiler: clang clang-runtime: '16' - - name: win-msvc-runtime14 - os: windows-latest - compiler: msvc - clang-runtime: '14' - - - name: win-msvc-runtime15 - os: windows-latest - compiler: msvc - clang-runtime: '15' - - - name: win-msvc-runtime16 - os: windows-latest - compiler: msvc - clang-runtime: '16' - - - name: ubu20-gcc7-runtime8 - os: ubuntu-20.04 - compiler: gcc-7 - clang-runtime: '8' - - - name: ubu20-gcc7-runtime11-analyzers - os: ubuntu-20.04 - compiler: gcc-7 - clang-runtime: '11' - coverage: true - cuda: true - extra_cmake_options: '-DLLVM_ENABLE_WERROR=On -DENABLE_ENZYME_BACKEND=On' - #clang-format: true - - - name: ubu20-gcc7-runtime11-benchmarks - os: ubuntu-20.04 - compiler: gcc-7 - clang-runtime: '11' - extra_cmake_options: '-DCLAD_ENABLE_BENCHMARKS=On -DENABLE_ENZYME_BACKEND=On' - benchmark: true - - - name: ubu20-gcc8-runtime11-coverity - os: ubuntu-20.04 - compiler: gcc-8 - clang-runtime: '11' - coverity: true - - - name: ubu20-gcc9-runtime9 - os: ubuntu-20.04 - compiler: gcc-9 - clang-runtime: '9' - - - name: ubu20-clang8-runtime9 - os: ubuntu-20.04 - compiler: clang-8 - clang-runtime: '9' - - - name: ubu20-clang9-runtime9 - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '9' - - - name: ubu20-gcc7-runtime10 - os: ubuntu-20.04 - compiler: gcc-7 - clang-runtime: '10' - - - name: ubu20-gcc8-runtime10 - os: ubuntu-20.04 - compiler: gcc-8 - clang-runtime: '10' - - - name: ubu20-gcc9-runtime10 - os: ubuntu-20.04 - compiler: gcc-9 - clang-runtime: '10' - - - name: ubu20-clang8-runtime10 - os: ubuntu-20.04 - compiler: clang-8 - clang-runtime: '10' - - - name: ubu20-clang9-runtime10 - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '10' - - - name: ubu20-gcc7-runtime11 - os: ubuntu-20.04 - compiler: gcc-7 - clang-runtime: '11' - - - name: ubu20-gcc8-runtime11 - os: ubuntu-20.04 - compiler: gcc-8 - clang-runtime: '11' - - - name: ubu20-gcc9-runtime11 - os: ubuntu-20.04 - compiler: gcc-9 - clang-runtime: '11' - - - name: ubu20-clang8-runtime11 - os: ubuntu-20.04 - compiler: clang-8 - clang-runtime: '11' - - - name: ubu20-clang9-runtime11 - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '11' - - - name: ubu20-gcc7-runtime12 - os: ubuntu-20.04 - compiler: gcc-7 - clang-runtime: '12' - - - name: ubu20-gcc8-runtime12 - os: ubuntu-20.04 - compiler: gcc-8 - clang-runtime: '12' - - - name: ubu20-gcc9-runtime12 - os: ubuntu-20.04 - compiler: gcc-9 - clang-runtime: '12' - - name: ubu20-clang8-runtime12 - os: ubuntu-20.04 - compiler: clang-8 - clang-runtime: '12' - - - name: ubu20-clang9-runtime12 - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '12' - - - name: ubu20-clang9-runtime10-cuda - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '10' - cuda: true - - - name: ubu20-clang9-runtime15 - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '15' - - - name: ubu20-clang9-runtime16 - os: ubuntu-20.04 - compiler: clang-9 - clang-runtime: '16' - - - name: ubu22-gcc9-runtime11 - os: ubuntu-22.04 - compiler: gcc-9 - clang-runtime: '11' - - - name: ubu22-gcc10-runtime11 - os: ubuntu-22.04 - compiler: gcc-10 - clang-runtime: '11' - - - name: ubu22-gcc11-runtime11 - os: ubuntu-22.04 - compiler: gcc-11 - clang-runtime: '11' - - - name: ubu22-clang11-runtime11 - os: ubuntu-22.04 - compiler: 'clang-11' - clang-runtime: '11' - - - name: ubu22-clang12-runtime11 - os: ubuntu-22.04 - compiler: clang-12 - clang-runtime: '11' - - - name: ubu22-clang13-runtime11 - os: ubuntu-22.04 - compiler: clang-13 - clang-runtime: '11' - - - name: ubu22-clang14-runtime11 - os: ubuntu-22.04 - compiler: clang-14 - clang-runtime: '11' - - - name: ubu22-clang15-runtime11 - os: ubuntu-22.04 - compiler: clang-15 - clang-runtime: '11' - - - name: ubu22-gcc9-runtime12 - os: ubuntu-22.04 - compiler: gcc-9 - clang-runtime: '12' - - - name: ubu22-gcc10-runtime12 - os: ubuntu-22.04 - compiler: gcc-10 - clang-runtime: '12' - - - name: ubu22-gcc11-runtime12 - os: ubuntu-22.04 - compiler: gcc-11 - clang-runtime: '12' - - - name: ubu22-clang11-runtime12 - os: ubuntu-22.04 - compiler: 'clang-11' - clang-runtime: '12' - - - name: ubu22-clang12-runtime12 - os: ubuntu-22.04 - compiler: clang-12 - clang-runtime: '12' - - - name: ubu22-clang13-runtime12 - os: ubuntu-22.04 - compiler: clang-13 - clang-runtime: '12' - - - name: ubu22-clang14-runtime12 - os: ubuntu-22.04 - compiler: clang-14 - clang-runtime: '12' - - - name: ubu22-clang15-runtime12 - os: ubuntu-22.04 - compiler: clang-15 - clang-runtime: '12' - - - name: ubu22-gcc10-runtime13 - os: ubuntu-22.04 - compiler: gcc-10 - clang-runtime: '13' - - - name: ubu22-gcc11-runtime13 - os: ubuntu-22.04 - compiler: gcc-11 - clang-runtime: '13' - - - name: ubu22-gcc12-runtime13 - os: ubuntu-22.04 - compiler: gcc-12 - clang-runtime: '13' - - - name: ubu22-clang11-runtime13 - os: ubuntu-22.04 - compiler: clang-11 - clang-runtime: '13' - - - name: ubu22-clang12-runtime13 - os: ubuntu-22.04 - compiler: clang-12 - clang-runtime: '13' - - - name: ubu22-clang13-runtime13 - os: ubuntu-22.04 - compiler: clang-13 - clang-runtime: '13' - - - name: ubu22-clang14-runtime13 - os: ubuntu-22.04 - compiler: clang-14 - clang-runtime: '13' - - - name: ubu22-clang15-runtime13 - os: ubuntu-22.04 - compiler: clang-15 - clang-runtime: '13' - - - name: ubu22-gcc10-runtime14 - os: ubuntu-22.04 - compiler: gcc-10 - clang-runtime: '14' - - - name: ubu22-gcc11-runtime14 - os: ubuntu-22.04 - compiler: gcc-11 - clang-runtime: '14' - - - name: ubu22-gcc12-runtime14 - os: ubuntu-22.04 - compiler: gcc-12 - clang-runtime: '14' - - - name: ubu22-gcc12-runtime15 - os: ubuntu-22.04 - compiler: gcc-12 - clang-runtime: '15' - - - name: ubu22-gcc12-runtime16 - os: ubuntu-22.04 - compiler: gcc-12 - clang-runtime: '16' - - - name: ubu22-clang11-runtime14 - os: ubuntu-22.04 - compiler: clang-11 - clang-runtime: '14' - - - name: ubu22-clang12-runtime14 - os: ubuntu-22.04 - compiler: clang-12 - clang-runtime: '14' - - - name: ubu22-clang13-runtime14 - os: ubuntu-22.04 - compiler: clang-13 - clang-runtime: '14' - - - name: ubu22-clang14-runtime14 - os: ubuntu-22.04 - compiler: clang-14 - clang-runtime: '14' - - - name: ubu22-clang14-runtime15 - os: ubuntu-22.04 - compiler: clang-14 - clang-runtime: '15' - - - name: ubu22-clang15-runtime14 - os: ubuntu-22.04 - compiler: clang-15 - clang-runtime: '14' - - - name: ubu22-clang15-runtime15 - os: ubuntu-22.04 - compiler: clang-15 - clang-runtime: '15' - - - name: ubu22-clang15-runtime16 - os: ubuntu-22.04 - compiler: clang-15 - clang-runtime: '16' + # - name: win-msvc-runtime14 + # os: windows-latest + # compiler: msvc + # clang-runtime: '14' + + # - name: win-msvc-runtime15 + # os: windows-latest + # compiler: msvc + # clang-runtime: '15' + + # - name: win-msvc-runtime16 + # os: windows-latest + # compiler: msvc + # clang-runtime: '16' + + # - name: ubu20-gcc7-runtime8 + # os: ubuntu-20.04 + # compiler: gcc-7 + # clang-runtime: '8' + + # - name: ubu20-gcc7-runtime11-analyzers + # os: ubuntu-20.04 + # compiler: gcc-7 + # clang-runtime: '11' + # coverage: true + # cuda: true + # extra_cmake_options: '-DLLVM_ENABLE_WERROR=On -DENABLE_ENZYME_BACKEND=On' + # #clang-format: true + + # - name: ubu20-gcc7-runtime11-benchmarks + # os: ubuntu-20.04 + # compiler: gcc-7 + # clang-runtime: '11' + # extra_cmake_options: '-DCLAD_ENABLE_BENCHMARKS=On -DENABLE_ENZYME_BACKEND=On' + # benchmark: true + + # - name: ubu20-gcc8-runtime11-coverity + # os: ubuntu-20.04 + # compiler: gcc-8 + # clang-runtime: '11' + # coverity: true + + # - name: ubu20-gcc9-runtime9 + # os: ubuntu-20.04 + # compiler: gcc-9 + # clang-runtime: '9' + + # - name: ubu20-clang8-runtime9 + # os: ubuntu-20.04 + # compiler: clang-8 + # clang-runtime: '9' + + # - name: ubu20-clang9-runtime9 + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '9' + + # - name: ubu20-gcc7-runtime10 + # os: ubuntu-20.04 + # compiler: gcc-7 + # clang-runtime: '10' + + # - name: ubu20-gcc8-runtime10 + # os: ubuntu-20.04 + # compiler: gcc-8 + # clang-runtime: '10' + + # - name: ubu20-gcc9-runtime10 + # os: ubuntu-20.04 + # compiler: gcc-9 + # clang-runtime: '10' + + # - name: ubu20-clang8-runtime10 + # os: ubuntu-20.04 + # compiler: clang-8 + # clang-runtime: '10' + + # - name: ubu20-clang9-runtime10 + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '10' + + # - name: ubu20-gcc7-runtime11 + # os: ubuntu-20.04 + # compiler: gcc-7 + # clang-runtime: '11' + + # - name: ubu20-gcc8-runtime11 + # os: ubuntu-20.04 + # compiler: gcc-8 + # clang-runtime: '11' + + # - name: ubu20-gcc9-runtime11 + # os: ubuntu-20.04 + # compiler: gcc-9 + # clang-runtime: '11' + + # - name: ubu20-clang8-runtime11 + # os: ubuntu-20.04 + # compiler: clang-8 + # clang-runtime: '11' + + # - name: ubu20-clang9-runtime11 + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '11' + + # - name: ubu20-gcc7-runtime12 + # os: ubuntu-20.04 + # compiler: gcc-7 + # clang-runtime: '12' + + # - name: ubu20-gcc8-runtime12 + # os: ubuntu-20.04 + # compiler: gcc-8 + # clang-runtime: '12' + + # - name: ubu20-gcc9-runtime12 + # os: ubuntu-20.04 + # compiler: gcc-9 + # clang-runtime: '12' + # - name: ubu20-clang8-runtime12 + # os: ubuntu-20.04 + # compiler: clang-8 + # clang-runtime: '12' + + # - name: ubu20-clang9-runtime12 + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '12' + + # - name: ubu20-clang9-runtime10-cuda + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '10' + # cuda: true + + # - name: ubu20-clang9-runtime15 + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '15' + + # - name: ubu20-clang9-runtime16 + # os: ubuntu-20.04 + # compiler: clang-9 + # clang-runtime: '16' + + # - name: ubu22-gcc9-runtime11 + # os: ubuntu-22.04 + # compiler: gcc-9 + # clang-runtime: '11' + + # - name: ubu22-gcc10-runtime11 + # os: ubuntu-22.04 + # compiler: gcc-10 + # clang-runtime: '11' + + # - name: ubu22-gcc11-runtime11 + # os: ubuntu-22.04 + # compiler: gcc-11 + # clang-runtime: '11' + + # - name: ubu22-clang11-runtime11 + # os: ubuntu-22.04 + # compiler: 'clang-11' + # clang-runtime: '11' + + # - name: ubu22-clang12-runtime11 + # os: ubuntu-22.04 + # compiler: clang-12 + # clang-runtime: '11' + + # - name: ubu22-clang13-runtime11 + # os: ubuntu-22.04 + # compiler: clang-13 + # clang-runtime: '11' + + # - name: ubu22-clang14-runtime11 + # os: ubuntu-22.04 + # compiler: clang-14 + # clang-runtime: '11' + + # - name: ubu22-clang15-runtime11 + # os: ubuntu-22.04 + # compiler: clang-15 + # clang-runtime: '11' + + # - name: ubu22-gcc9-runtime12 + # os: ubuntu-22.04 + # compiler: gcc-9 + # clang-runtime: '12' + + # - name: ubu22-gcc10-runtime12 + # os: ubuntu-22.04 + # compiler: gcc-10 + # clang-runtime: '12' + + # - name: ubu22-gcc11-runtime12 + # os: ubuntu-22.04 + # compiler: gcc-11 + # clang-runtime: '12' + + # - name: ubu22-clang11-runtime12 + # os: ubuntu-22.04 + # compiler: 'clang-11' + # clang-runtime: '12' + + # - name: ubu22-clang12-runtime12 + # os: ubuntu-22.04 + # compiler: clang-12 + # clang-runtime: '12' + + # - name: ubu22-clang13-runtime12 + # os: ubuntu-22.04 + # compiler: clang-13 + # clang-runtime: '12' + + # - name: ubu22-clang14-runtime12 + # os: ubuntu-22.04 + # compiler: clang-14 + # clang-runtime: '12' + + # - name: ubu22-clang15-runtime12 + # os: ubuntu-22.04 + # compiler: clang-15 + # clang-runtime: '12' + + # - name: ubu22-gcc10-runtime13 + # os: ubuntu-22.04 + # compiler: gcc-10 + # clang-runtime: '13' + + # - name: ubu22-gcc11-runtime13 + # os: ubuntu-22.04 + # compiler: gcc-11 + # clang-runtime: '13' + + # - name: ubu22-gcc12-runtime13 + # os: ubuntu-22.04 + # compiler: gcc-12 + # clang-runtime: '13' + + # - name: ubu22-clang11-runtime13 + # os: ubuntu-22.04 + # compiler: clang-11 + # clang-runtime: '13' + + # - name: ubu22-clang12-runtime13 + # os: ubuntu-22.04 + # compiler: clang-12 + # clang-runtime: '13' + + # - name: ubu22-clang13-runtime13 + # os: ubuntu-22.04 + # compiler: clang-13 + # clang-runtime: '13' + + # - name: ubu22-clang14-runtime13 + # os: ubuntu-22.04 + # compiler: clang-14 + # clang-runtime: '13' + + # - name: ubu22-clang15-runtime13 + # os: ubuntu-22.04 + # compiler: clang-15 + # clang-runtime: '13' + + # - name: ubu22-gcc10-runtime14 + # os: ubuntu-22.04 + # compiler: gcc-10 + # clang-runtime: '14' + + # - name: ubu22-gcc11-runtime14 + # os: ubuntu-22.04 + # compiler: gcc-11 + # clang-runtime: '14' + + # - name: ubu22-gcc12-runtime14 + # os: ubuntu-22.04 + # compiler: gcc-12 + # clang-runtime: '14' + + # - name: ubu22-gcc12-runtime15 + # os: ubuntu-22.04 + # compiler: gcc-12 + # clang-runtime: '15' + + # - name: ubu22-gcc12-runtime16 + # os: ubuntu-22.04 + # compiler: gcc-12 + # clang-runtime: '16' + + # - name: ubu22-clang11-runtime14 + # os: ubuntu-22.04 + # compiler: clang-11 + # clang-runtime: '14' + + # - name: ubu22-clang12-runtime14 + # os: ubuntu-22.04 + # compiler: clang-12 + # clang-runtime: '14' + + # - name: ubu22-clang13-runtime14 + # os: ubuntu-22.04 + # compiler: clang-13 + # clang-runtime: '14' + + # - name: ubu22-clang14-runtime14 + # os: ubuntu-22.04 + # compiler: clang-14 + # clang-runtime: '14' + + # - name: ubu22-clang14-runtime15 + # os: ubuntu-22.04 + # compiler: clang-14 + # clang-runtime: '15' + + # - name: ubu22-clang15-runtime14 + # os: ubuntu-22.04 + # compiler: clang-15 + # clang-runtime: '14' + + # - name: ubu22-clang15-runtime15 + # os: ubuntu-22.04 + # compiler: clang-15 + # clang-runtime: '15' + + # - name: ubu22-clang15-runtime16 + # os: ubuntu-22.04 + # compiler: clang-15 + # clang-runtime: '16' steps: - uses: actions/checkout@v3 @@ -748,7 +748,7 @@ jobs: if: ${{ failure() }} uses: mxschmitt/action-tmate@v3 # When debugging increase to a suitable value! - timeout-minutes: ${{ github.event.pull_request && 1 || 20 }} + timeout-minutes: ${{ github.event.pull_request && 30 || 20 }} - name: Prepare code coverage report if: ${{ success() && (matrix.coverage == true) }} run: | diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index d7dd9d859..f3172b53f 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -13,6 +13,7 @@ namespace custom_derivatives{} #include "clad/Differentiator/ArrayRef.h" #include "clad/Differentiator/CladConfig.h" +#include #include namespace clad { @@ -139,32 +140,34 @@ CUDA_HOST_DEVICE void fma_pullback(T1 a, T2 b, T3 c, T4 d_y, *d_c += d_y; } -template -CUDA_HOST_DEVICE ValueAndPushforward -min_pushforward(const T& a, const T& b, const T& d_a, const T& d_b) { +template +CUDA_HOST_DEVICE ValueAndPushforward +min_pushforward(const T1& a, const T2& b, const T1& d_a, const T2& d_b) { return {::std::min(a, b), a < b ? d_a : d_b}; } -template -CUDA_HOST_DEVICE ValueAndPushforward -max_pushforward(const T& a, const T& b, const T& d_a, const T& d_b) { +template +CUDA_HOST_DEVICE ValueAndPushforward +max_pushforward(const T1& a, const T2& b, const T1& d_a, const T2& d_b) { return {::std::max(a, b), a < b ? d_b : d_a}; } -template -CUDA_HOST_DEVICE void min_pullback(const T& a, const T& b, U d_y, - clad::array_ref d_a, - clad::array_ref d_b) { +template +CUDA_HOST_DEVICE void min_pullback(const T1& a, const T2& b, U d_y, + clad::array_ref d_a, + clad::array_ref d_b) { if (a < b) *d_a += d_y; else *d_b += d_y; } -template -CUDA_HOST_DEVICE void max_pullback(const T& a, const T& b, U d_y, - clad::array_ref d_a, - clad::array_ref d_b) { +template +CUDA_HOST_DEVICE void max_pullback(const T1& a, const T2& b, U d_y, + clad::array_ref d_a, + clad::array_ref d_b) { if (a < b) *d_b += d_y; else @@ -172,19 +175,20 @@ CUDA_HOST_DEVICE void max_pullback(const T& a, const T& b, U d_y, } #if __cplusplus >= 201703L -template -CUDA_HOST_DEVICE ValueAndPushforward -clamp_pushforward(const T& v, const T& lo, const T& hi, const T& d_v, - const T& d_lo, const T& d_hi) { +template +CUDA_HOST_DEVICE ValueAndPushforward +clamp_pushforward(const T1& v, const T2& lo, const T3& hi, const T1& d_v, + const T2& d_lo, const T3& d_hi) { return {::std::clamp(v, lo, hi), v < lo ? d_lo : hi < v ? d_hi : d_v}; } -template -CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi, +template +CUDA_HOST_DEVICE void clamp_pullback(const T1& v, const T2& lo, const T3& hi, const U& d_y, - clad::array_ref d_v, - clad::array_ref d_lo, - clad::array_ref d_hi) { + clad::array_ref d_v, + clad::array_ref d_lo, + clad::array_ref d_hi) { if (v < lo) *d_lo += d_y; else if (hi < v) @@ -213,6 +217,11 @@ using std::pow_pullback; using std::pow_pushforward; using std::sin_pushforward; using std::sqrt_pushforward; + +#if __cplusplus >= 201703L +using std::clamp_pullback; +using std::clamp_pushforward; +#endif } // namespace custom_derivatives } // namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index dac3e77d5..7cf6261ed 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -866,21 +866,29 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( // numerical diff to use correct declaration context. if (forCustomDerv) { DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC); + // llvm :: outs () << "outermostDC: " << outermostDC->getPrimaryContext()->getDeclKindName() << "\n"; + // print name of decl context + // llvm :: outs () << "originalFnDC: " << originalFnDC->getPrimaryContext()->getDeclKindName() << "\n"; + // llvm :: outs () << "NSD: " << NSD->getPrimaryContext()->getDeclKindName() << "\n"; // FIXME: We should ideally construct nested name specifier from the // found custom derivative function. Current way will compute incorrect // nested name specifier in some cases. if (outermostDC && outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) { + llvm :: outs () << "reached here 1: \n"; utils::BuildNNS(m_Sema, originalFnDC, SS); DC = originalFnDC; } else { + llvm :: outs () << "reached here 2: \n"; if (isa(originalFnDC)) DC = utils::LookupNSD(m_Sema, "class_functions", /*shouldExist=*/false, NSD); else DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC); - if (DC) + if (DC) { + llvm :: outs () << "reached here 2.1: \n"; utils::BuildNNS(m_Sema, DC, SS); + } } } else { SS.Extend(m_Context, NSD, noLoc, noLoc); @@ -890,10 +898,13 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema)); LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName); - if (DC) + if (DC) { + llvm :: outs () << "reached here 3: \n"; m_Sema.LookupQualifiedName(R, DC); + } Expr* OverloadedFn = 0; if (!R.empty()) { + llvm :: outs () << "reached here 4: \n"; // FIXME: We should find a way to specify nested name specifier // after finding the custom derivative. Expr* UnresolvedLookup = @@ -905,12 +916,15 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( SourceLocation Loc; if (noOverloadExists(UnresolvedLookup, MARargs)) { + llvm :: outs () << "reached here 5: \n"; return 0; } + llvm :: outs () << "reached here 6: \n"; OverloadedFn = m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); } + llvm :: outs () << "reached here 7: \n"; return OverloadedFn; } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e5d7020e..639d7c064 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1768,6 +1768,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr())); std::string customPullback = clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; + if (customPullback.find("max") != std::string::npos || + customPullback.find("min") != std::string::npos || + customPullback.find("clamp") != std::string::npos) { + // print signature of FD + llvm ::outs() << "Function signature: \n"; + FD->print(llvm::outs(), 0, false); + // print the name of pullback and its arguments + llvm ::outs() << "Pullback function name: " << customPullback << "\n"; + for (auto arg : pullbackCallArgs) { + llvm ::outs() << "Pullback function argument: "; + arg->printPrettyControlled(llvm ::outs(), nullptr, + m_Context.getPrintingPolicy(), 0); + } + } OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index 276d44fca..639eb69bf 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -4,6 +4,7 @@ #include "clad/Differentiator/Differentiator.h" +#include #include int printf(const char* fmt, ...);