Skip to content

Commit

Permalink
Add Support for Differentiating functions with both pointer/array typ…
Browse files Browse the repository at this point in the history
…e and primitive type parameters with Enzyme

This commit adds support to differentiate many more function types with enzyme.
Some example supported function types are:

double f(double* arr)
double f(double x, double y, double z)
double f(double* arr, int n)
double f(double* arr1, int n,  double* arr2, int m)
double f(double arr[], double x,int n,double y)

Tests for these have been written in tests/Enzyme/ReverseMode.C
  • Loading branch information
Nirhar authored and vgvassilev committed Aug 24, 2022
1 parent 828baf7 commit 09e2ed0
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 86 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ namespace clad {
DerivedFnType>(derivedFn /* will be replaced by estimation code*/,
code);
}

// Gradient Structure for Reverse Mode Enzyme
template <unsigned N> struct EnzymeGradient { double d_arr[N]; };
}
#endif // CLAD_DIFFERENTIATOR

Expand Down
9 changes: 6 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,18 @@ namespace clad {
///
/// \param[in] className name of the class to be found
/// \returns The declaration of the class with the name ClassName
clang::TemplateDecl* GetCladClassDecl(llvm::StringRef ClassName);
clang::TemplateDecl*
LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName);
/// Instantiate clad::class<TemplateArgs> type
///
/// \param[in] CladClassDecl the decl of the class that is going to be used
/// in the creation of the type \param[in] TemplateArgs an array of template
/// arguments \returns The created type clad::class<TemplateArgs>
clang::QualType
GetCladClassOfType(clang::TemplateDecl* CladClassDecl,
llvm::ArrayRef<clang::QualType> TemplateArgs);
InstantiateTemplate(clang::TemplateDecl* CladClassDecl,
llvm::ArrayRef<clang::QualType> TemplateArgs);
clang::QualType InstantiateTemplate(clang::TemplateDecl* CladClassDecl,
clang::TemplateArgumentListInfo& TLI);
/// Find declaration of clad::tape templated type.
clang::TemplateDecl* GetCladTapeDecl();
/// Perform a lookup into clad namespace for an entity with given name.
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ namespace clad {
QualType originalFnRT = m_Function->getReturnType();
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward = GetCladClassDecl("ValueAndPushforward");
TemplateDecl* valueAndPushforward =
LookupTemplateDeclInCladNamespace("ValueAndPushforward");
assert(valueAndPushforward &&
"clad::ValueAndPushforward template not found!!");
QualType RT =
GetCladClassOfType(valueAndPushforward, {originalFnRT, originalFnRT});
InstantiateTemplate(valueAndPushforward, {originalFnRT, originalFnRT});
return RT;
}

Expand Down
169 changes: 114 additions & 55 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,68 +584,127 @@ namespace clad {
}

void ReverseModeVisitor::DifferentiateWithEnzyme() {
// FIXME: Generalize this function to differentiate other kinds
// of function prototypes
unsigned numParams = m_Function->getNumParams();
auto origParams = m_Function->parameters();
llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters();
auto originalFnType = dyn_cast<FunctionProtoType>(m_Function->getType());

// Case 1: The function to be differentiated is of type double
// func(double* arr){...};
// or double func(double arr[n]){...};
if (numParams == 1) {
QualType origTy = origParams[0]->getOriginalType();
if (origTy->isConstantArrayType() || origTy->isPointerType()) {
// Extract Pointer from Clad Array Ref
auto arrayRefNameExpr = BuildDeclRef(paramsRef[1]);
auto getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {});
auto arrayRefToArrayStmt = BuildVarDecl(
origTy, "d_" + paramsRef[0]->getNameAsString(), getPointerExpr);
addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt),
direction::forward);
// Extract Pointer from Clad Array Ref
llvm::SmallVector<VarDecl*, 8> cladRefParams;
for (int i = 0; i < numParams; i++) {
QualType paramType = origParams[i]->getOriginalType();
if (paramType->isRealType()) {
cladRefParams.push_back(nullptr);
continue;
}

paramType = m_Context.getPointerType(
QualType(paramType->getPointeeOrArrayElementType(), 0));
auto arrayRefNameExpr = BuildDeclRef(paramsRef[numParams + i]);
auto getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {});
auto arrayRefToArrayStmt = BuildVarDecl(
paramType, "d_" + paramsRef[i]->getNameAsString(), getPointerExpr);
addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt), direction::forward);
cladRefParams.push_back(arrayRefToArrayStmt);
}
// Prepare Arguments and Parameters to enzyme_autodiff
llvm::SmallVector<Expr*, 16> enzymeArgs;
llvm::SmallVector<ParmVarDecl*, 16> enzymeParams;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParams;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParamsRef;

