Skip to content

Commit

Permalink
Add support for verifying Enzyme Gradients with Clad Gradients
Browse files Browse the repository at this point in the history
This commit generates code that will verify the results of Enzyme Gradients with Clad Gradients.
For example, if previously the following code was generated for differentiating with enzyme for a function:
```cpp
void f1_grad_enzyme(double arr[2], clad::array_ref<double> _d_arr) {
    double *d_arr = _d_arr.ptr();
    __enzyme_autodiff_f1(f1, arr, d_arr);
}

```

The above code will be appended with checks to verify the calculated gradients. Thus the newly generated code would be:
```cpp
void f1_grad_enzyme(double arr[2], clad::array_ref<double> _d_arr) {
    double *d_arr = _d_arr.ptr();
    __enzyme_autodiff_f1(f1, arr, d_arr);
    double cladResult1[2];
    f1_grad(arr, cladResult1);
    EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL);
}
```

`EssentiallyEqualArrays` and `EssentiallyEqual` are functions defined in Differentiator.h

Only functions with primitive type and ConstantArray type parameters can be verified in this manner.

To trigger this verification one must append the following flag to clang while compiling the function to be generated: `-Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad`
  • Loading branch information
Nirhar committed Aug 25, 2022
1 parent 09e2ed0 commit 2a36e90
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 8 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ namespace clad {
// A flag to enable the use of enzyme for backend instead of clad
bool use_enzyme = false;

// A flag to generate code that verifies clad and enzyme
bool checkEnzymeWithClad = false;

/// Recomputes `DiffInputVarsInfo` using the current values of data members.
///
/// Differentiation parameters info is computed by parsing the argument
Expand Down
16 changes: 16 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,22 @@ namespace clad {
code);
}

void EssentiallyEqual(double a, double b) {
//FIXME: We should select epsilon value in a more robust way.
const double epsilon = 1e-12;
// printf("a=%.40f, b=%.40f\n",a,b);
bool ans = std::fabs(a - b) <=
((std::fabs(a > b) ? std::fabs(b) : std::fabs(a)) * epsilon);

assert(ans && "Clad Gradient is not equal to Enzyme Gradient");
}
void EssentiallyEqualArrays(double* a, double* b, unsigned size) {
//FIXME: We should select epsilon value in a more robust way.
for(int i=0;i<size;i++){
EssentiallyEqual(a[i],b[i]);
}
}

