From e69ceed6c8b3b6178346930e367086c58d2d1a20 Mon Sep 17 00:00:00 2001 From: Nirhar Date: Thu, 25 Aug 2022 19:22:38 +0530 Subject: [PATCH] Add support for verifying Enzyme Gradients with Clad Gradients This commit generates code that will verify the results of Enzyme Gradients with Clad Gradients. For example, if previously the following code was generated for differentiating with enzyme for a function: ```cpp void f1_grad_enzyme(double arr[2], clad::array_ref _d_arr) { double *d_arr = _d_arr.ptr(); __enzyme_autodiff_f1(f1, arr, d_arr); } ``` The above code will be appended with checks to verify the calculated gradients. Thus the newly generated code would be: ```cpp void f1_grad_enzyme(double arr[2], clad::array_ref _d_arr) { double *d_arr = _d_arr.ptr(); __enzyme_autodiff_f1(f1, arr, d_arr); double cladResult1[2]; f1_grad(arr, cladResult1); EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL); } ``` `EssentiallyEqualArrays` and `EssentiallyEqual` are functions defined in Differentiator.h Only functions with primitive type and ConstantArray type parameters can be verified in this manner. To trigger this verification one must append the following flag to clang while compiling the function to be generated: `-Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad` --- include/clad/Differentiator/DiffPlanner.h | 3 + include/clad/Differentiator/Differentiator.h | 16 ++++ .../clad/Differentiator/ReverseModeVisitor.h | 4 + include/clad/Differentiator/VisitorBase.h | 5 ++ lib/Differentiator/DiffPlanner.cpp | 5 +- lib/Differentiator/ReverseModeVisitor.cpp | 82 +++++++++++++++++++ lib/Differentiator/VisitorBase.cpp | 21 ++++- test/Enzyme/ReverseModeWithCladCheck.C | 51 ++++++++++++ tools/ClangPlugin.cpp | 6 ++ tools/ClangPlugin.h | 6 +- tools/DerivedFnInfo.cpp | 7 +- tools/DerivedFnInfo.h | 1 + 12 files changed, 199 insertions(+), 8 deletions(-) create mode 100644 test/Enzyme/ReverseModeWithCladCheck.C diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 03ad56c07..bdccd8605 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -62,6 +62,9 @@ namespace clad { // A flag to enable the use of enzyme for backend instead of clad bool use_enzyme = false; + // A flag to generate code that verifies clad and enzyme + bool checkEnzymeWithClad = false; + /// Recomputes `DiffInputVarsInfo` using the current values of data members. /// /// Differentiation parameters info is computed by parsing the argument diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 131e8ce15..1e4e2ed4a 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -450,6 +450,22 @@ namespace clad { code); } + void EssentiallyEqual(long double a, long double b) { + //FIXME: We should select epsilon value in a more robust way. + const double epsilon = 1e-12; + // printf("a=%.40f, b=%.40f\n",a,b); + bool ans = std::fabs(a - b) <= + ((std::fabs(a > b) ? std::fabs(b) : std::fabs(a)) * epsilon); + + assert(ans && "Clad Gradient is not equal to Enzyme Gradient"); + } + void EssentiallyEqualArrays(long double* a, long double* b, unsigned size) { + //FIXME: We should select epsilon value in a more robust way. + for(int i=0;i struct EnzymeGradient { double d_arr[N]; }; } diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index f305b4f96..ec557d546 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -64,6 +64,7 @@ namespace clad { unsigned numParams = 0; bool isVectorValued = false; bool use_enzyme = false; + bool checkEnzymeWithClad = false; // FIXME: Should we make this an object instead of a pointer? // Downside of making it an object: We will need to include // 'MultiplexExternalRMVSource.h' file @@ -92,6 +93,9 @@ namespace clad { // Function to Differentiate with Enzyme as Backend void DifferentiateWithEnzyme(); + //Function that inserts code to verify Enzyme Results with Clad Results + void CheckEnzymeResultsWithClad(clang::FunctionDecl* cladFD); + public: using direction = rmv::direction; clang::Expr* dfdx() { diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 4fa9e4cc9..c735e83be 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -374,6 +374,11 @@ namespace clad { /// \returns The declaration of the class with the name ClassName clang::TemplateDecl* LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName); + /// Find declaration of clad::function templated type + /// + /// \param[in] FunctionName name of the function to be found + /// \returns The declaration of the function with the name FunctionName + clang::FunctionDecl* LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName); /// Instantiate clad::class type /// /// \param[in] CladClassDecl the decl of the class that is going to be used diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 63c772286..43ac9e3d3 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -209,8 +209,9 @@ namespace clad { call->setArg(derivedFnArgIdx, newUnOp); // Update the code parameter. - if (CXXDefaultArgExpr* Arg - = dyn_cast(call->getArg(codeArgIdx))) { + CXXDefaultArgExpr* Arg + = dyn_cast(call->getArg(codeArgIdx)); + if (Arg) { clang::LangOptions LangOpts; LangOpts.CPlusPlus = true; clang::PrintingPolicy Policy(LangOpts); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 3b3f311d1..34c97f8e5 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -28,6 +28,7 @@ #include #include +#include #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/Compatibility.h" @@ -268,6 +269,10 @@ namespace clad { if (request.use_enzyme) use_enzyme = true; + if(request.checkEnzymeWithClad){ + checkEnzymeWithClad = true; + } + auto derivativeBaseName = request.BaseFunctionName; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if @@ -413,6 +418,15 @@ namespace clad { else DifferentiateWithEnzyme(); + if(use_enzyme && checkEnzymeWithClad){ + DiffRequest newRequest = const_cast(request); + newRequest.checkEnzymeWithClad = false; + newRequest.use_enzyme = false; + FunctionDecl* cladFD = plugin::ProcessDiffRequest(m_CladPlugin,newRequest); + + CheckEnzymeResultsWithClad(cladFD); + } + gradientBody = endBlock(); m_Derivative->setBody(gradientBody); endScope(); // Function body scope @@ -707,6 +721,74 @@ namespace clad { addToCurrentBlock(enzymeCall); } } + void ReverseModeVisitor::CheckEnzymeResultsWithClad(FunctionDecl* cladFD){ + // Prepare Arguments for the clad derivative function + llvm::SmallVector cladGradArgs; + llvm::SmallVector cladResultDecls; + unsigned numParams = m_Function->getNumParams(); + llvm::ArrayRef paramsRef = m_Derivative->parameters(); + + for(int i=0;igetOriginalType(); + + // FIX-ME: Non Constant Array/pointer type parameters can't be dealt with as of now because we don't know the size of the array + // This code will break if we use array type parameters. This can be fixed if the ReverseModeVisitor keeps track + // of the maximum index of the array seen so far. + + if(isArrayOrPointerType(paramType)){ + assert(paramType->isConstantArrayType() && + "Only Constant type arrays are allowed to be parameters of " + "functions whose gradients we want to verify with clad"); + + auto resultVar = BuildVarDecl(paramType,finalVarName,nullptr,false); + addToCurrentBlock(BuildDeclStmt(resultVar),direction::forward); + cladGradArgs.push_back(BuildDeclRef(resultVar)); + cladResultDecls.push_back(resultVar); + }else{ + auto resultVar = BuildVarDecl(paramType,finalVarName,nullptr,false); + addToCurrentBlock(BuildDeclStmt(resultVar),direction::forward); + cladGradArgs.push_back(BuildOp(UO_AddrOf,BuildDeclRef(resultVar))); + cladResultDecls.push_back(resultVar); + } + } + + Expr* cladCall = BuildCallExprToFunction(cladFD,cladGradArgs); + addToCurrentBlock(cladCall); + + //Compare the values + FunctionDecl* equalityFD = LookupFunctionDeclInCladNamespace("EssentiallyEqual"); + FunctionDecl* equalityFDForArrays = LookupFunctionDeclInCladNamespace("EssentiallyEqualArrays"); + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); + for(int i=0;igetOriginalType(); + llvm::SmallVector equalityCheckArguments; + equalityCheckArguments.push_back(BuildDeclRef(cladResultDecls[i])); + if(paramType->isFloatingType()){ + equalityCheckArguments.push_back(BuildOp(UO_Deref, BuildDeclRef(paramsRef[i+numParams]))); + Expr* checkCall = BuildCallExprToFunction(equalityFD,equalityCheckArguments); + addToCurrentBlock(checkCall); + }else if(paramType->isConstantArrayType()){ + equalityCheckArguments.push_back(BuildCallExprToMemFn(BuildDeclRef(paramsRef[i+numParams]),"ptr",{})); + ConstantArrayType* t = + dyn_cast(const_cast(paramType.getTypePtr())); + int sizeOfArray = (int)(t->getSize().roundToDouble(false)); + llvm::APInt idxValue(size_type_bits, sizeOfArray); + auto idx = IntegerLiteral::Create(m_Context, idxValue, size_type, noLoc); + equalityCheckArguments.push_back(idx); + + Expr* checkCall = BuildCallExprToFunction(equalityFDForArrays,equalityCheckArguments); + addToCurrentBlock(checkCall); + } + } + } + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( DiagnosticsEngine::Warning, diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 127829aeb..e382c70f5 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -24,6 +24,7 @@ #include #include +#include #include "clad/Differentiator/Compatibility.h" @@ -366,8 +367,24 @@ namespace clad { return cast(TapeR.getFoundDecl()); } - QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, - TemplateArgumentListInfo& TLI) { + FunctionDecl* VisitorBase::LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName){ + NamespaceDecl* CladNS = GetCladNamespace(); + CXXScopeSpec CSS; + CSS.Extend(m_Context, CladNS, noLoc, noLoc); + DeclarationName TapeName = &m_Context.Idents.get(FunctionName); + LookupResult TapeR(m_Sema, + TapeName, + noLoc, + Sema::LookupUsingDeclName, + clad_compat::Sema_ForVisibleRedeclaration); + m_Sema.LookupQualifiedName(TapeR, CladNS, CSS); + assert(!TapeR.empty() && isa(TapeR.getFoundDecl()) && + "cannot find clad::tape"); + return cast(TapeR.getFoundDecl()); + } + +QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, + TemplateArgumentListInfo& TLI) { // This will instantiate tape type and return it. QualType TT = m_Sema.CheckTemplateIdType(TemplateName(CladClassDecl), noLoc, TLI); diff --git a/test/Enzyme/ReverseModeWithCladCheck.C b/test/Enzyme/ReverseModeWithCladCheck.C new file mode 100644 index 000000000..f35b8c745 --- /dev/null +++ b/test/Enzyme/ReverseModeWithCladCheck.C @@ -0,0 +1,51 @@ +// RUN: %cladclang %s -I%S/../../include -Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad -oReverseModeWithCladCheck.out | FileCheck %s +// RUN: ./ReverseModeWithCladCheck.out | FileCheck -check-prefix=CHECK-EXEC %s +// CHECK-NOT: {{.*error|warning|note:.*}} +// REQUIRES: Enzyme +// XFAIL:* + +#include "clad/Differentiator/Differentiator.h" + +double f1(double arr[2]) { return arr[0] * arr[1]; } + +// CHECK: void f1_grad_enzyme(double arr[2], clad::array_ref _d_arr) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr); +// CHECK-NEXT: double cladResult1[2]; +// CHECK-NEXT: f1_grad(arr, cladResult1); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL); +// CHECK-NEXT:} + +double f2(double x, double y, double z){ + return x * y * z; +} + +// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref _d_x, clad::array_ref _d_y, clad::array_ref _d_z) { +// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: * _d_z = grad.d_arr[2U]; +// CHECK-NEXT: double cladResult1; +// CHECK-NEXT: double cladResult2; +// CHECK-NEXT: double cladResult3; +// CHECK-NEXT: f2_grad(x, y, z, &cladResult1, &cladResult2, &cladResult3); +// CHECK-NEXT: EssentiallyEqual(cladResult1, * _d_x); +// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_y); +// CHECK-NEXT: EssentiallyEqual(cladResult3, * _d_z); +// CHECK-NEXT:} + +int main() { + auto f1_grad = clad::gradient(f1); + double f1_v[2] = {3, 4}; + double f1_g[2] = {0}; + f1_grad.execute(f1_v, f1_g); + printf("d_x = %.2f, d_y = %.2f\n", f1_g[0], f1_g[1]); + // CHECK-EXEC: d_x = 4.00, d_y = 3.00 + + auto f2_grad=clad::gradient(f2); + double f2_res[3]; + double f2_x=3,f2_y=4,f2_z=5; + f2_grad.execute(f2_x,f2_y,f2_z,&f2_res[0],&f2_res[1],&f2_res[2]); + printf("d_x = %.2f, d_y = %.2f, d_z = %.2f\n", f2_res[0], f2_res[1], f2_res[2]); + //CHECK-EXEC: d_x = 20.00, d_y = 15.00, d_z = 12.00 +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0fb32d832..c41881474 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -182,6 +182,12 @@ namespace clad { if (m_DO.DumpSourceFnAST) { FD->dumpColor(); } + + // If enabled, update request to also compare enzyme and clad results + if(m_DO.CheckEnzymeWithClad){ + request.checkEnzymeWithClad = true; + } + // if enabled, load the dynamic library input from user to use // as a custom estimation model. if (m_DO.CustomEstimationModel) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 6d5e4a82a..ee3117f3c 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -66,7 +66,8 @@ namespace clad { : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), DumpDerivedAST(false), GenerateSourceFile(false), ValidateClangVersion(false), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false), CustomModelName("") {} + PrintNumDiffErrorInfo(false), CheckEnzymeWithClad(false), + CustomModelName("") {} bool DumpSourceFn : 1; bool DumpSourceFnAST : 1; @@ -76,6 +77,7 @@ namespace clad { bool ValidateClangVersion : 1; bool CustomEstimationModel : 1; bool PrintNumDiffErrorInfo : 1; + bool CheckEnzymeWithClad : 1; std::string CustomModelName; }; @@ -157,6 +159,8 @@ namespace clad { m_DO.CustomModelName = args[i]; } else if (args[i] == "-fprint-num-diff-errors") { m_DO.PrintNumDiffErrorInfo = true; + } else if(args[i] == "-fcheck-enzyme-with-clad"){ + m_DO.CheckEnzymeWithClad = true; } else if (args[i] == "-help") { // Print some help info. llvm::errs() diff --git a/tools/DerivedFnInfo.cpp b/tools/DerivedFnInfo.cpp index d0f8251a4..8c50ef623 100644 --- a/tools/DerivedFnInfo.cpp +++ b/tools/DerivedFnInfo.cpp @@ -11,12 +11,12 @@ namespace clad { : m_OriginalFn(request.Function), m_DerivedFn(derivedFn), m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode), m_DerivativeOrder(request.CurrentDerivativeOrder), - m_DiffVarsInfo(request.DVI) {} + m_DiffVarsInfo(request.DVI), m_UsesEnzyme(request.use_enzyme) {} bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { return (request.Function == m_OriginalFn && request.Mode == m_Mode && request.CurrentDerivativeOrder == m_DerivativeOrder && - request.DVI == m_DiffVarsInfo); + request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme); } bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; } @@ -26,6 +26,7 @@ namespace clad { return lhs.m_OriginalFn == rhs.m_OriginalFn && lhs.m_DerivativeOrder == rhs.m_DerivativeOrder && lhs.m_Mode == rhs.m_Mode && - lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo; + lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo && + lhs.m_UsesEnzyme == rhs.m_UsesEnzyme; } } // namespace clad \ No newline at end of file diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h index 5b45201fc..637e1a15a 100644 --- a/tools/DerivedFnInfo.h +++ b/tools/DerivedFnInfo.h @@ -17,6 +17,7 @@ namespace clad { DiffMode m_Mode = DiffMode::unknown; unsigned m_DerivativeOrder = 0; DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme= false; DerivedFnInfo() {} DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,