From 09e2ed0968fe6fbb142be155269372e0e2c50584 Mon Sep 17 00:00:00 2001 From: Nirhar Date: Tue, 16 Aug 2022 11:07:39 +0530 Subject: [PATCH] Add Support for Differentiating functions with both pointer/array type and primitive type parameters with Enzyme This commit adds support to differentiate many more function types with enzyme. Some example supported function types are: double f(double* arr) double f(double x, double y, double z) double f(double* arr, int n) double f(double* arr1, int n, double* arr2, int m) double f(double arr[], double x,int n,double y) Tests for these have been written in tests/Enzyme/ReverseMode.C --- include/clad/Differentiator/Differentiator.h | 3 + include/clad/Differentiator/VisitorBase.h | 9 +- lib/Differentiator/ForwardModeVisitor.cpp | 5 +- lib/Differentiator/ReverseModeVisitor.cpp | 169 +++++++++++++------ lib/Differentiator/VisitorBase.cpp | 40 +++-- test/Enzyme/ReverseMode.C | 111 +++++++++++- tools/CMakeLists.txt | 2 +- 7 files changed, 253 insertions(+), 86 deletions(-) diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index ac9019743..131e8ce15 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -449,6 +449,9 @@ namespace clad { DerivedFnType>(derivedFn /* will be replaced by estimation code*/, code); } + + // Gradient Structure for Reverse Mode Enzyme + template struct EnzymeGradient { double d_arr[N]; }; } #endif // CLAD_DIFFERENTIATOR diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 13ed6339f..4fa9e4cc9 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -372,15 +372,18 @@ namespace clad { /// /// \param[in] className name of the class to be found /// \returns The declaration of the class with the name ClassName - clang::TemplateDecl* GetCladClassDecl(llvm::StringRef ClassName); + clang::TemplateDecl* + LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName); /// Instantiate clad::class type /// /// \param[in] CladClassDecl the decl of the class that is going to be used /// in the creation of the type \param[in] TemplateArgs an array of template /// arguments \returns The created type clad::class clang::QualType - GetCladClassOfType(clang::TemplateDecl* CladClassDecl, - llvm::ArrayRef TemplateArgs); + InstantiateTemplate(clang::TemplateDecl* CladClassDecl, + llvm::ArrayRef TemplateArgs); + clang::QualType InstantiateTemplate(clang::TemplateDecl* CladClassDecl, + clang::TemplateArgumentListInfo& TLI); /// Find declaration of clad::tape templated type. clang::TemplateDecl* GetCladTapeDecl(); /// Perform a lookup into clad namespace for an entity with given name. diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp index 6077a2d97..b774737b1 100644 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ b/lib/Differentiator/ForwardModeVisitor.cpp @@ -58,11 +58,12 @@ namespace clad { QualType originalFnRT = m_Function->getReturnType(); if (originalFnRT->isVoidType()) return m_Context.VoidTy; - TemplateDecl* valueAndPushforward = GetCladClassDecl("ValueAndPushforward"); + TemplateDecl* valueAndPushforward = + LookupTemplateDeclInCladNamespace("ValueAndPushforward"); assert(valueAndPushforward && "clad::ValueAndPushforward template not found!!"); QualType RT = - GetCladClassOfType(valueAndPushforward, {originalFnRT, originalFnRT}); + InstantiateTemplate(valueAndPushforward, {originalFnRT, originalFnRT}); return RT; } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fe650c41a..3b3f311d1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -584,68 +584,127 @@ namespace clad { } void ReverseModeVisitor::DifferentiateWithEnzyme() { - // FIXME: Generalize this function to differentiate other kinds - // of function prototypes unsigned numParams = m_Function->getNumParams(); auto origParams = m_Function->parameters(); llvm::ArrayRef paramsRef = m_Derivative->parameters(); auto originalFnType = dyn_cast(m_Function->getType()); - // Case 1: The function to be differentiated is of type double - // func(double* arr){...}; - // or double func(double arr[n]){...}; - if (numParams == 1) { - QualType origTy = origParams[0]->getOriginalType(); - if (origTy->isConstantArrayType() || origTy->isPointerType()) { - // Extract Pointer from Clad Array Ref - auto arrayRefNameExpr = BuildDeclRef(paramsRef[1]); - auto getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {}); - auto arrayRefToArrayStmt = BuildVarDecl( - origTy, "d_" + paramsRef[0]->getNameAsString(), getPointerExpr); - addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt), - direction::forward); + // Extract Pointer from Clad Array Ref + llvm::SmallVector cladRefParams; + for (int i = 0; i < numParams; i++) { + QualType paramType = origParams[i]->getOriginalType(); + if (paramType->isRealType()) { + cladRefParams.push_back(nullptr); + continue; + } + + paramType = m_Context.getPointerType( + QualType(paramType->getPointeeOrArrayElementType(), 0)); + auto arrayRefNameExpr = BuildDeclRef(paramsRef[numParams + i]); + auto getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {}); + auto arrayRefToArrayStmt = BuildVarDecl( + paramType, "d_" + paramsRef[i]->getNameAsString(), getPointerExpr); + addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt), direction::forward); + cladRefParams.push_back(arrayRefToArrayStmt); + } + // Prepare Arguments and Parameters to enzyme_autodiff + llvm::SmallVector enzymeArgs; + llvm::SmallVector enzymeParams; + llvm::SmallVector enzymeRealParams; + llvm::SmallVector enzymeRealParamsRef; + + // First add the function itself as a parameter/argument + enzymeArgs.push_back(BuildDeclRef(const_cast(m_Function))); + DeclContext* fdDeclContext = + const_cast(m_Function->getDeclContext()); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, m_Function->getType())); + + // Add rest of the parameters/arguments + for (int i = 0; i < numParams; i++) { + // First Add the original parameter + enzymeArgs.push_back(BuildDeclRef(paramsRef[i])); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, paramsRef[i]->getType())); + + // If the original parameter is not of array/pointer type, then we don't + // have to extract its pointer from clad array_ref and add it to the + // enzyme parameters, so we can skip the rest of the code + if (!cladRefParams[i]) { + // If original parameter is of a differentiable real type(but not + // array/pointer), then add it to the list of params whose gradient must + // be extracted later from the EnzymeGradient structure + if (paramsRef[i]->getOriginalType()->isRealFloatingType()) { + enzymeRealParams.push_back(paramsRef[i]); + enzymeRealParamsRef.push_back(paramsRef[numParams + i]); + } + continue; + } + // Then add the corresponding clad array ref pointer variable + enzymeArgs.push_back(BuildDeclRef(cladRefParams[i])); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, cladRefParams[i]->getType())); + } - // Prepare Arguments to enzyme_autodiff - llvm::SmallVector enzymeArgs; - enzymeArgs.push_back( - BuildDeclRef(const_cast(m_Function))); - enzymeArgs.push_back(BuildDeclRef(paramsRef[0])); - enzymeArgs.push_back(BuildDeclRef(arrayRefToArrayStmt)); - - // Prepare Parameters for Function Signature - llvm::SmallVector enzymeParams; - DeclContext* fdDeclContext = - const_cast(m_Function->getDeclContext()); - enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, m_Function->getType())); - enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, paramsRef[0]->getType())); - enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, arrayRefToArrayStmt->getType())); - - // Get the Parameter Types in Function Signature - llvm::SmallVector enzymeParamsType; - for (auto i : enzymeParams) - enzymeParamsType.push_back(i->getType()); - - // Prepare Function call - std::string enzymeCallName = - "__enzyme_autodiff_" + m_Function->getNameAsString(); - IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName); - DeclarationName nameEnzyme(IIEnzyme); - QualType enzymeFunctionType = m_Sema.BuildFunctionType( - m_Context.VoidTy, enzymeParamsType, noLoc, nameEnzyme, - originalFnType->getExtProtoInfo()); - FunctionDecl* enzymeCallFD = FunctionDecl::Create( - m_Context, const_cast(m_Function->getDeclContext()), - noLoc, noLoc, nameEnzyme, enzymeFunctionType, - m_Function->getTypeSourceInfo(), SC_Extern); - enzymeCallFD->setParams(enzymeParams); - - // Add Function call to block - Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); - addToCurrentBlock(enzymeCall); + llvm::SmallVector enzymeParamsType; + for (auto i : enzymeParams) + enzymeParamsType.push_back(i->getType()); + + QualType QT; + if (enzymeRealParams.size()) { + // Find the EnzymeGradient datastructure + auto gradDecl = LookupTemplateDeclInCladNamespace("EnzymeGradient"); + + TemplateArgumentListInfo TLI{}; + llvm::APSInt argValue(std::to_string(enzymeRealParams.size())); + TemplateArgument TA(m_Context, argValue, m_Context.UnsignedIntTy); + TLI.addArgument(TemplateArgumentLoc(TA, TemplateArgumentLocInfo())); + + QT = InstantiateTemplate(gradDecl, TLI); + } else { + QT = m_Context.VoidTy; + } + + // Prepare Function call + std::string enzymeCallName = + "__enzyme_autodiff_" + m_Function->getNameAsString(); + IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName); + DeclarationName nameEnzyme(IIEnzyme); + QualType enzymeFunctionType = + m_Sema.BuildFunctionType(QT, enzymeParamsType, noLoc, nameEnzyme, + originalFnType->getExtProtoInfo()); + FunctionDecl* enzymeCallFD = FunctionDecl::Create( + m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType, + m_Function->getTypeSourceInfo(), SC_Extern); + enzymeCallFD->setParams(enzymeParams); + Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); + + // Prepare the statements that assign the gradients to + // non array/pointer type parameters of the original function + if (enzymeRealParams.size() != 0) { + auto gradDeclStmt = BuildVarDecl(QT, "grad", enzymeCall, true); + addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward); + + for (int i = 0; i < enzymeRealParams.size(); i++) { + auto LHSExpr = BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsRef[i])); + + auto ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(), + BuildDeclRef(gradDeclStmt), "d_arr"); + + Expr* gradIndex = dyn_cast( + IntegerLiteral::Create(m_Context, llvm::APSInt(std::to_string(i)), + m_Context.UnsignedIntTy, noLoc)); + Expr* RHSExpr = + m_Sema.CreateBuiltinArraySubscriptExpr(ME, noLoc, gradIndex, noLoc) + .get(); + + auto assignExpr = BuildOp(BO_Assign, LHSExpr, RHSExpr); + addToCurrentBlock(assignExpr, direction::forward); } + } else { + // Add Function call to block + Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); + addToCurrentBlock(enzymeCall); } } StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 2f7cf1876..127829aeb 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -349,7 +349,8 @@ namespace clad { return Result; } - TemplateDecl* VisitorBase::GetCladClassDecl(llvm::StringRef ClassName) { + TemplateDecl* + VisitorBase::LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName) { NamespaceDecl* CladNS = GetCladNamespace(); CXXScopeSpec CSS; CSS.Extend(m_Context, CladNS, noLoc, noLoc); @@ -365,16 +366,8 @@ namespace clad { return cast(TapeR.getFoundDecl()); } - QualType - VisitorBase::GetCladClassOfType(TemplateDecl* CladClassDecl, - ArrayRef TemplateArgs) { - // Create a list of template arguments. - TemplateArgumentListInfo TLI{}; - for (auto T : TemplateArgs) { - TemplateArgument TA = T; - TLI.addArgument( - TemplateArgumentLoc(TA, m_Context.getTrivialTypeSourceInfo(T))); - } + QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, + TemplateArgumentListInfo& TLI) { // This will instantiate tape type and return it. QualType TT = m_Sema.CheckTemplateIdType(TemplateName(CladClassDecl), noLoc, TLI); @@ -388,10 +381,23 @@ namespace clad { return m_Context.getElaboratedType(ETK_None, NS, TT); } + QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl, + ArrayRef TemplateArgs) { + // Create a list of template arguments. + TemplateArgumentListInfo TLI{}; + for (auto T : TemplateArgs) { + TemplateArgument TA = T; + TLI.addArgument( + TemplateArgumentLoc(TA, m_Context.getTrivialTypeSourceInfo(T))); + } + + return VisitorBase::InstantiateTemplate(CladClassDecl, TLI); + } + TemplateDecl* VisitorBase::GetCladTapeDecl() { static TemplateDecl* Result = nullptr; if (!Result) - Result = GetCladClassDecl(/*ClassName=*/"tape"); + Result = LookupTemplateDeclInCladNamespace(/*ClassName=*/"tape"); return Result; } @@ -441,7 +447,7 @@ namespace clad { } QualType VisitorBase::GetCladTapeOfType(QualType T) { - return GetCladClassOfType(GetCladTapeDecl(), {T}); + return InstantiateTemplate(GetCladTapeDecl(), {T}); } Expr* VisitorBase::BuildCallExprToMemFn(Expr* Base, @@ -555,23 +561,23 @@ namespace clad { TemplateDecl* VisitorBase::GetCladArrayRefDecl() { static TemplateDecl* Result = nullptr; if (!Result) - Result = GetCladClassDecl(/*ClassName=*/"array_ref"); + Result = LookupTemplateDeclInCladNamespace(/*ClassName=*/"array_ref"); return Result; } QualType VisitorBase::GetCladArrayRefOfType(clang::QualType T) { - return GetCladClassOfType(GetCladArrayRefDecl(), {T}); + return InstantiateTemplate(GetCladArrayRefDecl(), {T}); } TemplateDecl* VisitorBase::GetCladArrayDecl() { static TemplateDecl* Result = nullptr; if (!Result) - Result = GetCladClassDecl(/*ClassName=*/"array"); + Result = LookupTemplateDeclInCladNamespace(/*ClassName=*/"array"); return Result; } QualType VisitorBase::GetCladArrayOfType(clang::QualType T) { - return GetCladClassOfType(GetCladArrayDecl(), {T}); + return InstantiateTemplate(GetCladArrayDecl(), {T}); } Expr* VisitorBase::BuildArrayRefSizeExpr(Expr* Base) { diff --git a/test/Enzyme/ReverseMode.C b/test/Enzyme/ReverseMode.C index fd72df17e..98aa4a877 100644 --- a/test/Enzyme/ReverseMode.C +++ b/test/Enzyme/ReverseMode.C @@ -5,17 +5,112 @@ #include "clad/Differentiator/Differentiator.h" -double f(double* arr) { return arr[0] * arr[1]; } +double f1(double* arr) { return arr[0] * arr[1]; } -// CHECK: void f_grad_enzyme(double *arr, clad::array_ref _d_arr) { +// CHECK: void f1_grad_enzyme(double *arr, clad::array_ref _d_arr) { // CHECK-NEXT: double *d_arr = _d_arr.ptr(); -// CHECK-NEXT: __enzyme_autodiff_f(f, arr, d_arr); +// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr); // CHECK-NEXT:} +double f2(double x, double y, double z){ + return x * y * z; +} + +// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref _d_x, clad::array_ref _d_y, clad::array_ref _d_z) { +// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: * _d_z = grad.d_arr[2U]; +// CHECK-NEXT:} + +double f3(double* arr, int n){ + double sum=0; + for(int i=0;i _d_arr, clad::array_ref _d_n) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, d_arr, n); +// CHECK-NEXT: } + +double f4(double* arr1, int n, double*arr2, int m){ + double sum=0; + for(int i=0;i _d_arr1, clad::array_ref _d_n, clad::array_ref _d_arr2, clad::array_ref _d_m) { +// CHECK-NEXT: double *d_arr1 = _d_arr1.ptr(); +// CHECK-NEXT: double *d_arr2 = _d_arr2.ptr(); +// CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, d_arr1, n, arr2, d_arr2, m); +// CHECK-NEXT: } + +double f5(double arr[], double x,int n,double y){ + double res=0; + for(int i=0;i _d_arr, clad::array_ref _d_x, clad::array_ref _d_n, clad::array_ref _d_y) { +// CHECK-NEXT: double *d_arr = _d_arr.ptr(); +// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, d_arr, x, n, y); +// CHECK-NEXT: * _d_x = grad.d_arr[0U]; +// CHECK-NEXT: * _d_y = grad.d_arr[1U]; +// CHECK-NEXT: } + + int main() { - auto f_grad = clad::gradient(f); - double v[2] = {3, 4}; - double g[2] = {0}; - f_grad.execute(v, g); - printf("d_x = %.2f, d_y = %.2f", g[0], g[1]); // CHECK-EXEC: d_x = 4.00, d_y = 3.00 + auto f1_grad = clad::gradient(f1); + double f1_v[2] = {3, 4}; + double f1_g[2] = {0}; + f1_grad.execute(f1_v, f1_g); + printf("d_x = %.2f, d_y = %.2f\n", f1_g[0], f1_g[1]); + // CHECK-EXEC: d_x = 4.00, d_y = 3.00 + + auto f2_grad=clad::gradient(f2); + double f2_res[3]; + double f2_x=3,f2_y=4,f2_z=5; + f2_grad.execute(f2_x,f2_y,f2_z,&f2_res[0],&f2_res[1],&f2_res[2]); + printf("d_x = %.2f, d_y = %.2f, d_z = %.2f\n", f2_res[0], f2_res[1], f2_res[2]); + //CHECK-EXEC: d_x = 20.00, d_y = 15.00, d_z = 12.00 + + auto f3_grad=clad::gradient(f3); + double f3_list[3]={3,4,5}; + double f3_res[3]={0}; + int f3_dn=0; + f3_grad.execute(f3_list,3,f3_res,&f3_dn); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n = %d\n",f3_res[0],f3_res[1],f3_res[2],f3_dn); + //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n = 0 + + auto f4_grad=clad::gradient(f4); + double f4_list1[3]={3,4,5}; + double f4_list2[2]={1,2}; + double f4_res1[3]={0}; + double f4_res2[2]={0}; + int f4_dn1=0,f4_dn2=0; + f4_grad.execute(f4_list1,3,f4_list2,2,f4_res1,&f4_dn1,f4_res2,&f4_dn2); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d\n",f4_res1[0],f4_res1[1],f4_res1[2],f4_dn1); + //CHECK-EXEC: d_x1 = 6.00, d_x2 = 8.00, d_x3 = 10.00, d_n1 = 0 + printf("d_y1 = %.2f, d_y2 = %.2f, d_n2 = %d\n",f4_res2[0],f4_res2[1],f4_dn2); + //CHECK-EXEC: d_y1 = 2.00, d_y2 = 4.00, d_n2 = 0 + + auto f5_grad=clad::gradient(f5); + double f5_list[3]={3,4,5}; + double f5_res[3]={0}; + double f5_x=10.0,f5_dx=0,f5_y=5,f5_dy; + int f5_dn=0; + f5_grad.execute(f5_list,f5_x,3,f5_y,f5_res,&f5_dx,&f5_dn,&f5_dy); + printf("d_x1 = %.2f, d_x2 = %.2f, d_x3 = %.2f, d_n1 = %d, d_x = %.2f, d_y = %.2f\n",f5_res[0],f5_res[1],f5_res[2],f5_dn, f5_dx, f5_dy); + //CHECK-EXEC: d_x1 = 50.00, d_x2 = 50.00, d_x3 = 50.00, d_n1 = 0, d_x = 60.00, d_y = 120.00 + } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 33b824c87..62c60d3b6 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -48,7 +48,7 @@ if (NOT CLAD_BUILD_STATIC_ONLY) ExternalProject_Add( Enzyme GIT_REPOSITORY https://github.com/wsmoses/Enzyme - GIT_TAG v0.0.33 + GIT_TAG v0.0.36 GIT_SHALLOW 1 # Do not clone the history PATCH_COMMAND ${_enzyme_patch_command} UPDATE_COMMAND ""