Skip to content

Commit

Permalink
Add support for shared memory keyword and skip overload creation for …
Browse files Browse the repository at this point in the history
…kernels
  • Loading branch information
kchristin22 committed Sep 9, 2024
1 parent ae71945 commit 1faffc4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 27 deletions.
6 changes: 4 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ namespace clad {
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit,
clang::StorageClass SC = clang::SC_None);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down Expand Up @@ -336,7 +337,8 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit,
clang::StorageClass SC = clang::SC_None);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
50 changes: 33 additions & 17 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(BuildDeclStmt(gradientVD));
}

// If the function is a global kernel, we need to transform it
// into a device function when calling it inside the overload function
// which is the final global kernel returned.
if (m_Derivative->hasAttr<clang::CUDAGlobalAttr>()) {
m_Derivative->dropAttr<clang::CUDAGlobalAttr>();
m_Derivative->addAttr(clang::CUDADeviceAttr::CreateImplicit(m_Context));
}

Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs,
/*UseRefQualifiedThisObj=*/true);
addToCurrentBlock(callExpr);
Expand Down Expand Up @@ -346,7 +338,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool shouldCreateOverload = false;
// FIXME: Gradient overload doesn't know how to handle additional parameters
// added by the plugins yet.
if (request.Mode != DiffMode::jacobian && numExtraParam == 0)
if (request.Mode != DiffMode::jacobian && numExtraParam == 0 &&
!FD->hasAttr<CUDAGlobalAttr>())
shouldCreateOverload = true;
if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty())
// If the overload is already created, we don't need to create it again.
Expand Down Expand Up @@ -2884,10 +2877,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = getZeroInit(VDDerivedType);
}
}
if (initializeDerivedVar)
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VD->getInitStyle());
if (initializeDerivedVar) {
if (VD->hasAttr<clang::CUDASharedAttr>()) {
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), nullptr, false,
nullptr, VD->getInitStyle(), VD->getStorageClass());
VDDerived->addAttr(
VD->getAttr<clang::CUDASharedAttr>()->clone(m_Context));
} else {
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit,
false, nullptr, VD->getInitStyle());
}
}
}

// If `VD` is a reference to a local variable, then it is already
Expand Down Expand Up @@ -2970,11 +2972,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDClone = BuildGlobalVarDecl(
VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()),
VD->isDirectInit());
VD->isDirectInit(), nullptr,
clang::VarDecl::InitializationStyle::CInit, VD->getStorageClass());
else
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
VDClone =
BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(), nullptr,
VD->getInitStyle(), VD->getStorageClass());
if (VD->hasAttr<clang::CUDASharedAttr>())
VDClone->addAttr(VD->getAttr<clang::CUDASharedAttr>()->clone(m_Context));
if (isPointerType && derivedVDE) {
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
Expand Down Expand Up @@ -3083,6 +3089,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Decl*, 4> declsDiff;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;
llvm::SmallVector<Stmt*, 16> sharedMemInits;
// reverse_mode_forward_pass does not have a reverse pass so declarations
// don't have to be moved to the function global scope.
bool promoteToFnScope =
Expand Down Expand Up @@ -3172,6 +3179,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
if (VD->hasAttr<clang::CUDASharedAttr>()) {
VarDecl* VDDerived = VDDiff.getDecl_dx();
Expr* declRef = BuildDeclRef(VDDerived);
Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, declRef,
getZeroInit(VDDerived->getType()));
sharedMemInits.push_back(assignToZero);
}
}
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
Expand Down Expand Up @@ -3210,6 +3224,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmts& block =
promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward);
addToBlock(DSDiff, block);
for (Stmt* sharedMemInitsStmt : sharedMemInits)
addToBlock(sharedMemInitsStmt, block);
}

if (m_ExternalSource) {
Expand Down
14 changes: 8 additions & 6 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ namespace clad {
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS,
StorageClass SC) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto* VD = VarDecl::Create(
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);
auto* VD =
VarDecl::Create(m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC);

if (Init) {
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
Expand Down Expand Up @@ -149,9 +150,10 @@ namespace clad {
VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
llvm::StringRef prefix, Expr* Init,
bool DirectInit, TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS,
StorageClass SC) {
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix),
m_DerivativeFnScope, Init, DirectInit, TSI, IS);
m_DerivativeFnScope, Init, DirectInit, TSI, IS, SC);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
Expand Down
6 changes: 4 additions & 2 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ __global__ void add_kernel_2(int *out, int *in) {
//CHECK-NEXT: }

__global__ void add_kernel_3(int *out, int *in) {
__shared__ int shared;
shared = 1;
int index = threadIdx.x + blockIdx.x * blockDim.x;
out[index] += in[index];
out[index] += in[index] + shared;
}

// CHECK: void add_kernel_3_grad(int *out, int *in, int *_d_out, int *_d_in) {
Expand Down Expand Up @@ -168,7 +170,7 @@ int main(void) {

cudaMemset(d_in, 0, 10 * sizeof(int));
auto add_3 = clad::gradient(add_kernel_3, "in, out");
add_3.execute_kernel(dim3(10), dim3(1), dummy_out, dummy_in, d_out, d_in);
add_3.execute_kernel(dim3(10), dim3(1), sizeof(int), cudaStream, dummy_out, dummy_in, d_out, d_in);
cudaDeviceSynchronize();

cudaMemcpy(res, d_in, 10 * sizeof(int), cudaMemcpyDeviceToHost);
Expand Down

0 comments on commit 1faffc4

Please sign in to comment.