Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of cuda kernels as pullback functions #1114

Merged
merged 11 commits into from
Oct 29, 2024
49 changes: 49 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,55 @@ ValueAndPushforward<int, int> cudaDeviceSynchronize_pushforward()
__attribute__((host)) {
return {cudaDeviceSynchronize(), 0};
}

template <typename T>
__global__ void atomicAdd_kernel(T* destPtr, T* srcPtr, size_t N) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x)
atomicAdd(&destPtr[i], srcPtr[i]);
}

template <typename T>
void cudaMemcpy_pullback(T* destPtr, T* srcPtr, size_t count,
cudaMemcpyKind kind, T* d_destPtr, T* d_srcPtr,
size_t* d_count, cudaMemcpyKind* d_kind)
__attribute__((host)) {
T* aux_destPtr = nullptr;
if (kind == cudaMemcpyDeviceToHost) {
*d_kind = cudaMemcpyHostToDevice;
cudaMalloc(&aux_destPtr, count);
} else if (kind == cudaMemcpyHostToDevice) {
*d_kind = cudaMemcpyDeviceToHost;
aux_destPtr = (T*)malloc(count);
}
cudaDeviceSynchronize(); // needed in case user uses another stream for
// kernel execution besides the default one
cudaMemcpy(aux_destPtr, d_destPtr, count, *d_kind);
size_t N = count / sizeof(T);
if (kind == cudaMemcpyDeviceToHost) {
// d_kind is host to device, so d_srcPtr is a device pointer
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
size_t maxThreads = deviceProp.maxThreadsPerBlock;
size_t maxBlocks = deviceProp.maxGridSize[0];

size_t numThreads = std::min(maxThreads, N);
size_t numBlocks = std::min(maxBlocks, (N + numThreads - 1) / numThreads);
custom_derivatives::atomicAdd_kernel<<<numBlocks, numThreads>>>(
d_srcPtr, aux_destPtr, N);
cudaDeviceSynchronize(); // needed in case the user uses another stream for
// kernel execution besides the default one, so we
// need to make sure the data are updated before
// continuing with the rest of the code
cudaFree(aux_destPtr);
} else if (kind == cudaMemcpyHostToDevice) {
// d_kind is device to host, so d_srcPtr is a host pointer
for (size_t i = 0; i < N; i++)
d_srcPtr[i] += aux_destPtr[i];
free(aux_destPtr);
}
}

#endif

CUDA_HOST_DEVICE inline ValueAndPushforward<float, float>
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ namespace clad {
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true);
bool forCustomDerv = true, bool namespaceShouldExist = true,
clang::Expr* CUDAExecConfig = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
/// Shorthand to issues a warning or error.
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ namespace clad {
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
llvm::SmallVectorImpl<clang::Expr*>& outputArgs,
clang::Expr* CUDAExecConfig = nullptr);

public:
ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ namespace clad {
/// \returns The derivative function call.
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args,
clang::Expr* CUDAExecConfig = nullptr);

/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
Expand Down
5 changes: 4 additions & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ namespace clad {
}

bool IsMemoryFunction(const clang::FunctionDecl* FD) {

if (FD->getNameAsString() == "cudaMalloc")
return true;
#if CLANG_VERSION_MAJOR > 12
if (FD->getBuiltinID() == Builtin::BImalloc)
return true;
Expand All @@ -703,6 +704,8 @@ namespace clad {
}

bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) {
if (FD->getNameAsString() == "cudaFree")
return true;
#if CLANG_VERSION_MAJOR > 12
return FD->getBuiltinID() == Builtin::ID::BIfree;
#else
Expand Down
9 changes: 6 additions & 3 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/,
Expr* CUDAExecConfig /*=nullptr*/) {
CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);
Expand All @@ -265,8 +266,10 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
OverloadedFn = m_Sema
.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc,
CUDAExecConfig)
.get();

// Add the custom derivative to the set of derivatives.
// This is required in case the definition of the custom derivative
Expand Down
60 changes: 39 additions & 21 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

// if the function is a global kernel, all its parameters reside in the
// global memory of the GPU
// if the function is a global kernel, all the adjoint parameters reside in
// the global memory of the GPU. To facilitate the process, all the params
// of the kernel are added to the set.
if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);
Expand Down Expand Up @@ -631,7 +632,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!m_DiffReq.CUDAGlobalArgsIndexes.empty())
for (auto index : m_DiffReq.CUDAGlobalArgsIndexes)
m_CUDAGlobalArgs.emplace(m_Derivative->getParamDecl(index));

