Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Basic Forward Mode Differentiation with Enzyme #496

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/clad/Differentiator/ForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "clad/Differentiator/DiffPlanner.h"


#include <array>
#include <stack>
#include <unordered_map>
Expand All @@ -28,6 +31,14 @@ namespace clad {
unsigned m_IndependentVarIndex = ~0;
unsigned m_DerivativeOrder = ~0;
unsigned m_ArgIndex = ~0;
DiffInputVarsInfo m_DVI;
bool use_enzyme = false;

// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();

// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

public:
ForwardModeVisitor(DerivativeBuilder& builder);
Expand Down
116 changes: 103 additions & 13 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/StmtClone.h"

Expand Down Expand Up @@ -214,8 +213,7 @@ namespace clad {
m_DerivativeInFlight = true;

DiffInputVarsInfo DVI = request.DVI;

DVI = request.DVI;
m_DVI=DVI;
Copy link
Collaborator

@parth-07 parth-07 Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are saving most of the information of DiffRequest request as separate members in the visitor classes. @vgvassilev Should we directly save the DiffRequest as a member in the visitor classes?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can add it as a const reference where we could read it. Makes sense, thanks for the suggestion Parth.

@Nirhar, do you have the bandwidth to continue with this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wont be able to work on this now, I can contribute to this after mid-december


// FIXME: Shouldn't we give error here that no arg is specified?
if (DVI.empty())
Expand All @@ -239,6 +237,10 @@ namespace clad {
// FIXME: implement gradient-vector products to fix the issue.
assert((DVI.size() == 1) &&
"nested forward mode differentiation for several args is broken");

// Check if DiffRequest asks for use of enzyme as backend
if (request.use_enzyme)
use_enzyme = true;

// FIXME: Differentiation variable cannot always be represented just by
// `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc.
Expand Down Expand Up @@ -307,6 +309,9 @@ namespace clad {
for (auto field : diffVarInfo.fields)
argInfo += "_" + field;

if(use_enzyme)
derivativeSuffix+="_enzyme";

IdentifierInfo* II =
&m_Context.Idents.get(request.BaseFunctionName + "_d" + s + "arg" +
argInfo + derivativeSuffix);
Expand Down Expand Up @@ -371,6 +376,33 @@ namespace clad {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

if(!use_enzyme)
DifferentiateWithClad();
else
DifferentiateWithEnzyme();


Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;

return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
}

void ForwardModeVisitor::DifferentiateWithClad(){
const clang::FunctionDecl* FD = m_Function;
llvm::ArrayRef<ParmVarDecl*> params = m_Derivative->parameters();
DiffInputVarInfo diffVarInfo = m_DVI.back();


// For each function parameter variable, store its derivative value.
for (auto param : params) {
// We cannot create derivatives of reference type since seed value is
Expand Down Expand Up @@ -507,18 +539,76 @@ namespace clad {
addToCurrentBlock(S);
else
addToCurrentBlock(BodyDiff);
Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope
}

m_DerivativeInFlight = false;
void ForwardModeVisitor::DifferentiateWithEnzyme(){
llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters();
unsigned numParams = m_Function->getNumParams();
auto originalFnType = dyn_cast<FunctionProtoType>(m_Function->getType());
auto returnType = m_Derivative->getReturnType();


// Prepare Arguments and Parameters to enzyme_autodiff
llvm::SmallVector<Expr*, 16> enzymeArgs;
llvm::SmallVector<ParmVarDecl*, 16> enzymeParams;

// First add the function itself as a parameter/argument
enzymeArgs.push_back(BuildDeclRef(const_cast<FunctionDecl*>(m_Function)));
DeclContext* fdDeclContext =
const_cast<DeclContext*>(m_Function->getDeclContext());
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, m_Function->getType()));

// Add rest of the parameters/arguments
for (int i = 0; i < numParams; i++) {
auto paramType = paramsRef[i]->getType().getNonReferenceType();

// First Add the original parameter
enzymeArgs.push_back(BuildDeclRef(paramsRef[i]));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramType));

float fDx=0;
if(paramsRef[i]==m_IndependentVar)
fDx=1;
Expr* dx = dyn_cast<Expr>(
FloatingLiteral::Create(m_Context, llvm::APFloat(fDx),
false,paramType, noLoc));

// Then add the dx argument
enzymeArgs.push_back(dx);
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramType));
}

return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
llvm::SmallVector<QualType, 16> enzymeParamsType;
for (auto i : enzymeParams)
enzymeParamsType.push_back(i->getType());

// Prepare Function call
std::string enzymeCallName =
"__enzyme_fwddiff_" + m_Function->getNameAsString()+"_"+m_IndependentVar->getNameAsString();
IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName);
DeclarationName nameEnzyme(IIEnzyme);
QualType enzymeFunctionType =
m_Sema.BuildFunctionType(returnType, enzymeParamsType, noLoc, nameEnzyme,
originalFnType->getExtProtoInfo());
FunctionDecl* enzymeCallFD = FunctionDecl::Create(
m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType,
m_Function->getTypeSourceInfo(), SC_Extern);
enzymeCallFD->setParams(enzymeParams);
Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs);

auto diffDecl = BuildVarDecl(returnType, "diff", enzymeCall, true);
addToCurrentBlock(BuildDeclStmt(diffDecl));

// auto type = m_IndependentVar->getType().getNonReferenceType();
// Expr* retExpr = dyn_cast<Expr>(
// FloatingLiteral::Create(m_Context, llvm::APFloat(5.01),
// true,type, noLoc));
Stmt* returnStmt = m_Sema.BuildReturnStmt(noLoc,BuildDeclRef(diffDecl)).get();

addToCurrentBlock(returnStmt);
}

StmtDiff ForwardModeVisitor::VisitStmt(const Stmt* S) {
Expand Down
15 changes: 10 additions & 5 deletions test/Enzyme/ForwardMode.C
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
// RUN: %cladclang %s -lstdc++ -I%S/../../include -oEnzyme.out 2>&1 | FileCheck %s
// RUN: ./Enzyme.out
// RUN: ./Enzyme.out | FileCheck -check-prefix=CHECK-EXEC %s
// CHECK-NOT: {{.*error|warning|note:.*}}
// REQUIRES: Enzyme
// XFAIL:*
// Forward mode is not implemented yet

#include "clad/Differentiator/Differentiator.h"

double f(double x, double y) { return x * y; }
double f(double x, double y) {
return x * y;
}

// CHECK: double f_darg0_enzyme(double x, double y) {
// CHECK-NEXT:}
// CHECK: double f_darg0_enzyme(double x, double y) {
// CHECK-NEXT: double diff = __enzyme_fwddiff_f_x(f, x, 1., y, 0.);
// CHECK-NEXT: return diff;
// CHECK-NEXT: }

int main(){
auto f_dx = clad::differentiate<clad::opts::use_enzyme>(f, "x");
double ans = f_dx.execute(3,4);
printf("Ans = %.2f\n",ans); // CHECK-EXEC: Ans = 4.00
}