Skip to content

Commit

Permalink
Add support for Basic Forward Mode Differentiation with Enzyme
Browse files Browse the repository at this point in the history
This commit adds support for differentiation of a function of type:
```cpp
double func(double x, double y){
	return x*y;
}
```
  • Loading branch information
Nirhar committed Sep 3, 2022
1 parent 5be9796 commit 0ccb8a4
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 18 deletions.
11 changes: 11 additions & 0 deletions include/clad/Differentiator/ForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "clad/Differentiator/DiffPlanner.h"


#include <array>
#include <stack>
#include <unordered_map>
Expand All @@ -28,6 +31,14 @@ namespace clad {
unsigned m_IndependentVarIndex = ~0;
unsigned m_DerivativeOrder = ~0;
unsigned m_ArgIndex = ~0;
DiffInputVarsInfo m_DVI;
bool use_enzyme = false;

// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();

// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

public:
ForwardModeVisitor(DerivativeBuilder& builder);
Expand Down
116 changes: 103 additions & 13 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/StmtClone.h"

Expand Down Expand Up @@ -214,8 +213,7 @@ namespace clad {
m_DerivativeInFlight = true;

DiffInputVarsInfo DVI = request.DVI;

DVI = request.DVI;
m_DVI=DVI;

// FIXME: Shouldn't we give error here that no arg is specified?
if (DVI.empty())
Expand All @@ -239,6 +237,10 @@ namespace clad {
// FIXME: implement gradient-vector products to fix the issue.
assert((DVI.size() == 1) &&
"nested forward mode differentiation for several args is broken");

// Check if DiffRequest asks for use of enzyme as backend
if (request.use_enzyme)
use_enzyme = true;

// FIXME: Differentiation variable cannot always be represented just by
// `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc.
Expand Down Expand Up @@ -307,6 +309,9 @@ namespace clad {
for (auto field : diffVarInfo.fields)
argInfo += "_" + field;

if(use_enzyme)
derivativeSuffix+="_enzyme";

IdentifierInfo* II =
&m_Context.Idents.get(request.BaseFunctionName + "_d" + s + "arg" +
argInfo + derivativeSuffix);
Expand Down Expand Up @@ -371,6 +376,33 @@ namespace clad {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

if(!use_enzyme)
DifferentiateWithClad();
else
DifferentiateWithEnzyme();


Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;

return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
}

void ForwardModeVisitor::DifferentiateWithClad(){
const clang::FunctionDecl* FD = m_Function;
llvm::ArrayRef<ParmVarDecl*> params = m_Derivative->parameters();
DiffInputVarInfo diffVarInfo = m_DVI.back();


// For each function parameter variable, store its derivative value.
for (auto param : params) {
// We cannot create derivatives of reference type since seed value is
Expand Down Expand Up @@ -507,18 +539,76 @@ namespace clad {
addToCurrentBlock(S);
else
addToCurrentBlock(BodyDiff);
Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope
}

m_DerivativeInFlight = false;
void ForwardModeVisitor::DifferentiateWithEnzyme(){
llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters();
unsigned numParams = m_Function->getNumParams();
auto originalFnType = dyn_cast<FunctionProtoType>(m_Function->getType());
auto returnType = m_Derivative->getReturnType();


// Prepare Arguments and Parameters to enzyme_autodiff
llvm::SmallVector<Expr*, 16> enzymeArgs;
llvm::SmallVector<ParmVarDecl*, 16> enzymeParams;

// First add the function itself as a parameter/argument
enzymeArgs.push_back(BuildDeclRef(const_cast<FunctionDecl*>(m_Function)));
DeclContext* fdDeclContext =
const_cast<DeclContext*>(m_Function->getDeclContext());
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, m_Function->getType()));

// Add rest of the parameters/arguments
for (int i = 0; i < numParams; i++) {
auto paramType = paramsRef[i]->getType().getNonReferenceType();

// First Add the original parameter
enzymeArgs.push_back(BuildDeclRef(paramsRef[i]));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramType));

float fDx=0;
if(paramsRef[i]==m_IndependentVar)
fDx=1;
Expr* dx = dyn_cast<Expr>(
FloatingLiteral::Create(m_Context, llvm::APFloat(fDx),
false,paramType, noLoc));

// Then add the dx argument
enzymeArgs.push_back(dx);
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramType));
}

return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
llvm::SmallVector<QualType, 16> enzymeParamsType;
for (auto i : enzymeParams)
enzymeParamsType.push_back(i->getType());

// Prepare Function call
std::string enzymeCallName =
"__enzyme_fwddiff_" + m_Function->getNameAsString()+"_"+m_IndependentVar->getNameAsString();
IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName);
DeclarationName nameEnzyme(IIEnzyme);
QualType enzymeFunctionType =
m_Sema.BuildFunctionType(returnType, enzymeParamsType, noLoc, nameEnzyme,
originalFnType->getExtProtoInfo());
FunctionDecl* enzymeCallFD = FunctionDecl::Create(
m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType,
m_Function->getTypeSourceInfo(), SC_Extern);
enzymeCallFD->setParams(enzymeParams);
Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs);

auto diffDecl = BuildVarDecl(returnType, "diff", enzymeCall, true);
addToCurrentBlock(BuildDeclStmt(diffDecl));

// auto type = m_IndependentVar->getType().getNonReferenceType();
// Expr* retExpr = dyn_cast<Expr>(
// FloatingLiteral::Create(m_Context, llvm::APFloat(5.01),
// true,type, noLoc));
Stmt* returnStmt = m_Sema.BuildReturnStmt(noLoc,BuildDeclRef(diffDecl)).get();

addToCurrentBlock(returnStmt);
}

StmtDiff ForwardModeVisitor::VisitStmt(const Stmt* S) {
Expand Down
15 changes: 10 additions & 5 deletions test/Enzyme/ForwardMode.C
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
// RUN: %cladclang %s -lstdc++ -I%S/../../include -oEnzyme.out 2>&1 | FileCheck %s
// RUN: ./Enzyme.out
// RUN: ./Enzyme.out | FileCheck -check-prefix=CHECK-EXEC %s
// CHECK-NOT: {{.*error|warning|note:.*}}
// REQUIRES: Enzyme
// XFAIL:*
// Forward mode is not implemented yet

#include "clad/Differentiator/Differentiator.h"

double f(double x, double y) { return x * y; }
double f(double x, double y) {
return x * y;
}

// CHECK: double f_darg0_enzyme(double x, double y) {
// CHECK-NEXT:}
// CHECK: double f_darg0_enzyme(double x, double y) {
// CHECK-NEXT: double diff = __enzyme_fwddiff_f_x(f, x, 1., y, 0.);
// CHECK-NEXT: return diff;
// CHECK-NEXT: }

int main(){
auto f_dx = clad::differentiate<clad::opts::use_enzyme>(f, "x");
double ans = f_dx.execute(3,4);
printf("Ans = %.2f\n",ans); // CHECK-EXEC: Ans = 4.00
}

0 comments on commit 0ccb8a4

Please sign in to comment.