diff --git a/include/clad/Differentiator/ForwardModeVisitor.h b/include/clad/Differentiator/ForwardModeVisitor.h index 765b9bed7..7fe16c404 100644 --- a/include/clad/Differentiator/ForwardModeVisitor.h +++ b/include/clad/Differentiator/ForwardModeVisitor.h @@ -13,6 +13,9 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Sema/Sema.h" +#include "clad/Differentiator/DiffPlanner.h" + + #include #include #include @@ -28,6 +31,14 @@ namespace clad { unsigned m_IndependentVarIndex = ~0; unsigned m_DerivativeOrder = ~0; unsigned m_ArgIndex = ~0; + DiffInputVarsInfo m_DVI; + bool use_enzyme = false; + + // Function to Differentiate with Clad as Backend + void DifferentiateWithClad(); + + // Function to Differentiate with Enzyme as Backend + void DifferentiateWithEnzyme(); public: ForwardModeVisitor(DerivativeBuilder& builder); diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp index b774737b1..f0bb59f70 100644 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ b/lib/Differentiator/ForwardModeVisitor.cpp @@ -9,7 +9,6 @@ #include "ConstantFolder.h" #include "clad/Differentiator/CladUtils.h" -#include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" #include "clad/Differentiator/StmtClone.h" @@ -214,8 +213,7 @@ namespace clad { m_DerivativeInFlight = true; DiffInputVarsInfo DVI = request.DVI; - - DVI = request.DVI; + m_DVI=DVI; // FIXME: Shouldn't we give error here that no arg is specified? if (DVI.empty()) @@ -239,6 +237,10 @@ namespace clad { // FIXME: implement gradient-vector products to fix the issue. assert((DVI.size() == 1) && "nested forward mode differentiation for several args is broken"); + + // Check if DiffRequest asks for use of enzyme as backend + if (request.use_enzyme) + use_enzyme = true; // FIXME: Differentiation variable cannot always be represented just by // `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc. @@ -307,6 +309,9 @@ namespace clad { for (auto field : diffVarInfo.fields) argInfo += "_" + field; + if(use_enzyme) + derivativeSuffix+="_enzyme"; + IdentifierInfo* II = &m_Context.Idents.get(request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix); @@ -371,6 +376,33 @@ namespace clad { beginScope(Scope::FnScope | Scope::DeclScope); m_DerivativeFnScope = getCurrentScope(); beginBlock(); + + if(!use_enzyme) + DifferentiateWithClad(); + else + DifferentiateWithEnzyme(); + + + Stmt* derivativeBody = endBlock(); + derivedFD->setBody(derivativeBody); + + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope + + m_DerivativeInFlight = false; + + return DerivativeAndOverload{result.first, + /*OverloadFunctionDecl=*/nullptr}; + } + + void ForwardModeVisitor::DifferentiateWithClad(){ + const clang::FunctionDecl* FD = m_Function; + llvm::ArrayRef params = m_Derivative->parameters(); + DiffInputVarInfo diffVarInfo = m_DVI.back(); + + // For each function parameter variable, store its derivative value. for (auto param : params) { // We cannot create derivatives of reference type since seed value is @@ -507,18 +539,76 @@ namespace clad { addToCurrentBlock(S); else addToCurrentBlock(BodyDiff); - Stmt* derivativeBody = endBlock(); - derivedFD->setBody(derivativeBody); - - endScope(); // Function body scope - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope + } - m_DerivativeInFlight = false; + void ForwardModeVisitor::DifferentiateWithEnzyme(){ + llvm::ArrayRef paramsRef = m_Derivative->parameters(); + unsigned numParams = m_Function->getNumParams(); + auto originalFnType = dyn_cast(m_Function->getType()); + auto returnType = m_Derivative->getReturnType(); + + + // Prepare Arguments and Parameters to enzyme_autodiff + llvm::SmallVector enzymeArgs; + llvm::SmallVector enzymeParams; + + // First add the function itself as a parameter/argument + enzymeArgs.push_back(BuildDeclRef(const_cast(m_Function))); + DeclContext* fdDeclContext = + const_cast(m_Function->getDeclContext()); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, m_Function->getType())); + + // Add rest of the parameters/arguments + for (int i = 0; i < numParams; i++) { + auto paramType = paramsRef[i]->getType().getNonReferenceType(); + + // First Add the original parameter + enzymeArgs.push_back(BuildDeclRef(paramsRef[i])); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, paramType)); + + float fDx=0; + if(paramsRef[i]==m_IndependentVar) + fDx=1; + Expr* dx = dyn_cast( + FloatingLiteral::Create(m_Context, llvm::APFloat(fDx), + false,paramType, noLoc)); + + // Then add the dx argument + enzymeArgs.push_back(dx); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, paramType)); + } - return DerivativeAndOverload{result.first, - /*OverloadFunctionDecl=*/nullptr}; + llvm::SmallVector enzymeParamsType; + for (auto i : enzymeParams) + enzymeParamsType.push_back(i->getType()); + + // Prepare Function call + std::string enzymeCallName = + "__enzyme_fwddiff_" + m_Function->getNameAsString()+"_"+m_IndependentVar->getNameAsString(); + IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName); + DeclarationName nameEnzyme(IIEnzyme); + QualType enzymeFunctionType = + m_Sema.BuildFunctionType(returnType, enzymeParamsType, noLoc, nameEnzyme, + originalFnType->getExtProtoInfo()); + FunctionDecl* enzymeCallFD = FunctionDecl::Create( + m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType, + m_Function->getTypeSourceInfo(), SC_Extern); + enzymeCallFD->setParams(enzymeParams); + Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); + + auto diffDecl = BuildVarDecl(returnType, "diff", enzymeCall, true); + addToCurrentBlock(BuildDeclStmt(diffDecl)); + + // auto type = m_IndependentVar->getType().getNonReferenceType(); + // Expr* retExpr = dyn_cast( + // FloatingLiteral::Create(m_Context, llvm::APFloat(5.01), + // true,type, noLoc)); + Stmt* returnStmt = m_Sema.BuildReturnStmt(noLoc,BuildDeclRef(diffDecl)).get(); + + addToCurrentBlock(returnStmt); } StmtDiff ForwardModeVisitor::VisitStmt(const Stmt* S) { diff --git a/test/Enzyme/ForwardMode.C b/test/Enzyme/ForwardMode.C index dce7dbc27..b22470ae2 100644 --- a/test/Enzyme/ForwardMode.C +++ b/test/Enzyme/ForwardMode.C @@ -1,17 +1,22 @@ // RUN: %cladclang %s -lstdc++ -I%S/../../include -oEnzyme.out 2>&1 | FileCheck %s -// RUN: ./Enzyme.out +// RUN: ./Enzyme.out | FileCheck -check-prefix=CHECK-EXEC %s // CHECK-NOT: {{.*error|warning|note:.*}} // REQUIRES: Enzyme -// XFAIL:* // Forward mode is not implemented yet #include "clad/Differentiator/Differentiator.h" -double f(double x, double y) { return x * y; } +double f(double x, double y) { + return x * y; +} -// CHECK: double f_darg0_enzyme(double x, double y) { -// CHECK-NEXT:} +// CHECK: double f_darg0_enzyme(double x, double y) { +// CHECK-NEXT: double diff = __enzyme_fwddiff_f_x(f, x, 1., y, 0.); +// CHECK-NEXT: return diff; +// CHECK-NEXT: } int main(){ auto f_dx = clad::differentiate(f, "x"); + double ans = f_dx.execute(3,4); + printf("Ans = %.2f\n",ans); // CHECK-EXEC: Ans = 4.00 } \ No newline at end of file