// if the function is a global kernel, all the adjoint parameters reside in
// the global memory of the GPU. To facilitate the process, all the params
// of the kernel are added to the set.
else if (m_DiffReq->hasAttr<clang::CUDAGlobalAttr>())
for (auto* param : params)
m_CUDAGlobalArgs.emplace(param);
m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -1667,6 +1673,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

Expr* CUDAExecConfig = nullptr;
if (const auto* KCE = dyn_cast<CUDAKernelCallExpr>(CE))
CUDAExecConfig = Clone(KCE->getConfig());

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
Expand All @@ -1675,10 +1685,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ClonedArgs.push_back(Clone(CE->getArg(i)));

SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema);
Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc)
.get();
Expr* Call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc, CUDAExecConfig)
.get();
// Creating a zero derivative
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
Expand Down Expand Up @@ -1825,7 +1836,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
llvm::MutableArrayRef<Expr*>(CallArgs), Loc,
CUDAExecConfig)
.get();
return call;
}
Expand Down Expand Up @@ -1940,7 +1952,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
const_cast<DeclContext*>(FD->getDeclContext()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]

              const_cast<DeclContext*>(FD->getDeclContext()),
              ^

/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (OverloadedDerivedFn)
asGrad = false;
}
Expand Down Expand Up @@ -2041,7 +2055,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
const_cast<DeclContext*>(FD->getDeclContext()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]

              const_cast<DeclContext*>(FD->getDeclContext()),
              ^

/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}
Expand All @@ -2057,10 +2073,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();

OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef,
Loc, pullbackCallArgs, Loc)
.get();
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef, Loc,
pullbackCallArgs, Loc, CUDAExecConfig)
.get();
} else {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
Expand Down Expand Up @@ -2112,14 +2129,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn = GetSingleArgCentralDiffCall(
Clone(CE->getCallee()), DerivedCallArgs[0],
/*targetPos=*/0,
/*numArgs=*/1, DerivedCallArgs);
/*numArgs=*/1, DerivedCallArgs, CUDAExecConfig);
asGrad = !OverloadedDerivedFn;
} else {
auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema);
OverloadedDerivedFn = GetMultiArgCentralDiffCall(
Clone(CE->getCallee()), CEType.getCanonicalType(),
CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts,
DerivedCallArgs, CallArgDx);
DerivedCallArgs, CallArgDx, CUDAExecConfig);
}
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!OverloadedDerivedFn) {
Expand All @@ -2137,7 +2154,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
Loc, pullbackCallArgs, Loc)
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
}
}
Expand Down Expand Up @@ -2250,7 +2267,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
call = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(calleeFnForwPassFD), Loc,
CallArgs, Loc)
CallArgs, Loc, CUDAExecConfig)
.get();
}
auto* callRes = StoreAndRef(call);
Expand Down Expand Up @@ -2285,7 +2302,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
CallArgs, Loc)
CallArgs, Loc, CUDAExecConfig)
.get();
return StmtDiff(call);
}
Expand All @@ -2295,7 +2312,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVectorImpl<Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<Expr*>& args,
llvm::SmallVectorImpl<Expr*>& outputArgs) {
llvm::SmallVectorImpl<Expr*>& outputArgs,
Expr* CUDAExecConfig /*=nullptr*/) {
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
llvm::SmallVector<Expr*, 16U> NumDiffArgs = {};
NumDiffArgs.push_back(targetFuncCall);
Expand Down Expand Up @@ -2336,7 +2354,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
/*namespaceShouldExist=*/false, CUDAExecConfig);
}

StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,8 @@ namespace clad {

Expr* VisitorBase::GetSingleArgCentralDiffCall(
Expr* targetFuncCall, Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<Expr*>& args) {
unsigned numArgs, llvm::SmallVectorImpl<Expr*>& args,
Expr* CUDAExecConfig /*=nullptr*/) {
QualType argType = targetArg->getType();
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
bool isSupported = argType->isArithmeticType();
Expand All @@ -788,7 +789,7 @@ namespace clad {
Name, NumDiffArgs, getCurrentScope(),
/*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
/*namespaceShouldExist=*/false, CUDAExecConfig);
}

void VisitorBase::CallExprDiffDiagnostics(const clang::FunctionDecl* FD,
Expand Down
Loading
Loading