// Gradient Structure for Reverse Mode Enzyme
template <unsigned N> struct EnzymeGradient { double d_arr[N]; };
}
Expand Down
4 changes: 4 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace clad {
unsigned numParams = 0;
bool isVectorValued = false;
bool use_enzyme = false;
bool checkEnzymeWithClad = false;
// FIXME: Should we make this an object instead of a pointer?
// Downside of making it an object: We will need to include
// 'MultiplexExternalRMVSource.h' file
Expand Down Expand Up @@ -92,6 +93,9 @@ namespace clad {
// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

//Function that inserts code to verify Enzyme Results with Clad Results
void CheckEnzymeResultsWithClad(clang::FunctionDecl* cladFD);

public:
using direction = rmv::direction;
clang::Expr* dfdx() {
Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,11 @@ namespace clad {
/// \returns The declaration of the class with the name ClassName
clang::TemplateDecl*
LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName);
/// Find declaration of clad::function templated type
///
/// \param[in] FunctionName name of the function to be found
/// \returns The declaration of the function with the name FunctionName
clang::FunctionDecl* LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName);
/// Instantiate clad::class<TemplateArgs> type
///
/// \param[in] CladClassDecl the decl of the class that is going to be used
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ namespace clad {
call->setArg(derivedFnArgIdx, newUnOp);

// Update the code parameter.
if (CXXDefaultArgExpr* Arg
= dyn_cast<CXXDefaultArgExpr>(call->getArg(codeArgIdx))) {
CXXDefaultArgExpr* Arg
= dyn_cast<CXXDefaultArgExpr>(call->getArg(codeArgIdx));
if (Arg) {
clang::LangOptions LangOpts;
LangOpts.CPlusPlus = true;
clang::PrintingPolicy Policy(LangOpts);
Expand Down
82 changes: 82 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <algorithm>
#include <numeric>
#include <iostream>

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/Compatibility.h"
Expand Down Expand Up @@ -268,6 +269,10 @@ namespace clad {
if (request.use_enzyme)
use_enzyme = true;

if(request.checkEnzymeWithClad){
checkEnzymeWithClad = true;
}

auto derivativeBaseName = request.BaseFunctionName;
std::string gradientName = derivativeBaseName + funcPostfix();
// To be consistent with older tests, nothing is appended to 'f_grad' if
Expand Down Expand Up @@ -413,6 +418,15 @@ namespace clad {
else
DifferentiateWithEnzyme();

if(use_enzyme && checkEnzymeWithClad){
DiffRequest newRequest = const_cast<DiffRequest&>(request);
newRequest.checkEnzymeWithClad = false;
newRequest.use_enzyme = false;
FunctionDecl* cladFD = plugin::ProcessDiffRequest(m_CladPlugin,newRequest);

CheckEnzymeResultsWithClad(cladFD);
}

gradientBody = endBlock();
m_Derivative->setBody(gradientBody);
endScope(); // Function body scope
Expand Down Expand Up @@ -707,6 +721,74 @@ namespace clad {
addToCurrentBlock(enzymeCall);
}
}
void ReverseModeVisitor::CheckEnzymeResultsWithClad(FunctionDecl* cladFD){
// Prepare Arguments for the clad derivative function
llvm::SmallVector<Expr*, 16> cladGradArgs;
llvm::SmallVector<VarDecl*,16> cladResultDecls;
unsigned numParams = m_Function->getNumParams();
llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters();

for(int i=0;i<numParams;i++){
cladGradArgs.push_back(BuildDeclRef(paramsRef[i]));
}
std::string varNames = "cladResult";
int varNo = 1;
for(int i=0;i<numParams;i++){
std::string finalVarName = varNames + std::to_string(varNo++);
auto paramType = paramsRef[i]->getOriginalType();

// FIX-ME: Non Constant Array/pointer type parameters can't be dealt with as of now because we don't know the size of the array
// This code will break if we use array type parameters. This can be fixed if the ReverseModeVisitor keeps track
// of the maximum index of the array seen so far.

if(isArrayOrPointerType(paramType)){
assert(paramType->isConstantArrayType() &&
"Only Constant type arrays are allowed to be parameters of "
"functions whose gradients we want to verify with clad");

auto resultVar = BuildVarDecl(paramType,finalVarName,nullptr,false);
addToCurrentBlock(BuildDeclStmt(resultVar),direction::forward);
cladGradArgs.push_back(BuildDeclRef(resultVar));
cladResultDecls.push_back(resultVar);
}else{
auto resultVar = BuildVarDecl(paramType,finalVarName,nullptr,false);
addToCurrentBlock(BuildDeclStmt(resultVar),direction::forward);
cladGradArgs.push_back(BuildOp(UO_AddrOf,BuildDeclRef(resultVar)));
cladResultDecls.push_back(resultVar);
}
}

Expr* cladCall = BuildCallExprToFunction(cladFD,cladGradArgs);
addToCurrentBlock(cladCall);

//Compare the values
FunctionDecl* equalityFD = LookupFunctionDeclInCladNamespace("EssentiallyEqual");
FunctionDecl* equalityFDForArrays = LookupFunctionDeclInCladNamespace("EssentiallyEqualArrays");
auto size_type = m_Context.getSizeType();
unsigned size_type_bits = m_Context.getIntWidth(size_type);
for(int i=0;i<numParams;i++){
auto paramType = paramsRef[i]->getOriginalType();
llvm::SmallVector<Expr*, 2> equalityCheckArguments;
equalityCheckArguments.push_back(BuildDeclRef(cladResultDecls[i]));
if(paramType->isFloatingType()){
equalityCheckArguments.push_back(BuildOp(UO_Deref, BuildDeclRef(paramsRef[i+numParams])));
Expr* checkCall = BuildCallExprToFunction(equalityFD,equalityCheckArguments);
addToCurrentBlock(checkCall);
}else if(paramType->isConstantArrayType()){
equalityCheckArguments.push_back(BuildCallExprToMemFn(BuildDeclRef(paramsRef[i+numParams]),"ptr",{}));
ConstantArrayType* t =
dyn_cast<ConstantArrayType>(const_cast<Type*>(paramType.getTypePtr()));
int sizeOfArray = (int)(t->getSize().roundToDouble(false));
llvm::APInt idxValue(size_type_bits, sizeOfArray);
auto idx = IntegerLiteral::Create(m_Context, idxValue, size_type, noLoc);
equalityCheckArguments.push_back(idx);

Expr* checkCall = BuildCallExprToFunction(equalityFDForArrays,equalityCheckArguments);
addToCurrentBlock(checkCall);
}
}
}

StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
diag(
DiagnosticsEngine::Warning,
Expand Down
21 changes: 19 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <algorithm>
#include <numeric>
#include <iostream>

#include "clad/Differentiator/Compatibility.h"

Expand Down Expand Up @@ -366,8 +367,24 @@ namespace clad {
return cast<TemplateDecl>(TapeR.getFoundDecl());
}

QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl,
TemplateArgumentListInfo& TLI) {
FunctionDecl* VisitorBase::LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName){
NamespaceDecl* CladNS = GetCladNamespace();
CXXScopeSpec CSS;
CSS.Extend(m_Context, CladNS, noLoc, noLoc);
DeclarationName TapeName = &m_Context.Idents.get(FunctionName);
LookupResult TapeR(m_Sema,
TapeName,
noLoc,
Sema::LookupUsingDeclName,
clad_compat::Sema_ForVisibleRedeclaration);
m_Sema.LookupQualifiedName(TapeR, CladNS, CSS);
assert(!TapeR.empty() && isa<FunctionDecl>(TapeR.getFoundDecl()) &&
"cannot find clad::tape");
return cast<FunctionDecl>(TapeR.getFoundDecl());
}

QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl,
TemplateArgumentListInfo& TLI) {
// This will instantiate tape<T> type and return it.
QualType TT =
m_Sema.CheckTemplateIdType(TemplateName(CladClassDecl), noLoc, TLI);
Expand Down
51 changes: 51 additions & 0 deletions test/Enzyme/ReverseModeWithCladCheck.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %cladclang %s -I%S/../../include -Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad -oReverseModeWithCladCheck.out | FileCheck %s
// RUN: ./ReverseModeWithCladCheck.out | FileCheck -check-prefix=CHECK-EXEC %s
// CHECK-NOT: {{.*error|warning|note:.*}}
// REQUIRES: Enzyme
// XFAIL:*

#include "clad/Differentiator/Differentiator.h"

double f1(double arr[2]) { return arr[0] * arr[1]; }

// CHECK: void f1_grad_enzyme(double arr[2], clad::array_ref<double> _d_arr) {
// CHECK-NEXT: double *d_arr = _d_arr.ptr();
// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr);
// CHECK-NEXT: double cladResult1[2];
// CHECK-NEXT: f1_grad(arr, cladResult1);
// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL);
// CHECK-NEXT:}

double f2(double x, double y, double z){
return x * y * z;
}

// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y, clad::array_ref<double> _d_z) {
// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z);
// CHECK-NEXT: * _d_x = grad.d_arr[0U];
// CHECK-NEXT: * _d_y = grad.d_arr[1U];
// CHECK-NEXT: * _d_z = grad.d_arr[2U];
// CHECK-NEXT: double cladResult1;
// CHECK-NEXT: double cladResult2;
// CHECK-NEXT: double cladResult3;
// CHECK-NEXT: f2_grad(x, y, z, &cladResult1, &cladResult2, &cladResult3);
// CHECK-NEXT: EssentiallyEqual(cladResult1, * _d_x);
// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_y);
// CHECK-NEXT: EssentiallyEqual(cladResult3, * _d_z);
// CHECK-NEXT:}

int main() {
auto f1_grad = clad::gradient<clad::opts::use_enzyme>(f1);
double f1_v[2] = {3, 4};
double f1_g[2] = {0};
f1_grad.execute(f1_v, f1_g);
printf("d_x = %.2f, d_y = %.2f\n", f1_g[0], f1_g[1]);
// CHECK-EXEC: d_x = 4.00, d_y = 3.00

auto f2_grad=clad::gradient<clad::opts::use_enzyme>(f2);
double f2_res[3];
double f2_x=3,f2_y=4,f2_z=5;
f2_grad.execute(f2_x,f2_y,f2_z,&f2_res[0],&f2_res[1],&f2_res[2]);
printf("d_x = %.2f, d_y = %.2f, d_z = %.2f\n", f2_res[0], f2_res[1], f2_res[2]);
//CHECK-EXEC: d_x = 20.00, d_y = 15.00, d_z = 12.00
}
6 changes: 6 additions & 0 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ namespace clad {
if (m_DO.DumpSourceFnAST) {
FD->dumpColor();
}

// If enabled, update request to also compare enzyme and clad results
if(m_DO.CheckEnzymeWithClad){
request.checkEnzymeWithClad = true;
}

// if enabled, load the dynamic library input from user to use
// as a custom estimation model.
if (m_DO.CustomEstimationModel) {
Expand Down
6 changes: 5 additions & 1 deletion tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ namespace clad {
: DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false),
DumpDerivedAST(false), GenerateSourceFile(false),
ValidateClangVersion(false), CustomEstimationModel(false),
PrintNumDiffErrorInfo(false), CustomModelName("") {}
PrintNumDiffErrorInfo(false), CheckEnzymeWithClad(false),
CustomModelName("") {}

bool DumpSourceFn : 1;
bool DumpSourceFnAST : 1;
Expand All @@ -76,6 +77,7 @@ namespace clad {
bool ValidateClangVersion : 1;
bool CustomEstimationModel : 1;
bool PrintNumDiffErrorInfo : 1;
bool CheckEnzymeWithClad : 1;
std::string CustomModelName;
};

Expand Down Expand Up @@ -157,6 +159,8 @@ namespace clad {
m_DO.CustomModelName = args[i];
} else if (args[i] == "-fprint-num-diff-errors") {
m_DO.PrintNumDiffErrorInfo = true;
} else if(args[i] == "-fcheck-enzyme-with-clad"){
m_DO.CheckEnzymeWithClad = true;
} else if (args[i] == "-help") {
// Print some help info.
llvm::errs()
Expand Down
7 changes: 4 additions & 3 deletions tools/DerivedFnInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ namespace clad {
: m_OriginalFn(request.Function), m_DerivedFn(derivedFn),
m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode),
m_DerivativeOrder(request.CurrentDerivativeOrder),
m_DiffVarsInfo(request.DVI) {}
m_DiffVarsInfo(request.DVI), m_UsesEnzyme(request.use_enzyme) {}

bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const {
return (request.Function == m_OriginalFn && request.Mode == m_Mode &&
request.CurrentDerivativeOrder == m_DerivativeOrder &&
request.DVI == m_DiffVarsInfo);
request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme);
}

bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; }
Expand All @@ -26,6 +26,7 @@ namespace clad {
return lhs.m_OriginalFn == rhs.m_OriginalFn &&
lhs.m_DerivativeOrder == rhs.m_DerivativeOrder &&
lhs.m_Mode == rhs.m_Mode &&
lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo;
lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo &&
lhs.m_UsesEnzyme == rhs.m_UsesEnzyme;
}
} // namespace clad
1 change: 1 addition & 0 deletions tools/DerivedFnInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace clad {
DiffMode m_Mode = DiffMode::unknown;
unsigned m_DerivativeOrder = 0;
DiffInputVarsInfo m_DiffVarsInfo;
bool m_UsesEnzyme= false;

DerivedFnInfo() {}
DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,
Expand Down

0 comments on commit 2a36e90

Please sign in to comment.