-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Treat Enzyme Request to be different from Clad Request for same Function
Prior to this commit, If a user differentiates the same function with both Clad and Enzyme, only one derivative is generated, because both Derivatives get stored in the same `DerivedFnInfo` object. Hence if the derivative of a function was already found by using Enzyme or Clad, using the other method would not succeed. This is because the new `DiffRequest` would be taken as a repetition of the old one, and hence the old derivative would be returned. This commit adds a new field `m_UsesEnzyme` to `DerivedFnInfo` that helps the program differentiate between a Function derived with Enzyme and Clad respectively.
- Loading branch information
1 parent
d7fcd51
commit 5be9796
Showing
3 changed files
with
52 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out | FileCheck %s | ||
// RUN: ./DifferentCladEnzymeDerivatives.out | ||
// CHECK-NOT: {{.*error|warning|note:.*}} | ||
// REQUIRES: Enzyme | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
|
||
double foo(double x, double y){ | ||
return x*y; | ||
} | ||
|
||
// CHECK: void foo_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) { | ||
// CHECK-NEXT: double _t0; | ||
// CHECK-NEXT: double _t1; | ||
// CHECK-NEXT: _t1 = x; | ||
// CHECK-NEXT: _t0 = y; | ||
// CHECK-NEXT: double foo_return = _t1 * _t0; | ||
// CHECK-NEXT: goto _label0; | ||
// CHECK-NEXT: _label0: | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: double _r0 = 1 * _t0; | ||
// CHECK-NEXT: * _d_x += _r0; | ||
// CHECK-NEXT: double _r1 = _t1 * 1; | ||
// CHECK-NEXT: * _d_y += _r1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void foo_grad_enzyme(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) { | ||
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_foo(foo, x, y); | ||
// CHECK-NEXT: * _d_x = grad.d_arr[0U]; | ||
// CHECK-NEXT: * _d_y = grad.d_arr[1U]; | ||
// CHECK-NEXT: } | ||
|
||
int main(){ | ||
auto grad = clad::gradient(foo); | ||
auto gradEnzyme = clad::gradient<clad::opts::use_enzyme>(foo); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters