diff --git a/test/Enzyme/DifferentCladEnzymeDerivatives.C b/test/Enzyme/DifferentCladEnzymeDerivatives.C new file mode 100644 index 000000000..921bf306a --- /dev/null +++ b/test/Enzyme/DifferentCladEnzymeDerivatives.C @@ -0,0 +1,38 @@ +// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out | FileCheck %s +// RUN: ./DifferentCladEnzymeDerivatives.out +// CHECK-NOT: {{.*error|warning|note:.*}} +// REQUIRES: Enzyme + +#include "clad/Differentiator/Differentiator.h" + + +double foo(double x, double y){ + return x*y; +} + +// CHECK: void foo_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: _t1 = x; +// CHECK-NEXT: _t0 = y; +// CHECK-NEXT: double foo_return = _t1 * _t0; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 1 * _t0; +// CHECK-NEXT: * _d_x += _r0; +// CHECK-NEXT: double _r1 = _t1 * 1; +// CHECK-NEXT: * _d_y += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void foo_grad_enzyme(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_foo(foo, x, y); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: } + +int main(){ + auto grad = clad::gradient(foo); + auto gradEnzyme = clad::gradient(foo); +} \ No newline at end of file diff --git a/tools/DerivedFnInfo.cpp b/tools/DerivedFnInfo.cpp index d0f8251a4..a5c708dfa 100644 --- a/tools/DerivedFnInfo.cpp +++ b/tools/DerivedFnInfo.cpp @@ -5,18 +5,18 @@ using namespace clang; namespace clad { - DerivedFnInfo::DerivedFnInfo(const DiffRequest& request, - FunctionDecl* derivedFn, - FunctionDecl* overloadedDerivedFn) - : m_OriginalFn(request.Function), m_DerivedFn(derivedFn), - m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode), - m_DerivativeOrder(request.CurrentDerivativeOrder), - m_DiffVarsInfo(request.DVI) {} +DerivedFnInfo::DerivedFnInfo(const DiffRequest& request, + FunctionDecl* derivedFn, + FunctionDecl* overloadedDerivedFn) + : m_OriginalFn(request.Function), m_DerivedFn(derivedFn), + m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode), + m_DerivativeOrder(request.CurrentDerivativeOrder), + m_DiffVarsInfo(request.DVI), m_UsesEnzyme(request.use_enzyme) {} - bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { - return (request.Function == m_OriginalFn && request.Mode == m_Mode && - request.CurrentDerivativeOrder == m_DerivativeOrder && - request.DVI == m_DiffVarsInfo); +bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { + return (request.Function == m_OriginalFn && request.Mode == m_Mode && + request.CurrentDerivativeOrder == m_DerivativeOrder && + request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme); } bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; } @@ -26,6 +26,7 @@ namespace clad { return lhs.m_OriginalFn == rhs.m_OriginalFn && lhs.m_DerivativeOrder == rhs.m_DerivativeOrder && lhs.m_Mode == rhs.m_Mode && - lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo; + lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo && + lhs.m_UsesEnzyme == rhs.m_UsesEnzyme; } } // namespace clad \ No newline at end of file diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h index 5b45201fc..03d1a6763 100644 --- a/tools/DerivedFnInfo.h +++ b/tools/DerivedFnInfo.h @@ -17,6 +17,7 @@ namespace clad { DiffMode m_Mode = DiffMode::unknown; unsigned m_DerivativeOrder = 0; DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme = false; DerivedFnInfo() {} DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,