-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for verifying Enzyme Gradients with Clad Gradients #488
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -450,6 +450,22 @@ namespace clad { | |
code); | ||
} | ||
|
||
void EssentiallyEqual(double a, double b) { | ||
// FIXME: We should select epsilon value in a more robust way. | ||
const double epsilon = 1e-12; | ||
// printf("a=%.40f, b=%.40f\n",a,b); | ||
bool ans = std::fabs(a - b) <= | ||
((std::fabs(a > b) ? std::fabs(b) : std::fabs(a)) * epsilon); | ||
|
||
assert(ans && "Clad Gradient is not equal to Enzyme Gradient"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of an assert, can we give a non-fatal error here? |
||
} | ||
void EssentiallyEqualArrays(double* a, double* b, unsigned size) { | ||
// FIXME: We should select epsilon value in a more robust way. | ||
for (int i = 0; i < size; i++) { | ||
EssentiallyEqual(a[i], b[i]); | ||
} | ||
} | ||
|
||
// Gradient Structure for Reverse Mode Enzyme | ||
template <unsigned N> struct EnzymeGradient { double d_arr[N]; }; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,6 +268,10 @@ namespace clad { | |
if (request.use_enzyme) | ||
use_enzyme = true; | ||
|
||
if (request.checkEnzymeWithClad) { | ||
checkEnzymeWithClad = true; | ||
} | ||
|
||
auto derivativeBaseName = request.BaseFunctionName; | ||
std::string gradientName = derivativeBaseName + funcPostfix(); | ||
// To be consistent with older tests, nothing is appended to 'f_grad' if | ||
|
@@ -413,6 +417,16 @@ namespace clad { | |
else | ||
DifferentiateWithEnzyme(); | ||
|
||
if (use_enzyme && checkEnzymeWithClad) { | ||
DiffRequest newRequest = const_cast<DiffRequest&>(request); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need |
||
newRequest.checkEnzymeWithClad = false; | ||
newRequest.use_enzyme = false; | ||
FunctionDecl* cladFD = | ||
plugin::ProcessDiffRequest(m_CladPlugin, newRequest); | ||
|
||
CheckEnzymeResultsWithClad(cladFD); | ||
} | ||
|
||
gradientBody = endBlock(); | ||
m_Derivative->setBody(gradientBody); | ||
endScope(); // Function body scope | ||
|
@@ -707,6 +721,101 @@ namespace clad { | |
addToCurrentBlock(enzymeCall); | ||
} | ||
} | ||
void ReverseModeVisitor::CheckEnzymeResultsWithClad(FunctionDecl* cladFD) { | ||
// Prepare Arguments for the clad derivative function | ||
llvm::SmallVector<Expr*, 16> cladGradArgs; | ||
llvm::SmallVector<VarDecl*, 16> cladResultDecls; | ||
unsigned numParams = m_Function->getNumParams(); | ||
llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters(); | ||
|
||
for (int i = 0; i < numParams; i++) { | ||
cladGradArgs.push_back(BuildDeclRef(paramsRef[i])); | ||
} | ||
std::string varNames = "cladResult"; | ||
int varNo = 1; | ||
auto size_type = m_Context.getSizeType(); | ||
unsigned size_type_bits = m_Context.getIntWidth(size_type); | ||
for (int i = 0; i < numParams; i++) { | ||
std::string finalVarName = varNames + std::to_string(varNo++); | ||
auto paramType = paramsRef[i]->getOriginalType(); | ||
|
||
// FIX-ME: Non Constant Array/pointer type parameters can't be dealt with | ||
// as of now because we don't know the size of the array This code will | ||
// break if we use array type parameters. This can be fixed if the | ||
// ReverseModeVisitor keeps track of the maximum index of the array seen | ||
// so far. | ||
|
||
if (isArrayOrPointerType(paramType)) { | ||
assert(paramType->isConstantArrayType() && | ||
"Only Constant type arrays are allowed to be parameters of " | ||
"functions whose gradients we want to verify with clad"); | ||
|
||
// Create InitList to set all elements of the result array to zero | ||
auto init = FloatingLiteral::Create( | ||
m_Context, llvm::APFloat(0.0), true, | ||
dyn_cast<ConstantArrayType>(paramType)->getElementType(), noLoc); | ||
llvm::SmallVector<Expr*, 2> initListElement{init}; | ||
auto initList = dyn_cast<InitListExpr>( | ||
m_Sema.ActOnInitList(noLoc, initListElement, noLoc).get()); | ||
ImplicitValueInitExpr imp( | ||
dyn_cast<ConstantArrayType>(paramType)->getElementType()); | ||
initList->setArrayFiller(&imp); | ||
|
||
auto resultVar = BuildVarDecl(paramType, finalVarName, initList, true); | ||
addToCurrentBlock(BuildDeclStmt(resultVar), direction::forward); | ||
cladGradArgs.push_back(BuildDeclRef(resultVar)); | ||
cladResultDecls.push_back(resultVar); | ||
} else { | ||
VarDecl* resultVar; | ||
if (paramType->isFloatingType()) { | ||
auto init = FloatingLiteral::Create(m_Context, llvm::APFloat(0.0), | ||
true, paramType, noLoc); | ||
resultVar = BuildVarDecl(paramType, finalVarName, init, true); | ||
} else { | ||
resultVar = BuildVarDecl(paramType, finalVarName, nullptr, true); | ||
} | ||
addToCurrentBlock(BuildDeclStmt(resultVar), direction::forward); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move this statement outside the if-else chain. |
||
cladGradArgs.push_back(BuildOp(UO_AddrOf, BuildDeclRef(resultVar))); | ||
cladResultDecls.push_back(resultVar); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move this outside if-else chain as well. |
||
} | ||
} | ||
|
||
Expr* cladCall = BuildCallExprToFunction(cladFD, cladGradArgs); | ||
addToCurrentBlock(cladCall); | ||
|
||
// Compare the values | ||
FunctionDecl* equalityFD = | ||
LookupFunctionDeclInCladNamespace("EssentiallyEqual"); | ||
FunctionDecl* equalityFDForArrays = | ||
LookupFunctionDeclInCladNamespace("EssentiallyEqualArrays"); | ||
for (int i = 0; i < numParams; i++) { | ||
auto paramType = paramsRef[i]->getOriginalType(); | ||
llvm::SmallVector<Expr*, 2> equalityCheckArguments; | ||
equalityCheckArguments.push_back(BuildDeclRef(cladResultDecls[i])); | ||
if (paramType->isFloatingType()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if |
||
equalityCheckArguments.push_back( | ||
BuildOp(UO_Deref, BuildDeclRef(paramsRef[i + numParams]))); | ||
Expr* checkCall = | ||
BuildCallExprToFunction(equalityFD, equalityCheckArguments); | ||
addToCurrentBlock(checkCall); | ||
} else if (paramType->isConstantArrayType()) { | ||
equalityCheckArguments.push_back(BuildCallExprToMemFn( | ||
BuildDeclRef(paramsRef[i + numParams]), "ptr", {})); | ||
ConstantArrayType* t = dyn_cast<ConstantArrayType>( | ||
const_cast<Type*>(paramType.getTypePtr())); | ||
int sizeOfArray = (int)(t->getSize().roundToDouble(false)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need C-style cast here? |
||
llvm::APInt idxValue(size_type_bits, sizeOfArray); | ||
auto idx = | ||
IntegerLiteral::Create(m_Context, idxValue, size_type, noLoc); | ||
equalityCheckArguments.push_back(idx); | ||
|
||
Expr* checkCall = BuildCallExprToFunction(equalityFDForArrays, | ||
equalityCheckArguments); | ||
addToCurrentBlock(checkCall); | ||
} | ||
} | ||
} | ||
|
||
StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { | ||
diag( | ||
DiagnosticsEngine::Warning, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
// RUN: %cladclang %s -I%S/../../include -Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad -lstdc++ -oReverseModeWithCladCheck.out | FileCheck %s | ||
// RUN: ./ReverseModeWithCladCheck.out | FileCheck -check-prefix=CHECK-EXEC %s | ||
// CHECK-NOT: {{.*error|warning|note:.*}} | ||
// REQUIRES: Enzyme | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add a test containing nested function calls? |
||
double f1(double arr[2]) { return arr[0] * arr[1]; } | ||
|
||
// CHECK: void f1_grad_enzyme(double arr[2], clad::array_ref<double> _d_arr) { | ||
// CHECK-NEXT: double *d_arr = _d_arr.ptr(); | ||
// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr); | ||
// CHECK-NEXT: double cladResult1[2] = {0.}; | ||
// CHECK-NEXT: f1_grad(arr, cladResult1); | ||
// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL); | ||
// CHECK-NEXT:} | ||
|
||
double f2(double x, double y, double z){ | ||
return x * y * z; | ||
} | ||
|
||
// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y, clad::array_ref<double> _d_z) { | ||
// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z); | ||
// CHECK-NEXT: * _d_x = grad.d_arr[0U]; | ||
// CHECK-NEXT: * _d_y = grad.d_arr[1U]; | ||
// CHECK-NEXT: * _d_z = grad.d_arr[2U]; | ||
// CHECK-NEXT: double cladResult1 = 0.; | ||
// CHECK-NEXT: double cladResult2 = 0.; | ||
// CHECK-NEXT: double cladResult3 = 0.; | ||
// CHECK-NEXT: f2_grad(x, y, z, &cladResult1, &cladResult2, &cladResult3); | ||
// CHECK-NEXT: EssentiallyEqual(cladResult1, * _d_x); | ||
// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_y); | ||
// CHECK-NEXT: EssentiallyEqual(cladResult3, * _d_z); | ||
// CHECK-NEXT:} | ||
|
||
double f3(double arr[3], int n){ | ||
double sum=0; | ||
for(int i=0;i<n;i++){ | ||
sum+=arr[i]*arr[i]; | ||
} | ||
return sum; | ||
} | ||
|
||
// CHECK: void f3_grad_enzyme(double arr[3], int n, clad::array_ref<double> _d_arr, clad::array_ref<int> _d_n) { | ||
// CHECK-NEXT: double *d_arr = _d_arr.ptr(); | ||
// CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, d_arr, n); | ||
// CHECK-NEXT: double cladResult1[3] = {0.}; | ||
// CHECK-NEXT: int cladResult2; | ||
// CHECK-NEXT: f3_grad(arr, n, cladResult1, &cladResult2); | ||
// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 3UL); | ||
// CHECK-NEXT: } | ||
|
||
double f4(double arr1[3], int n, double arr2[2], int m){ | ||
double sum=0; | ||
for(int i=0;i<n;i++){ | ||
sum+=arr1[i]*arr1[i]; | ||
} | ||
for(int i=0;i<m;i++){ | ||
sum+=arr2[i]*arr2[i]; | ||
} | ||
return sum; | ||
} | ||
|
||
// CHECK: void f4_grad_enzyme(double arr1[3], int n, double arr2[2], int m, clad::array_ref<double> _d_arr1, clad::array_ref<int> _d_n, clad::array_ref<double> _d_arr2, clad::array_ref<int> _d_m) { | ||
// CHECK-NEXT: double *d_arr1 = _d_arr1.ptr(); | ||
// CHECK-NEXT: double *d_arr2 = _d_arr2.ptr(); | ||
// CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, d_arr1, n, arr2, d_arr2, m); | ||
// CHECK-NEXT: double cladResult1[3] = {0.}; | ||
// CHECK-NEXT: int cladResult2; | ||
// CHECK-NEXT: double cladResult3[2] = {0.}; | ||
// CHECK-NEXT: int cladResult4; | ||
// CHECK-NEXT: f4_grad(arr1, n, arr2, m, cladResult1, &cladResult2, cladResult3, &cladResult4); | ||
// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr1.ptr(), 3UL); | ||
// CHECK-NEXT: EssentiallyEqualArrays(cladResult3, _d_arr2.ptr(), 2UL); | ||
// CHECK-NEXT: } | ||
|
||
double f5(double arr[3], double x,int n,double y){ | ||
double res=0; | ||
for(int i=0;i<n;i++){ | ||
res+=(arr[i]*x*y); | ||
} | ||
return res; | ||
} | ||
|
||
// CHECK: void f5_grad_enzyme(double arr[3], double x, int n, double y, clad::array_ref<double> _d_arr, clad::array_ref<double> _d_x, clad::array_ref<int> _d_n, clad::array_ref<double> _d_y) { | ||
// CHECK-NEXT: double *d_arr = _d_arr.ptr(); | ||
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, d_arr, x, n, y); | ||
// CHECK-NEXT: * _d_x = grad.d_arr[0U]; | ||
// CHECK-NEXT: * _d_y = grad.d_arr[1U]; | ||
// CHECK-NEXT: double cladResult1[3] = {0.}; | ||
// CHECK-NEXT: double cladResult2 = 0.; | ||
// CHECK-NEXT: int cladResult3; | ||
// CHECK-NEXT: double cladResult4 = 0.; | ||
// CHECK-NEXT: f5_grad(arr, x, n, y, cladResult1, &cladResult2, &cladResult3, &cladResult4); | ||
// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 3UL); | ||
// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_x); | ||
// CHECK-NEXT: EssentiallyEqual(cladResult4, * _d_y); | ||
// CHECK-NEXT: } | ||
|
||
|
||
int main() { | ||
auto f1_grad = clad::gradient<clad::opts::use_enzyme>(f1); | ||
double f1_v[2] = {3, 4}; | ||
double f1_g[2] = {0}; | ||
f1_grad.execute(f1_v, f1_g); | ||
printf("d_x = %.2f, d_y = %.2f\n", f1_g[0], f1_g[1]); | ||
// CHECK-EXEC: d_x = 4.00, d_y = 3.00 | ||
|
||
auto f2_grad=clad::gradient<clad::opts::use_enzyme>(f2); | ||
double f2_res[3]={0}; | ||
double f2_x=3,f2_y=4,f2_z=5; | ||
f2_grad.execute(f2_x,f2_y,f2_z,&f2_res[0],&f2_res[1],&f2_res[2]); | ||
printf("d_x = %.2f, d_y = %.2f, d_z = %.2f\n", f2_res[0], f2_res[1], f2_res[2]); | ||
//CHECK-EXEC: d_x = 20.00, d_y = 15.00, d_z = 12.00 | ||
|
||
auto f3_grad=clad::gradient<clad::opts::use_enzyme>(f3); | ||
double f3_list[3]={3,4,5}; | ||
double f3_res[3]={0}; | ||
int f3_dn=0; | ||
f3_grad.execute(f3_list,3,f3_res,&f3_dn); | ||
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n = %d\n",f3_res[0],f3_res[1],f3_res[2],f3_dn); | ||
//CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n = 0 | ||
|
||
auto f4_grad=clad::gradient<clad::opts::use_enzyme>(f4); | ||
double f4_list1[3]={3,4,5}; | ||
double f4_list2[2]={1,2}; | ||
double f4_res1[3]={0}; | ||
double f4_res2[2]={0}; | ||
int f4_dn1=0,f4_dn2=0; | ||
f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,&f4_dn1,f4_res2,&f4_dn2); | ||
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d\n",f4_res1[0],f4_res1[1],f4_res1[2],f4_dn1); | ||
//CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n1 = 0 | ||
printf("d_y1 = %.2f, d_y2 = %.2f, d_n2 = %d\n",f4_res2[0],f4_res2[1],f4_dn2); | ||
//CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00, d_n2 = 0 | ||
|
||
auto f5_grad=clad::gradient<clad::opts::use_enzyme>(f5); | ||
double f5_list[3]={3,4,5}; | ||
double f5_res[3]={0}; | ||
double f5_x=10.0,f5_dx=0,f5_y=5,f5_dy=0; | ||
int f5_dn=0; | ||
f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dn,&f5_dy); | ||
printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dn, f5_dx, f5_dy); | ||
//CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_n1 = 0, d_x = 60.00, d_y = 120.00 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@grimmmyshini, can you take a look?