// First add the function itself as a parameter/argument
enzymeArgs.push_back(BuildDeclRef(const_cast<FunctionDecl*>(m_Function)));
DeclContext* fdDeclContext =
const_cast<DeclContext*>(m_Function->getDeclContext());
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, m_Function->getType()));

// Add rest of the parameters/arguments
for (int i = 0; i < numParams; i++) {
// First Add the original parameter
enzymeArgs.push_back(BuildDeclRef(paramsRef[i]));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramsRef[i]->getType()));

// If the original parameter is not of array/pointer type, then we don't
// have to extract its pointer from clad array_ref and add it to the
// enzyme parameters, so we can skip the rest of the code
if (!cladRefParams[i]) {
// If original parameter is of a differentiable real type(but not
// array/pointer), then add it to the list of params whose gradient must
// be extracted later from the EnzymeGradient structure
if (paramsRef[i]->getOriginalType()->isRealFloatingType()) {
enzymeRealParams.push_back(paramsRef[i]);
enzymeRealParamsRef.push_back(paramsRef[numParams + i]);
}
continue;
}
// Then add the corresponding clad array ref pointer variable
enzymeArgs.push_back(BuildDeclRef(cladRefParams[i]));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, cladRefParams[i]->getType()));
}

// Prepare Arguments to enzyme_autodiff
llvm::SmallVector<Expr*, 16> enzymeArgs;
enzymeArgs.push_back(
BuildDeclRef(const_cast<FunctionDecl*>(m_Function)));
enzymeArgs.push_back(BuildDeclRef(paramsRef[0]));
enzymeArgs.push_back(BuildDeclRef(arrayRefToArrayStmt));

// Prepare Parameters for Function Signature
llvm::SmallVector<ParmVarDecl*, 16> enzymeParams;
DeclContext* fdDeclContext =
const_cast<DeclContext*>(m_Function->getDeclContext());
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, m_Function->getType()));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramsRef[0]->getType()));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, arrayRefToArrayStmt->getType()));

// Get the Parameter Types in Function Signature
llvm::SmallVector<QualType, 8> enzymeParamsType;
for (auto i : enzymeParams)
enzymeParamsType.push_back(i->getType());

// Prepare Function call
std::string enzymeCallName =
"__enzyme_autodiff_" + m_Function->getNameAsString();
IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName);
DeclarationName nameEnzyme(IIEnzyme);
QualType enzymeFunctionType = m_Sema.BuildFunctionType(
m_Context.VoidTy, enzymeParamsType, noLoc, nameEnzyme,
originalFnType->getExtProtoInfo());
FunctionDecl* enzymeCallFD = FunctionDecl::Create(
m_Context, const_cast<DeclContext*>(m_Function->getDeclContext()),
noLoc, noLoc, nameEnzyme, enzymeFunctionType,
m_Function->getTypeSourceInfo(), SC_Extern);
enzymeCallFD->setParams(enzymeParams);

// Add Function call to block
Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs);
addToCurrentBlock(enzymeCall);
llvm::SmallVector<QualType, 16> enzymeParamsType;
for (auto i : enzymeParams)
enzymeParamsType.push_back(i->getType());

QualType QT;
if (enzymeRealParams.size()) {
// Find the EnzymeGradient datastructure
auto gradDecl = LookupTemplateDeclInCladNamespace("EnzymeGradient");

TemplateArgumentListInfo TLI{};
llvm::APSInt argValue(std::to_string(enzymeRealParams.size()));
TemplateArgument TA(m_Context, argValue, m_Context.UnsignedIntTy);
TLI.addArgument(TemplateArgumentLoc(TA, TemplateArgumentLocInfo()));

QT = InstantiateTemplate(gradDecl, TLI);
} else {
QT = m_Context.VoidTy;
}

// Prepare Function call
std::string enzymeCallName =
"__enzyme_autodiff_" + m_Function->getNameAsString();
IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName);
DeclarationName nameEnzyme(IIEnzyme);
QualType enzymeFunctionType =
m_Sema.BuildFunctionType(QT, enzymeParamsType, noLoc, nameEnzyme,
originalFnType->getExtProtoInfo());
FunctionDecl* enzymeCallFD = FunctionDecl::Create(
m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType,
m_Function->getTypeSourceInfo(), SC_Extern);
enzymeCallFD->setParams(enzymeParams);
Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs);

// Prepare the statements that assign the gradients to
// non array/pointer type parameters of the original function
if (enzymeRealParams.size() != 0) {
auto gradDeclStmt = BuildVarDecl(QT, "grad", enzymeCall, true);
addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward);

