diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 03ad56c07..bdccd8605 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -62,6 +62,9 @@ namespace clad { // A flag to enable the use of enzyme for backend instead of clad bool use_enzyme = false; + // A flag to generate code that verifies clad and enzyme + bool checkEnzymeWithClad = false; + /// Recomputes `DiffInputVarsInfo` using the current values of data members. /// /// Differentiation parameters info is computed by parsing the argument diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 131e8ce15..1e4e2ed4a 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -450,6 +450,22 @@ namespace clad { code); } + void EssentiallyEqual(long double a, long double b) { + //FIXME: We should select epsilon value in a more robust way. + const double epsilon = 1e-12; + // printf("a=%.40f, b=%.40f\n",a,b); + bool ans = std::fabs(a - b) <= + ((std::fabs(a > b) ? std::fabs(b) : std::fabs(a)) * epsilon); + + assert(ans && "Clad Gradient is not equal to Enzyme Gradient"); + } + void EssentiallyEqualArrays(long double* a, long double* b, unsigned size) { + //FIXME: We should select epsilon value in a more robust way. + for(int i=0;i struct EnzymeGradient { double d_arr[N]; }; } diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index f305b4f96..ec557d546 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -64,6 +64,7 @@ namespace clad { unsigned numParams = 0; bool isVectorValued = false; bool use_enzyme = false; + bool checkEnzymeWithClad = false; // FIXME: Should we make this an object instead of a pointer? // Downside of making it an object: We will need to include // 'MultiplexExternalRMVSource.h' file @@ -92,6 +93,9 @@ namespace clad { // Function to Differentiate with Enzyme as Backend void DifferentiateWithEnzyme(); + //Function that inserts code to verify Enzyme Results with Clad Results + void CheckEnzymeResultsWithClad(clang::FunctionDecl* cladFD); + public: using direction = rmv::direction; clang::Expr* dfdx() { diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 4fa9e4cc9..c735e83be 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -374,6 +374,11 @@ namespace clad { /// \returns The declaration of the class with the name ClassName clang::TemplateDecl* LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName); + /// Find declaration of clad::function templated type + /// + /// \param[in] FunctionName name of the function to be found + /// \returns The declaration of the function with the name FunctionName + clang::FunctionDecl* LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName); /// Instantiate clad::class type /// /// \param[in] CladClassDecl the decl of the class that is going to be used diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 63c772286..43ac9e3d3 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -209,8 +209,9 @@ namespace clad { call->setArg(derivedFnArgIdx, newUnOp); // Update the code parameter. - if (CXXDefaultArgExpr* Arg - = dyn_cast(call->getArg(codeArgIdx))) { + CXXDefaultArgExpr* Arg + = dyn_cast(call->getArg(codeArgIdx)); + if (Arg) { clang::LangOptions LangOpts; LangOpts.CPlusPlus = true; clang::PrintingPolicy Policy(LangOpts); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 3b3f311d1..34c97f8e5 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -28,6 +28,7 @@ #include #include +#include #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/Compatibility.h" @@ -268,6 +269,10 @@ namespace clad { if (request.use_enzyme) use_enzyme = true; + if(request.checkEnzymeWithClad){ + checkEnzymeWithClad = true; + } + auto derivativeBaseName = request.BaseFunctionName; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if @@ -413,6 +418,15 @@ namespace clad { else DifferentiateWithEnzyme(); + if(use_enzyme && checkEnzymeWithClad){ + DiffRequest newRequest = const_cast(request); + newRequest.checkEnzymeWithClad = false; + newRequest.use_enzyme = false; + FunctionDecl* cladFD = plugin::ProcessDiffRequest(m_CladPlugin,newRequest); + + CheckEnzymeResultsWithClad(cladFD); + } + gradientBody = endBlock(); m_Derivative->setBody(gradientBody); endScope(); // Function body scope @@ -707,6 +721,74 @@ namespace clad { addToCurrentBlock(enzymeCall); } } + void ReverseModeVisitor::CheckEnzymeResultsWithClad(FunctionDecl* cladFD){ + // Prepare Arguments for the clad derivative function + llvm::SmallVector cladGradArgs; + llvm::SmallVector cladResultDecls; + unsigned numParams = m_Function->getNumParams(); + llvm::ArrayRef paramsRef = m_Derivative->parameters(); + + for(int i=0;igetOriginalType(); + + // FIX-ME: Non Constant Array/pointer type parameters can't be dealt with as of now because we don't know the size of the array + // This code will break if we use array type parameters. This can be fixed if the ReverseModeVisitor keeps track + // of the maximum index of the array seen so far. + + if(isArrayOrPointerType(paramType)){ + assert(paramType->isConstantArrayType() && + "Only Constant type arrays are allowed to be parameters of " + "functions whose gradients we want to verify with clad"); + + auto resultVar = BuildVarDecl(paramType,finalVarName,nullptr,false); + addToCurrentBlock(BuildDeclStmt(resultVar),direction::forward); + cladGradArgs.push_back(BuildDeclRef(resultVar)); + cladResultDecls.push_back(resultVar); + }else{ + auto resultVar = BuildVarDecl(paramType,finalVarName,nullptr,false); + addToCurrentBlock(BuildDeclStmt(resultVar),direction::forward); + cladGradArgs.push_back(BuildOp(UO_AddrOf,BuildDeclRef(resultVar))); + cladResultDecls.push_back(resultVar); + } + } + + Expr* cladCall = BuildCallExprToFunction(cladFD,cladGradArgs); + addToCurrentBlock(cladCall); + + //Compare the values + FunctionDecl* equalityFD = LookupFunctionDeclInCladNamespace("EssentiallyEqual"); + FunctionDecl* equalityFDForArrays = LookupFunctionDeclInCladNamespace("EssentiallyEqualArrays"); + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); + for(int i=0;igetOriginalType(); + llvm::SmallVector equalityCheckArguments; + equalityCheckArguments.push_back(BuildDeclRef(cladResultDecls[i])); + if(paramType->isFloatingType()){ + equalityCheckArguments.push_back(BuildOp(UO_Deref, BuildDeclRef(paramsRef[i+numParams]))); + Expr* checkCall = BuildCallExprToFunction(equalityFD,equalityCheckArguments); + addToCurrentBlock(checkCall); + }else if(paramType->isConstantArrayType()){ + equalityCheckArguments.push_back(BuildCallExprToMemFn(BuildDeclRef(paramsRef[i+numParams]),"ptr",{})); + ConstantArrayType* t = + dyn_cast(const_cast(paramType.getTypePtr())); + int sizeOfArray = (int)(t->getSize().roundToDouble(false)); + llvm::APInt idxValue(size_type_bits, sizeOfArray); + auto idx = IntegerLiteral::Create(m_Context, idxValue, size_type, noLoc); + equalityCheckArguments.push_back(idx); + + Expr* checkCall = BuildCallExprToFunction(equalityFDForArrays,equalityCheckArguments); + addToCurrentBlock(checkCall); + } + } + } + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( DiagnosticsEngine::Warning, diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 127829aeb..e382c70f5 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -24,6 +24,7 @@ #include #include +#include #include "clad/Differentiator/Compatibility.h" @@ -366,8 +367,24 @@ namespace clad { return cast(TapeR.getFoundDecl()); } - QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, - TemplateArgumentListInfo& TLI) { + FunctionDecl* VisitorBase::LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName){ + NamespaceDecl* CladNS = GetCladNamespace(); + CXXScopeSpec CSS; + CSS.Extend(m_Context, CladNS, noLoc, noLoc); + DeclarationName TapeName = &m_Context.Idents.get(FunctionName); + LookupResult TapeR(m_Sema, + TapeName, + noLoc, + Sema::LookupUsingDeclName, + clad_compat::Sema_ForVisibleRedeclaration); + m_Sema.LookupQualifiedName(TapeR, CladNS, CSS); + assert(!TapeR.empty() && isa(TapeR.getFoundDecl()) && + "cannot find clad::tape"); + return cast(TapeR.getFoundDecl()); + } + +QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, + TemplateArgumentListInfo& TLI) { // This will instantiate tape type and return it. QualType TT = m_Sema.CheckTemplateIdType(TemplateName(CladClassDecl), noLoc, TLI); diff --git a/test/Enzyme/ReverseModeWithCladCheck.C b/test/Enzyme/ReverseModeWithCladCheck.C new file mode 100644 index 000000000..f35b8c745 --- /dev/null +++ b/test/Enzyme/ReverseModeWithCladCheck.C @@ -0,0 +1,51 @@ +// RUN: %cladclang %s -I%S/../../include -Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad -oReverseModeWithCladCheck.out | FileCheck %s +// RUN: ./ReverseModeWithCladCheck.out | FileCheck -check-prefix=CHECK-EXEC %s +// CHECK-NOT: {{.*error|warning|note:.*}} +// REQUIRES: Enzyme +// XFAIL:* + +#include "clad/Differentiator/Differentiator.h" + +double f1(double arr[2]) { return arr[0] * arr[1]; } + +// CHECK: void f1_grad_enzyme(double arr[2], clad::array_ref _d_arr) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr); +// CHECK-NEXT: double cladResult1[2]; +// CHECK-NEXT: f1_grad(arr, cladResult1); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL); +// CHECK-NEXT:} + +double f2(double x, double y, double z){ + return x * y * z; +} + +// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref _d_x, clad::array_ref _d_y, clad::array_ref _d_z) { +// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: * _d_z = grad.d_arr[2U]; +// CHECK-NEXT: double cladResult1; +// CHECK-NEXT: double cladResult2; +// CHECK-NEXT: double cladResult3; +// CHECK-NEXT: f2_grad(x, y, z, &cladResult1, &cladResult2, &cladResult3); +// CHECK-NEXT: EssentiallyEqual(cladResult1, * _d_x); +// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_y); +// CHECK-NEXT: EssentiallyEqual(cladResult3, * _d_z); +// CHECK-NEXT:} + +int main() { + auto f1_grad = clad::gradient(f1); + double f1_v[2] = {3, 4}; + double f1_g[2] = {0}; + f1_grad.execute(f1_v, f1_g); + printf("d_x = %.2f, d_y = %.2f\n", f1_g[0], f1_g[1]); + // CHECK-EXEC: d_x = 4.00, d_y = 3.00 + + auto f2_grad=clad::gradient(f2); + double f2_res[3]; + double f2_x=3,f2_y=4,f2_z=5; + f2_grad.execute(f2_x,f2_y,f2_z,&f2_res[0],&f2_res[1],&f2_res[2]); + printf("d_x = %.2f, d_y = %.2f, d_z = %.2f\n", f2_res[0], f2_res[1], f2_res[2]); + //CHECK-EXEC: d_x = 20.00, d_y = 15.00, d_z = 12.00 +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0fb32d832..c41881474 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -182,6 +182,12 @@ namespace clad { if (m_DO.DumpSourceFnAST) { FD->dumpColor(); } + + // If enabled, update request to also compare enzyme and clad results + if(m_DO.CheckEnzymeWithClad){ + request.checkEnzymeWithClad = true; + } + // if enabled, load the dynamic library input from user to use // as a custom estimation model. if (m_DO.CustomEstimationModel) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 6d5e4a82a..ee3117f3c 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -66,7 +66,8 @@ namespace clad { : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), DumpDerivedAST(false), GenerateSourceFile(false), ValidateClangVersion(false), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false), CustomModelName("") {} + PrintNumDiffErrorInfo(false), CheckEnzymeWithClad(false), + CustomModelName("") {} bool DumpSourceFn : 1; bool DumpSourceFnAST : 1; @@ -76,6 +77,7 @@ namespace clad { bool ValidateClangVersion : 1; bool CustomEstimationModel : 1; bool PrintNumDiffErrorInfo : 1; + bool CheckEnzymeWithClad : 1; std::string CustomModelName; }; @@ -157,6 +159,8 @@ namespace clad { m_DO.CustomModelName = args[i]; } else if (args[i] == "-fprint-num-diff-errors") { m_DO.PrintNumDiffErrorInfo = true; + } else if(args[i] == "-fcheck-enzyme-with-clad"){ + m_DO.CheckEnzymeWithClad = true; } else if (args[i] == "-help") { // Print some help info. llvm::errs() diff --git a/tools/DerivedFnInfo.cpp b/tools/DerivedFnInfo.cpp index d0f8251a4..8c50ef623 100644 --- a/tools/DerivedFnInfo.cpp +++ b/tools/DerivedFnInfo.cpp @@ -11,12 +11,12 @@ namespace clad { : m_OriginalFn(request.Function), m_DerivedFn(derivedFn), m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode), m_DerivativeOrder(request.CurrentDerivativeOrder), - m_DiffVarsInfo(request.DVI) {} + m_DiffVarsInfo(request.DVI), m_UsesEnzyme(request.use_enzyme) {} bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { return (request.Function == m_OriginalFn && request.Mode == m_Mode && request.CurrentDerivativeOrder == m_DerivativeOrder && - request.DVI == m_DiffVarsInfo); + request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme); } bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; } @@ -26,6 +26,7 @@ namespace clad { return lhs.m_OriginalFn == rhs.m_OriginalFn && lhs.m_DerivativeOrder == rhs.m_DerivativeOrder && lhs.m_Mode == rhs.m_Mode && - lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo; + lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo && + lhs.m_UsesEnzyme == rhs.m_UsesEnzyme; } } // namespace clad \ No newline at end of file diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h index 5b45201fc..637e1a15a 100644 --- a/tools/DerivedFnInfo.h +++ b/tools/DerivedFnInfo.h @@ -17,6 +17,7 @@ namespace clad { DiffMode m_Mode = DiffMode::unknown; unsigned m_DerivativeOrder = 0; DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme= false; DerivedFnInfo() {} DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,