Skip to content

Commit

Permalink
Add support for device functions as pullback functions
Browse files Browse the repository at this point in the history
For this purpose, a deeper look into atomic ops had to be taken. Atomic ops can only be applied on global or shared GPU memory.
Hence, we needed to identify which call args of the device function pullback are actually kernel args and, thus, global.
The indexes of those args are stored in a vector in the differentiation request for the internal device function and appended to the name of the pullback function.
Later on, when deriving the encountered device function, the global call args are matched with the function's params based on their stored indexes.
This way, the atomic ops are minimized to the absolute necessary number and no error arises.
  • Loading branch information
kchristin22 committed Oct 13, 2024
1 parent f86eede commit 871d538
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 28 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ struct DiffRequest {
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Indexes of global GPU args of function as a subset of Args.
std::vector<size_t> GlobalArgsIndexes;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ExternalRMVSource {
/// This is called just before finalising `VisitReturnStmt`.
virtual void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) {}

/// This ic called just before finalising `VisitCallExpr`.
/// This is called just before finalising `VisitCallExpr`.
///
/// \param CE call expression that is being visited.
/// \param CallArgs
Expand Down
9 changes: 8 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ namespace clad {
// several private/protected members of the visitor classes.
friend class ErrorEstimationHandler;
llvm::SmallVector<const clang::ValueDecl*, 16> m_IndependentVars;
/// Set used to keep track of parameter variables w.r.t which the
/// the derivative (gradient) is being computed. This is separate from the
/// m_Variables map because all other intermediate variables will
/// not be stored here.
std::unordered_set<const clang::ValueDecl*> m_ParamVarsWithDiff;
/// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
Expand All @@ -51,6 +56,8 @@ namespace clad {
/// that will be put immediately in the beginning of derivative function
/// block.
Stmts m_Globals;
/// Global GPU args of the function.
std::unordered_set<const clang::ParmVarDecl*> m_GlobalArgs;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
Expand Down Expand Up @@ -432,7 +439,7 @@ namespace clad {

/// Helper function that checks whether the function to be derived
/// is meant to be executed only by the GPU
bool shouldUseCudaAtomicOps();
bool shouldUseCudaAtomicOps(const clang::Expr* E);

/// Add call to cuda::atomicAdd for the given LHS and RHS expressions.
///
Expand Down
113 changes: 92 additions & 21 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTapeResult{*this, PushExpr, PopExpr, TapeRef};
}

bool ReverseModeVisitor::shouldUseCudaAtomicOps() {
return m_DiffReq->hasAttr<clang::CUDAGlobalAttr>() ||
(m_DiffReq->hasAttr<clang::CUDADeviceAttr>() &&
!m_DiffReq->hasAttr<clang::CUDAHostAttr>());
bool ReverseModeVisitor::shouldUseCudaAtomicOps(const Expr* E) {
// Same as checking whether this is a function executed by the GPU
if (!m_GlobalArgs.empty())
if (const auto* DRE = dyn_cast<DeclRefExpr>(E))
if (const auto* PVD = dyn_cast<ParmVarDecl>(DRE->getDecl()))
// we need to check whether this param is in the global memory of the
// GPU
return m_GlobalArgs.find(PVD) != m_GlobalArgs.end();

return false;
}

clang::Expr* ReverseModeVisitor::BuildCallToCudaAtomicAdd(clang::Expr* LHS,
Expand All @@ -123,8 +129,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema.BuildDeclarationNameExpr(SS, lookupResult, /*ADL=*/true).get();

Expr* finalLHS = LHS;
if (isa<ArraySubscriptExpr>(LHS))
if (auto* UO = dyn_cast<UnaryOperator>(LHS)) {
if (UO->getOpcode() == UnaryOperatorKind::UO_Deref)
finalLHS = UO->getSubExpr()->IgnoreImplicit();
} else if (!LHS->getType()->isPointerType() &&
!LHS->getType()->isReferenceType())
finalLHS = BuildOp(UnaryOperatorKind::UO_AddrOf, LHS);

llvm::SmallVector<Expr*, 2> atomicArgs = {finalLHS, RHS};

assert(!m_Builder.noOverloadExists(UnresolvedLookup, atomicArgs) &&
Expand Down Expand Up @@ -440,6 +451,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);

llvm::ArrayRef<ParmVarDecl*> paramsRef =
clad_compat::makeArrayRef(params.data(), params.size());
gradientFD->setParams(paramsRef);
Expand Down Expand Up @@ -546,6 +563,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
for (auto index : m_DiffReq.GlobalArgsIndexes)
derivativeName += "_" + std::to_string(index);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -587,6 +606,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

m_Derivative->setParams(params);
// Match the global arguments of the call to the device function to the
// pullback function's parameters.
if (!m_DiffReq.GlobalArgsIndexes.empty())
for (auto index : m_DiffReq.GlobalArgsIndexes)
m_GlobalArgs.emplace(m_Derivative->getParamDecl(index));
// If the function is a global kernel, all its parameters reside in the
// global memory of the GPU
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto param : params)
m_GlobalArgs.emplace(param);

Check warning on line 618 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L617-L618

Added lines #L617 - L618 were not covered by tests
m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -1519,7 +1548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildArraySubscript(target, forwSweepDerivativeIndices);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps()) {
if (shouldUseCudaAtomicOps(target)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(result, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -1583,9 +1612,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: not sure if this is generic.
// Don't update derivatives of record types.
if (!VD->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
Expr* base = it->second;
if (auto* UO = dyn_cast<UnaryOperator>(it->second))
base = UO->getSubExpr()->IgnoreImpCasts();
if (shouldUseCudaAtomicOps(base)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
} else {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
addToCurrentBlock(add_assign, direction::reverse);
}
}
}
return StmtDiff(clonedDRE, it->second, it->second);
Expand Down Expand Up @@ -1728,20 +1765,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
if (auto* DRE = dyn_cast<DeclRefExpr>(ArgDiff.getExpr())) {
// If the arg is used for differentiation of the function, then we
// cannot free it in the end as it's the result to be returned to the
// user.
if (m_ParamVarsWithDiff.find(DRE->getDecl()) ==
m_ParamVarsWithDiff.end())
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
}
}
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
m_DeallocExprs.push_back(call);
m_DeallocExprs.push_back(call_dx);

if (!DerivedCallArgs.empty()) {
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs),
Loc)
.get();
m_DeallocExprs.push_back(call_dx);
}
return StmtDiff();
}

Expand Down Expand Up @@ -1887,6 +1935,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
std::vector<size_t> globalCallArgs;
if (!OverloadedDerivedFn) {
size_t idx = 0;

Expand Down Expand Up @@ -1952,12 +2001,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullback);

// Try to find it in builtin derivatives
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
// Add the indexes of the global args to the custom pullback name
if (!m_GlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
if (auto* param = dyn_cast<ParmVarDecl>(DRE->getDecl()))
if (m_GlobalArgs.find(param) != m_GlobalArgs.end()) {
customPullback += "_" + std::to_string(i);
globalCallArgs.emplace_back(i);
}

if (baseDiff.getExpr())
pullbackCallArgs.insert(
pullbackCallArgs.begin(),
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";

OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
Expand Down Expand Up @@ -1990,6 +2050,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
pullbackRequest.GlobalArgsIndexes = globalCallArgs;

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
Expand Down Expand Up @@ -2237,12 +2302,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i);
Expr* gradElem = BuildArraySubscript(gradRef, {idx});
Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem);
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
if (shouldUseCudaAtomicOps(outputArgs[i]))
PostCallStmts.push_back(
BuildCallToCudaAtomicAdd(outputArgs[i], gradExpr));

Check warning on line 2307 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L2306-L2307

Added lines #L2306 - L2307 were not covered by tests
else
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}
Expand Down Expand Up @@ -2344,7 +2414,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps()) {
if (shouldUseCudaAtomicOps(diff_dx)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -4556,6 +4626,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Variables[*it] =
utils::BuildParenExpr(m_Sema, m_Variables[*it]);
}
m_ParamVarsWithDiff.emplace(*it);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,8 @@ namespace clad {
// Return the found overload.
std::string Name = "forward_central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}
Expand Down
Loading

0 comments on commit 871d538

Please sign in to comment.