for (int i = 0; i < enzymeRealParams.size(); i++) {
auto LHSExpr = BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsRef[i]));

auto ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
BuildDeclRef(gradDeclStmt), "d_arr");

Expr* gradIndex = dyn_cast<Expr>(
IntegerLiteral::Create(m_Context, llvm::APSInt(std::to_string(i)),
m_Context.UnsignedIntTy, noLoc));
Expr* RHSExpr =
m_Sema.CreateBuiltinArraySubscriptExpr(ME, noLoc, gradIndex, noLoc)
.get();

auto assignExpr = BuildOp(BO_Assign, LHSExpr, RHSExpr);
addToCurrentBlock(assignExpr, direction::forward);
}
} else {
// Add Function call to block
Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs);
addToCurrentBlock(enzymeCall);
}
}
StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
Expand Down
40 changes: 23 additions & 17 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ namespace clad {
return Result;
}

TemplateDecl* VisitorBase::GetCladClassDecl(llvm::StringRef ClassName) {
TemplateDecl*
VisitorBase::LookupTemplateDeclInCladNamespace(llvm::StringRef ClassName) {
NamespaceDecl* CladNS = GetCladNamespace();
CXXScopeSpec CSS;
CSS.Extend(m_Context, CladNS, noLoc, noLoc);
Expand All @@ -365,16 +366,8 @@ namespace clad {
return cast<TemplateDecl>(TapeR.getFoundDecl());
}

QualType
VisitorBase::GetCladClassOfType(TemplateDecl* CladClassDecl,
ArrayRef<QualType> TemplateArgs) {
// Create a list of template arguments.
TemplateArgumentListInfo TLI{};
for (auto T : TemplateArgs) {
TemplateArgument TA = T;
TLI.addArgument(
TemplateArgumentLoc(TA, m_Context.getTrivialTypeSourceInfo(T)));
}
QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl,
TemplateArgumentListInfo& TLI) {
// This will instantiate tape<T> type and return it.
QualType TT =
m_Sema.CheckTemplateIdType(TemplateName(CladClassDecl), noLoc, TLI);
Expand All @@ -388,10 +381,23 @@ namespace clad {
return m_Context.getElaboratedType(ETK_None, NS, TT);
}

QualType VisitorBase::InstantiateTemplate(TemplateDecl* CladClassDecl,
ArrayRef<QualType> TemplateArgs) {
// Create a list of template arguments.
TemplateArgumentListInfo TLI{};
for (auto T : TemplateArgs) {
TemplateArgument TA = T;
TLI.addArgument(
TemplateArgumentLoc(TA, m_Context.getTrivialTypeSourceInfo(T)));
}

return VisitorBase::InstantiateTemplate(CladClassDecl, TLI);
}

TemplateDecl* VisitorBase::GetCladTapeDecl() {
static TemplateDecl* Result = nullptr;
if (!Result)
Result = GetCladClassDecl(/*ClassName=*/"tape");
Result = LookupTemplateDeclInCladNamespace(/*ClassName=*/"tape");
return Result;
}

Expand Down Expand Up @@ -441,7 +447,7 @@ namespace clad {
}

QualType VisitorBase::GetCladTapeOfType(QualType T) {
return GetCladClassOfType(GetCladTapeDecl(), {T});
return InstantiateTemplate(GetCladTapeDecl(), {T});
}

Expr* VisitorBase::BuildCallExprToMemFn(Expr* Base,
Expand Down Expand Up @@ -555,23 +561,23 @@ namespace clad {
TemplateDecl* VisitorBase::GetCladArrayRefDecl() {
static TemplateDecl* Result = nullptr;
if (!Result)
Result = GetCladClassDecl(/*ClassName=*/"array_ref");
Result = LookupTemplateDeclInCladNamespace(/*ClassName=*/"array_ref");
return Result;
}

QualType VisitorBase::GetCladArrayRefOfType(clang::QualType T) {
return GetCladClassOfType(GetCladArrayRefDecl(), {T});
return InstantiateTemplate(GetCladArrayRefDecl(), {T});
}

TemplateDecl* VisitorBase::GetCladArrayDecl() {
static TemplateDecl* Result = nullptr;
if (!Result)
Result = GetCladClassDecl(/*ClassName=*/"array");
Result = LookupTemplateDeclInCladNamespace(/*ClassName=*/"array");
return Result;
}

QualType VisitorBase::GetCladArrayOfType(clang::QualType T) {
return GetCladClassOfType(GetCladArrayDecl(), {T});
return InstantiateTemplate(GetCladArrayDecl(), {T});
}

Expr* VisitorBase::BuildArrayRefSizeExpr(Expr* Base) {
Expand Down
Loading

0 comments on commit 09e2ed0

Please sign in to comment.