Skip to content

Commit

Permalink
Add support of CUDA device pullbacks (#1111)
Browse files Browse the repository at this point in the history
For this purpose, a deeper look into atomic ops had to be taken. Atomic ops can only be applied on global or shared GPU memory.
Hence, we needed to identify which call args of the device function pullback are actually kernel args and, thus, global.
The indexes of those args are stored in a vector in the differentiation request for the internal device function and appended to the name of the pullback function.
Later on, when deriving the encountered device function, the global call args are matched with the function's params based on their stored indexes.
This way, the atomic ops are minimized to the absolute necessary number and no error arises.
  • Loading branch information
kchristin22 authored Oct 15, 2024
1 parent f86eede commit fc6d311
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 28 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ struct DiffRequest {
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Indexes of global GPU args of function as a subset of Args.
std::vector<size_t> CUDAGlobalArgsIndexes;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ExternalRMVSource {
/// This is called just before finalising `VisitReturnStmt`.
virtual void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) {}

/// This ic called just before finalising `VisitCallExpr`.
/// This is called just before finalising `VisitCallExpr`.
///
/// \param CE call expression that is being visited.
/// \param CallArgs
Expand Down
9 changes: 8 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ namespace clad {
// several private/protected members of the visitor classes.
friend class ErrorEstimationHandler;
llvm::SmallVector<const clang::ValueDecl*, 16> m_IndependentVars;
/// Set used to keep track of parameter variables w.r.t which the
/// the derivative (gradient) is being computed. This is separate from the
/// m_Variables map because all other intermediate variables will
/// not be stored here.
std::unordered_set<const clang::ValueDecl*> m_ParamVarsWithDiff;
/// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
Expand All @@ -51,6 +56,8 @@ namespace clad {
/// that will be put immediately in the beginning of derivative function
/// block.
Stmts m_Globals;
/// Global GPU args of the function.
std::unordered_set<const clang::ParmVarDecl*> m_CUDAGlobalArgs;
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
Expand Down Expand Up @@ -432,7 +439,7 @@ namespace clad {

/// Helper function that checks whether the function to be derived
/// is meant to be executed only by the GPU
bool shouldUseCudaAtomicOps();
bool shouldUseCudaAtomicOps(const clang::Expr* E);

/// Add call to cuda::atomicAdd for the given LHS and RHS expressions.
///
Expand Down
106 changes: 85 additions & 21 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTapeResult{*this, PushExpr, PopExpr, TapeRef};
}

bool ReverseModeVisitor::shouldUseCudaAtomicOps() {
return m_DiffReq->hasAttr<clang::CUDAGlobalAttr>() ||
(m_DiffReq->hasAttr<clang::CUDADeviceAttr>() &&
!m_DiffReq->hasAttr<clang::CUDAHostAttr>());
bool ReverseModeVisitor::shouldUseCudaAtomicOps(const Expr* E) {
// Same as checking whether this is a function executed by the GPU
if (!m_CUDAGlobalArgs.empty())
if (const auto* DRE = dyn_cast<DeclRefExpr>(E))
if (const auto* PVD = dyn_cast<ParmVarDecl>(DRE->getDecl()))
// Check whether this param is in the global memory of the GPU
return m_CUDAGlobalArgs.find(PVD) != m_CUDAGlobalArgs.end();

return false;
}

clang::Expr* ReverseModeVisitor::BuildCallToCudaAtomicAdd(clang::Expr* LHS,
Expand All @@ -123,8 +128,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema.BuildDeclarationNameExpr(SS, lookupResult, /*ADL=*/true).get();

Expr* finalLHS = LHS;
if (isa<ArraySubscriptExpr>(LHS))
if (auto* UO = dyn_cast<UnaryOperator>(LHS)) {
if (UO->getOpcode() == UnaryOperatorKind::UO_Deref)
finalLHS = UO->getSubExpr()->IgnoreImplicit();
} else if (!LHS->getType()->isPointerType() &&
!LHS->getType()->isReferenceType())
finalLHS = BuildOp(UnaryOperatorKind::UO_AddrOf, LHS);

llvm::SmallVector<Expr*, 2> atomicArgs = {finalLHS, RHS};

assert(!m_Builder.noOverloadExists(UnresolvedLookup, atomicArgs) &&
Expand Down Expand Up @@ -440,6 +450,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);

llvm::ArrayRef<ParmVarDecl*> paramsRef =
clad_compat::makeArrayRef(params.data(), params.size());
gradientFD->setParams(paramsRef);
Expand Down Expand Up @@ -546,6 +562,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
derivativeName += "_" + std::to_string(index);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -587,6 +605,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

m_Derivative->setParams(params);
// Match the global arguments of the call to the device function to the
// pullback function's parameters.
if (!m_DiffReq.CUDAGlobalArgsIndexes.empty())
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index));

m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -1519,7 +1543,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildArraySubscript(target, forwSweepDerivativeIndices);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps()) {
if (shouldUseCudaAtomicOps(target)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(result, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -1583,9 +1607,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: not sure if this is generic.
// Don't update derivatives of record types.
if (!VD->getType()->isRecordType()) {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(add_assign, direction::reverse);
Expr* base = it->second;
if (auto* UO = dyn_cast<UnaryOperator>(it->second))
base = UO->getSubExpr()->IgnoreImpCasts();
if (shouldUseCudaAtomicOps(base)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(it->second, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
} else {
auto* add_assign = BuildOp(BO_AddAssign, it->second, dfdx());
addToCurrentBlock(add_assign, direction::reverse);
}
}
}
return StmtDiff(clonedDRE, it->second, it->second);
Expand Down Expand Up @@ -1728,20 +1760,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
if (auto* DRE = dyn_cast<DeclRefExpr>(ArgDiff.getExpr())) {
// If the arg is used for differentiation of the function, then we
// cannot free it in the end as it's the result to be returned to the
// user.
if (m_ParamVarsWithDiff.find(DRE->getDecl()) ==
m_ParamVarsWithDiff.end())
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
}
}
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
m_DeallocExprs.push_back(call);
m_DeallocExprs.push_back(call_dx);

if (!DerivedCallArgs.empty()) {
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs),
Loc)
.get();
m_DeallocExprs.push_back(call_dx);
}
return StmtDiff();
}

Expand Down Expand Up @@ -1887,6 +1930,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
std::vector<size_t> globalCallArgs;
if (!OverloadedDerivedFn) {
size_t idx = 0;

Expand Down Expand Up @@ -1952,12 +1996,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullback);

// Try to find it in builtin derivatives
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
// Add the indexes of the global args to the custom pullback name
if (!m_CUDAGlobalArgs.empty())
for (size_t i = 0; i < pullbackCallArgs.size(); i++)
if (auto* DRE = dyn_cast<DeclRefExpr>(pullbackCallArgs[i]))
if (auto* param = dyn_cast<ParmVarDecl>(DRE->getDecl()))
if (m_CUDAGlobalArgs.find(param) != m_CUDAGlobalArgs.end()) {
customPullback += "_" + std::to_string(i);
globalCallArgs.emplace_back(i);
}

if (baseDiff.getExpr())
pullbackCallArgs.insert(
pullbackCallArgs.begin(),
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";

OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
Expand Down Expand Up @@ -1990,6 +2045,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs;

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
Expand Down Expand Up @@ -2237,12 +2297,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i);
Expr* gradElem = BuildArraySubscript(gradRef, {idx});
Expr* gradExpr = BuildOp(BO_Mul, dfdx, gradElem);
// Inputs were not pointers, so the output args are not in global GPU
// memory. Hence, no need to use atomic ops.
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}
Expand Down Expand Up @@ -2343,8 +2406,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else {
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, diff_dx);
// Create the (target += dfdx) statement.
if (dfdx()) {
if (shouldUseCudaAtomicOps()) {
if (dfdx() && derivedE) {
if (shouldUseCudaAtomicOps(diff_dx)) {
Expr* atomicCall = BuildCallToCudaAtomicAdd(diff_dx, dfdx());
// Add it to the body statements.
addToCurrentBlock(atomicCall, direction::reverse);
Expand Down Expand Up @@ -4556,6 +4619,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Variables[*it] =
utils::BuildParenExpr(m_Sema, m_Variables[*it]);
}
m_ParamVarsWithDiff.emplace(*it);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,8 @@ namespace clad {
// Return the found overload.
std::string Name = "forward_central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}
Expand Down
Loading

0 comments on commit fc6d311

Please sign in to comment.