diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 03ad56c07..bdccd8605 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -62,6 +62,9 @@ namespace clad { // A flag to enable the use of enzyme for backend instead of clad bool use_enzyme = false; + // A flag to generate code that verifies clad and enzyme + bool checkEnzymeWithClad = false; + /// Recomputes `DiffInputVarsInfo` using the current values of data members. /// /// Differentiation parameters info is computed by parsing the argument diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 131e8ce15..bb0d56147 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -450,6 +450,22 @@ namespace clad { code); } + void EssentiallyEqual(double a, double b) { + // FIXME: We should select epsilon value in a more robust way. + const double epsilon = 1e-12; + // printf("a=%.40f, b=%.40f\n",a,b); + bool ans = std::fabs(a - b) <= + ((std::fabs(a > b) ? std::fabs(b) : std::fabs(a)) * epsilon); + + assert(ans && "Clad Gradient is not equal to Enzyme Gradient"); + } + void EssentiallyEqualArrays(double* a, double* b, unsigned size) { + // FIXME: We should select epsilon value in a more robust way. + for (int i = 0; i < size; i++) { + EssentiallyEqual(a[i], b[i]); + } + } + // Gradient Structure for Reverse Mode Enzyme template struct EnzymeGradient { double d_arr[N]; }; } diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index f305b4f96..065150ca4 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -64,6 +64,7 @@ namespace clad { unsigned numParams = 0; bool isVectorValued = false; bool use_enzyme = false; + bool checkEnzymeWithClad = false; // FIXME: Should we make this an object instead of a pointer? // Downside of making it an object: We will need to include // 'MultiplexExternalRMVSource.h' file @@ -92,6 +93,9 @@ namespace clad { // Function to Differentiate with Enzyme as Backend void DifferentiateWithEnzyme(); + // Function that inserts code to verify Enzyme Results with Clad Results + void CheckEnzymeResultsWithClad(clang::FunctionDecl* cladFD); + public: using direction = rmv::direction; clang::Expr* dfdx() { diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 4fa9e4cc9..f539a1532 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -374,6 +374,12 @@ namespace clad { /// \returns The declaration of the class with the name ClassName clang::TemplateDecl* LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName); + /// Find declaration of clad::function templated type + /// + /// \param[in] FunctionName name of the function to be found + /// \returns The declaration of the function with the name FunctionName + clang::FunctionDecl* + LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName); /// Instantiate clad::class type /// /// \param[in] CladClassDecl the decl of the class that is going to be used diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 3b3f311d1..e5a108bc9 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -268,6 +268,10 @@ namespace clad { if (request.use_enzyme) use_enzyme = true; + if (request.checkEnzymeWithClad) { + checkEnzymeWithClad = true; + } + auto derivativeBaseName = request.BaseFunctionName; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if @@ -413,6 +417,16 @@ namespace clad { else DifferentiateWithEnzyme(); + if (use_enzyme && checkEnzymeWithClad) { + DiffRequest newRequest = const_cast(request); + newRequest.checkEnzymeWithClad = false; + newRequest.use_enzyme = false; + FunctionDecl* cladFD = + plugin::ProcessDiffRequest(m_CladPlugin, newRequest); + + CheckEnzymeResultsWithClad(cladFD); + } + gradientBody = endBlock(); m_Derivative->setBody(gradientBody); endScope(); // Function body scope @@ -707,6 +721,101 @@ namespace clad { addToCurrentBlock(enzymeCall); } } + void ReverseModeVisitor::CheckEnzymeResultsWithClad(FunctionDecl* cladFD) { + // Prepare Arguments for the clad derivative function + llvm::SmallVector cladGradArgs; + llvm::SmallVector cladResultDecls; + unsigned numParams = m_Function->getNumParams(); + llvm::ArrayRef paramsRef = m_Derivative->parameters(); + + for (int i = 0; i < numParams; i++) { + cladGradArgs.push_back(BuildDeclRef(paramsRef[i])); + } + std::string varNames = "cladResult"; + int varNo = 1; + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); + for (int i = 0; i < numParams; i++) { + std::string finalVarName = varNames + std::to_string(varNo++); + auto paramType = paramsRef[i]->getOriginalType(); + + // FIX-ME: Non Constant Array/pointer type parameters can't be dealt with + // as of now because we don't know the size of the array This code will + // break if we use array type parameters. This can be fixed if the + // ReverseModeVisitor keeps track of the maximum index of the array seen + // so far. + + if (isArrayOrPointerType(paramType)) { + assert(paramType->isConstantArrayType() && + "Only Constant type arrays are allowed to be parameters of " + "functions whose gradients we want to verify with clad"); + + // Create InitList to set all elements of the result array to zero + auto init = FloatingLiteral::Create( + m_Context, llvm::APFloat(0.0), true, + dyn_cast(paramType)->getElementType(), noLoc); + llvm::SmallVector initListElement{init}; + auto initList = dyn_cast( + m_Sema.BuildInitList(noLoc, initListElement, noLoc).get()); + ImplicitValueInitExpr imp( + dyn_cast(paramType)->getElementType()); + initList->setArrayFiller(&imp); + + auto resultVar = BuildVarDecl(paramType, finalVarName, initList, true); + addToCurrentBlock(BuildDeclStmt(resultVar), direction::forward); + cladGradArgs.push_back(BuildDeclRef(resultVar)); + cladResultDecls.push_back(resultVar); + } else { + VarDecl* resultVar; + if (paramType->isFloatingType()) { + auto init = FloatingLiteral::Create(m_Context, llvm::APFloat(0.0), + true, paramType, noLoc); + resultVar = BuildVarDecl(paramType, finalVarName, init, true); + } else { + resultVar = BuildVarDecl(paramType, finalVarName, nullptr, true); + } + addToCurrentBlock(BuildDeclStmt(resultVar), direction::forward); + cladGradArgs.push_back(BuildOp(UO_AddrOf, BuildDeclRef(resultVar))); + cladResultDecls.push_back(resultVar); + } + } + + Expr* cladCall = BuildCallExprToFunction(cladFD, cladGradArgs); + addToCurrentBlock(cladCall); + + // Compare the values + FunctionDecl* equalityFD = + LookupFunctionDeclInCladNamespace("EssentiallyEqual"); + FunctionDecl* equalityFDForArrays = + LookupFunctionDeclInCladNamespace("EssentiallyEqualArrays"); + for (int i = 0; i < numParams; i++) { + auto paramType = paramsRef[i]->getOriginalType(); + llvm::SmallVector equalityCheckArguments; + equalityCheckArguments.push_back(BuildDeclRef(cladResultDecls[i])); + if (paramType->isFloatingType()) { + equalityCheckArguments.push_back( + BuildOp(UO_Deref, BuildDeclRef(paramsRef[i + numParams]))); + Expr* checkCall = + BuildCallExprToFunction(equalityFD, equalityCheckArguments); + addToCurrentBlock(checkCall); + } else if (paramType->isConstantArrayType()) { + equalityCheckArguments.push_back(BuildCallExprToMemFn( + BuildDeclRef(paramsRef[i + numParams]), "ptr", {})); + ConstantArrayType* t = dyn_cast( + const_cast(paramType.getTypePtr())); + int sizeOfArray = (int)(t->getSize().roundToDouble(false)); + llvm::APInt idxValue(size_type_bits, sizeOfArray); + auto idx = + IntegerLiteral::Create(m_Context, idxValue, size_type, noLoc); + equalityCheckArguments.push_back(idx); + + Expr* checkCall = BuildCallExprToFunction(equalityFDForArrays, + equalityCheckArguments); + addToCurrentBlock(checkCall); + } + } + } + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( DiagnosticsEngine::Warning, diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 127829aeb..0123df5ed 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -366,6 +366,20 @@ namespace clad { return cast(TapeR.getFoundDecl()); } + FunctionDecl* + VisitorBase::LookupFunctionDeclInCladNamespace(llvm::StringRef FunctionName) { + NamespaceDecl* CladNS = GetCladNamespace(); + CXXScopeSpec CSS; + CSS.Extend(m_Context, CladNS, noLoc, noLoc); + DeclarationName TapeName = &m_Context.Idents.get(FunctionName); + LookupResult TapeR(m_Sema, TapeName, noLoc, Sema::LookupUsingDeclName, + clad_compat::Sema_ForVisibleRedeclaration); + m_Sema.LookupQualifiedName(TapeR, CladNS, CSS); + assert(!TapeR.empty() && isa(TapeR.getFoundDecl()) && + "cannot find clad::tape"); + return cast(TapeR.getFoundDecl()); + } + QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, TemplateArgumentListInfo& TLI) { // This will instantiate tape type and return it. diff --git a/test/Enzyme/ReverseModeWithCladCheck.C b/test/Enzyme/ReverseModeWithCladCheck.C new file mode 100644 index 000000000..da7bd127e --- /dev/null +++ b/test/Enzyme/ReverseModeWithCladCheck.C @@ -0,0 +1,144 @@ +// RUN: %cladclang %s -I%S/../../include -Xclang -plugin-arg-clad -Xclang -fcheck-enzyme-with-clad -lstdc++ -oReverseModeWithCladCheck.out | FileCheck %s +// RUN: ./ReverseModeWithCladCheck.out | FileCheck -check-prefix=CHECK-EXEC %s +// CHECK-NOT: {{.*error|warning|note:.*}} +// REQUIRES: Enzyme + +#include "clad/Differentiator/Differentiator.h" + +double f1(double arr[2]) { return arr[0] * arr[1]; } + +// CHECK: void f1_grad_enzyme(double arr[2], clad::array_ref _d_arr) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr); +// CHECK-NEXT: double cladResult1[2] = {0.}; +// CHECK-NEXT: f1_grad(arr, cladResult1); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 2UL); +// CHECK-NEXT:} + +double f2(double x, double y, double z){ + return x * y * z; +} + +// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref _d_x, clad::array_ref _d_y, clad::array_ref _d_z) { +// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: * _d_z = grad.d_arr[2U]; +// CHECK-NEXT: double cladResult1 = 0.; +// CHECK-NEXT: double cladResult2 = 0.; +// CHECK-NEXT: double cladResult3 = 0.; +// CHECK-NEXT: f2_grad(x, y, z, &cladResult1, &cladResult2, &cladResult3); +// CHECK-NEXT: EssentiallyEqual(cladResult1, * _d_x); +// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_y); +// CHECK-NEXT: EssentiallyEqual(cladResult3, * _d_z); +// CHECK-NEXT:} + +double f3(double arr[3], int n){ + double sum=0; + for(int i=0;i _d_arr, clad::array_ref _d_n) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, d_arr, n); +// CHECK-NEXT: double cladResult1[3] = {0.}; +// CHECK-NEXT: int cladResult2; +// CHECK-NEXT: f3_grad(arr, n, cladResult1, &cladResult2); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 3UL); +// CHECK-NEXT: } + +double f4(double arr1[3], int n, double arr2[2], int m){ + double sum=0; + for(int i=0;i _d_arr1, clad::array_ref _d_n, clad::array_ref _d_arr2, clad::array_ref _d_m) { +// CHECK-NEXT: double *d_arr1 = _d_arr1.ptr(); +// CHECK-NEXT: double *d_arr2 = _d_arr2.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, d_arr1, n, arr2, d_arr2, m); +// CHECK-NEXT: double cladResult1[3] = {0.}; +// CHECK-NEXT: int cladResult2; +// CHECK-NEXT: double cladResult3[2] = {0.}; +// CHECK-NEXT: int cladResult4; +// CHECK-NEXT: f4_grad(arr1, n, arr2, m, cladResult1, &cladResult2, cladResult3, &cladResult4); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr1.ptr(), 3UL); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult3, _d_arr2.ptr(), 2UL); +// CHECK-NEXT: } + +double f5(double arr[3], double x,int n,double y){ + double res=0; + for(int i=0;i _d_arr, clad::array_ref _d_x, clad::array_ref _d_n, clad::array_ref _d_y) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, d_arr, x, n, y); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: double cladResult1[3] = {0.}; +// CHECK-NEXT: double cladResult2 = 0.; +// CHECK-NEXT: int cladResult3; +// CHECK-NEXT: double cladResult4 = 0.; +// CHECK-NEXT: f5_grad(arr, x, n, y, cladResult1, &cladResult2, &cladResult3, &cladResult4); +// CHECK-NEXT: EssentiallyEqualArrays(cladResult1, _d_arr.ptr(), 3UL); +// CHECK-NEXT: EssentiallyEqual(cladResult2, * _d_x); +// CHECK-NEXT: EssentiallyEqual(cladResult4, * _d_y); +// CHECK-NEXT: } + + +int main() { + auto f1_grad = clad::gradient(f1); + double f1_v[2] = {3, 4}; + double f1_g[2] = {0}; + f1_grad.execute(f1_v, f1_g); + printf("d_x = %.2f, d_y = %.2f\n", f1_g[0], f1_g[1]); + // CHECK-EXEC: d_x = 4.00, d_y = 3.00 + + auto f2_grad=clad::gradient(f2); + double f2_res[3]={0}; + double f2_x=3,f2_y=4,f2_z=5; + f2_grad.execute(f2_x,f2_y,f2_z,&f2_res[0],&f2_res[1],&f2_res[2]); + printf("d_x = %.2f, d_y = %.2f, d_z = %.2f\n", f2_res[0], f2_res[1], f2_res[2]); + //CHECK-EXEC: d_x = 20.00, d_y = 15.00, d_z = 12.00 + + auto f3_grad=clad::gradient(f3); + double f3_list[3]={3,4,5}; + double f3_res[3]={0}; + int f3_dn=0; + f3_grad.execute(f3_list,3,f3_res,&f3_dn); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n = %d\n",f3_res[0],f3_res[1],f3_res[2],f3_dn); + //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n = 0 + + auto f4_grad=clad::gradient(f4); + double f4_list1[3]={3,4,5}; + double f4_list2[2]={1,2}; + double f4_res1[3]={0}; + double f4_res2[2]={0}; + int f4_dn1=0,f4_dn2=0; + f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,&f4_dn1,f4_res2,&f4_dn2); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d\n",f4_res1[0],f4_res1[1],f4_res1[2],f4_dn1); + //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n1 = 0 + printf("d_y1 = %.2f, d_y2 = %.2f, d_n2 = %d\n",f4_res2[0],f4_res2[1],f4_dn2); + //CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00, d_n2 = 0 + + auto f5_grad=clad::gradient(f5); + double f5_list[3]={3,4,5}; + double f5_res[3]={0}; + double f5_x=10.0,f5_dx=0,f5_y=5,f5_dy=0; + int f5_dn=0; + f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dn,&f5_dy); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dn, f5_dx, f5_dy); + //CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_n1 = 0, d_x = 60.00, d_y = 120.00 +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0fb32d832..0e8d43e0c 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -182,6 +182,12 @@ namespace clad { if (m_DO.DumpSourceFnAST) { FD->dumpColor(); } + + // If enabled, update request to also compare enzyme and clad results + if (m_DO.CheckEnzymeWithClad) { + request.checkEnzymeWithClad = true; + } + // if enabled, load the dynamic library input from user to use // as a custom estimation model. if (m_DO.CustomEstimationModel) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 6d5e4a82a..7e016970f 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -66,7 +66,8 @@ namespace clad { : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), DumpDerivedAST(false), GenerateSourceFile(false), ValidateClangVersion(false), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false), CustomModelName("") {} + PrintNumDiffErrorInfo(false), CheckEnzymeWithClad(false), + CustomModelName("") {} bool DumpSourceFn : 1; bool DumpSourceFnAST : 1; @@ -76,6 +77,7 @@ namespace clad { bool ValidateClangVersion : 1; bool CustomEstimationModel : 1; bool PrintNumDiffErrorInfo : 1; + bool CheckEnzymeWithClad : 1; std::string CustomModelName; }; @@ -157,6 +159,8 @@ namespace clad { m_DO.CustomModelName = args[i]; } else if (args[i] == "-fprint-num-diff-errors") { m_DO.PrintNumDiffErrorInfo = true; + } else if (args[i] == "-fcheck-enzyme-with-clad") { + m_DO.CheckEnzymeWithClad = true; } else if (args[i] == "-help") { // Print some help info. llvm::errs() diff --git a/tools/DerivedFnInfo.cpp b/tools/DerivedFnInfo.cpp index d0f8251a4..a5c708dfa 100644 --- a/tools/DerivedFnInfo.cpp +++ b/tools/DerivedFnInfo.cpp @@ -5,18 +5,18 @@ using namespace clang; namespace clad { - DerivedFnInfo::DerivedFnInfo(const DiffRequest& request, - FunctionDecl* derivedFn, - FunctionDecl* overloadedDerivedFn) - : m_OriginalFn(request.Function), m_DerivedFn(derivedFn), - m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode), - m_DerivativeOrder(request.CurrentDerivativeOrder), - m_DiffVarsInfo(request.DVI) {} +DerivedFnInfo::DerivedFnInfo(const DiffRequest& request, + FunctionDecl* derivedFn, + FunctionDecl* overloadedDerivedFn) + : m_OriginalFn(request.Function), m_DerivedFn(derivedFn), + m_OverloadedDerivedFn(overloadedDerivedFn), m_Mode(request.Mode), + m_DerivativeOrder(request.CurrentDerivativeOrder), + m_DiffVarsInfo(request.DVI), m_UsesEnzyme(request.use_enzyme) {} - bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { - return (request.Function == m_OriginalFn && request.Mode == m_Mode && - request.CurrentDerivativeOrder == m_DerivativeOrder && - request.DVI == m_DiffVarsInfo); +bool DerivedFnInfo::SatisfiesRequest(const DiffRequest& request) const { + return (request.Function == m_OriginalFn && request.Mode == m_Mode && + request.CurrentDerivativeOrder == m_DerivativeOrder && + request.DVI == m_DiffVarsInfo && request.use_enzyme == m_UsesEnzyme); } bool DerivedFnInfo::IsValid() const { return m_OriginalFn && m_DerivedFn; } @@ -26,6 +26,7 @@ namespace clad { return lhs.m_OriginalFn == rhs.m_OriginalFn && lhs.m_DerivativeOrder == rhs.m_DerivativeOrder && lhs.m_Mode == rhs.m_Mode && - lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo; + lhs.m_DiffVarsInfo == rhs.m_DiffVarsInfo && + lhs.m_UsesEnzyme == rhs.m_UsesEnzyme; } } // namespace clad \ No newline at end of file diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h index 5b45201fc..03d1a6763 100644 --- a/tools/DerivedFnInfo.h +++ b/tools/DerivedFnInfo.h @@ -17,6 +17,7 @@ namespace clad { DiffMode m_Mode = DiffMode::unknown; unsigned m_DerivativeOrder = 0; DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme = false; DerivedFnInfo() {} DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn,