Skip to content

Commit

Permalink
Treat Enzyme Request to be different from Clad Request for same Function
Browse files Browse the repository at this point in the history
Prior to this commit, If a user differentiates the same function with both Clad and
Enzyme, only one derivative is generated, because both Derivatives get stored in the
same `DerivedFnInfo` object.

Hence if the derivative of a function was already found by using Enzyme or Clad, using
the other method would not succeed. This is because the new `DiffRequest` would be
taken as a repetition of the old one, and hence the old derivative would be returned.

This commit adds a new field `m_UsesEnzyme` to `DerivedFnInfo` that helps the program
differentiate between a Function derived with Enzyme and Clad respectively.
  • Loading branch information
Nirhar authored and vgvassilev committed Sep 1, 2022
1 parent d7fcd51 commit 5be9796
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 12 deletions.
38 changes: 38 additions & 0 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out | FileCheck %s
// RUN: ./DifferentCladEnzymeDerivatives.out
// CHECK-NOT: {{.*error|warning|note:.*}}
// REQUIRES: Enzyme

#include "clad/Differentiator/Differentiator.h"


double foo(double x, double y){
return x*y;
}

// CHECK: void foo_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: _t1 = x;
// CHECK-NEXT: _t0 = y;
// CHECK-NEXT: double foo_return = _t1 * _t0;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 1 * _t0;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: double _r1 = _t1 * 1;
// CHECK-NEXT: * _d_y += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void foo_grad_enzyme(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_foo(foo, x, y);
// CHECK-NEXT: * _d_x = grad.d_arr[0U];
// CHECK-NEXT: * _d_y = grad.d_arr[1U];
// CHECK-NEXT: }

int main(){
auto grad = clad::gradient(foo);
auto gradEnzyme = clad::gradient<clad::opts::use_enzyme>(foo);
}
25 changes: 13 additions & 12 deletions tools/DerivedFnInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
using namespace clang;

namespace clad {
DerivedFnInfo::DerivedFnInfo(const DiffRequest& request,
FunctionDecl* derivedFn,
FunctionDecl* overloadedDerivedFn)
: m_OriginalFn(request.Function), m_DerivedFn(derivedFn),
m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode),
m_DerivativeOrder(request.CurrentDerivativeOrder),
m_DiffVarsInfo(request.DVI) {}
DerivedFnInfo::DerivedFnInfo(const DiffRequest& request,
FunctionDecl* derivedFn,
FunctionDecl* overloadedDerivedFn)
: m_OriginalFn(request.Function), m_DerivedFn(derivedFn),
m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode),
m_DerivativeOrder(request.CurrentDerivativeOrder),
m_DiffVarsInfo(request.DVI), m_UsesEnzyme(request.use_enzyme) {}

bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const {
return (request.Function == m_OriginalFn && request.Mode == m_Mode &&
request.CurrentDerivativeOrder == m_DerivativeOrder &&
request.DVI == m_DiffVarsInfo);
bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const {
return (request.Function == m_OriginalFn && request.Mode == m_Mode &&
request.CurrentDerivativeOrder == m_DerivativeOrder &&
request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme);
}

bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; }
Expand All @@ -26,6 +26,7 @@ namespace clad {
return lhs.m_OriginalFn == rhs.m_OriginalFn &&
lhs.m_DerivativeOrder == rhs.m_DerivativeOrder &&
lhs.m_Mode == rhs.m_Mode &&
lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo;
lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo &&
lhs.m_UsesEnzyme == rhs.m_UsesEnzyme;
}
} // namespace clad
1 change: 1 addition & 0 deletions tools/DerivedFnInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace clad {
DiffMode m_Mode = DiffMode::unknown;
unsigned m_DerivativeOrder = 0;
DiffInputVarsInfo m_DiffVarsInfo;
bool m_UsesEnzyme = false;

DerivedFnInfo() {}
DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,
Expand Down

0 comments on commit 5be9796

Please sign in to comment.