Skip to content

Commit

Permalink
Revert "Add support for shared memory keyword and skip overload creat…
Browse files Browse the repository at this point in the history
…ion for kernels"

This reverts commit 1faffc4.
  • Loading branch information
kchristin22 committed Sep 12, 2024
1 parent 1faffc4 commit 52dd28b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 49 deletions.
6 changes: 2 additions & 4 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ namespace clad {
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit,
clang::StorageClass SC = clang::SC_None);
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down Expand Up @@ -337,8 +336,7 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit,
clang::StorageClass SC = clang::SC_None);
clang::VarDecl::InitializationStyle::CInit);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
50 changes: 17 additions & 33 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(BuildDeclStmt(gradientVD));
}

// If the function is a global kernel, we need to transform it
// into a device function when calling it inside the overload function
// which is the final global kernel returned.
if (m_Derivative->hasAttr<clang::CUDAGlobalAttr>()) {
m_Derivative->dropAttr<clang::CUDAGlobalAttr>();
m_Derivative->addAttr(clang::CUDADeviceAttr::CreateImplicit(m_Context));
}

Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs,
/*UseRefQualifiedThisObj=*/true);
addToCurrentBlock(callExpr);
Expand Down Expand Up @@ -338,8 +346,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool shouldCreateOverload = false;
// FIXME: Gradient overload doesn't know how to handle additional parameters
// added by the plugins yet.
if (request.Mode != DiffMode::jacobian && numExtraParam == 0 &&
!FD->hasAttr<CUDAGlobalAttr>())
if (request.Mode != DiffMode::jacobian && numExtraParam == 0)
shouldCreateOverload = true;
if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty())
// If the overload is already created, we don't need to create it again.
Expand Down Expand Up @@ -2877,19 +2884,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = getZeroInit(VDDerivedType);
}
}
if (initializeDerivedVar) {
if (VD->hasAttr<clang::CUDASharedAttr>()) {
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), nullptr, false,
nullptr, VD->getInitStyle(), VD->getStorageClass());
VDDerived->addAttr(
VD->getAttr<clang::CUDASharedAttr>()->clone(m_Context));
} else {
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit,
false, nullptr, VD->getInitStyle());
}
}
if (initializeDerivedVar)
VDDerived = BuildGlobalVarDecl(
VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false,
nullptr, VD->getInitStyle());
}

// If `VD` is a reference to a local variable, then it is already
Expand Down Expand Up @@ -2972,15 +2970,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDClone = BuildGlobalVarDecl(
VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()),
VD->isDirectInit(), nullptr,
clang::VarDecl::InitializationStyle::CInit, VD->getStorageClass());
VD->isDirectInit());
else
VDClone =
BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(), nullptr,
VD->getInitStyle(), VD->getStorageClass());
if (VD->hasAttr<clang::CUDASharedAttr>())
VDClone->addAttr(VD->getAttr<clang::CUDASharedAttr>()->clone(m_Context));
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
if (isPointerType && derivedVDE) {
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
Expand Down Expand Up @@ -3089,7 +3083,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Decl*, 4> declsDiff;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;
llvm::SmallVector<Stmt*, 16> sharedMemInits;
// reverse_mode_forward_pass does not have a reverse pass so declarations
// don't have to be moved to the function global scope.
bool promoteToFnScope =
Expand Down Expand Up @@ -3179,13 +3172,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
if (VD->hasAttr<clang::CUDASharedAttr>()) {
VarDecl* VDDerived = VDDiff.getDecl_dx();
Expr* declRef = BuildDeclRef(VDDerived);
Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, declRef,
getZeroInit(VDDerived->getType()));
sharedMemInits.push_back(assignToZero);
}
}
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
Expand Down Expand Up @@ -3224,8 +3210,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmts& block =
promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward);
addToBlock(DSDiff, block);
for (Stmt* sharedMemInitsStmt : sharedMemInits)
addToBlock(sharedMemInitsStmt, block);
}

if (m_ExternalSource) {
Expand Down
14 changes: 6 additions & 8 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,12 @@ namespace clad {
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS,
StorageClass SC) {
VarDecl::InitializationStyle IS) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto* VD =
VarDecl::Create(m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC);
auto* VD = VarDecl::Create(
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);

if (Init) {
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
Expand Down Expand Up @@ -150,10 +149,9 @@ namespace clad {
VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
llvm::StringRef prefix, Expr* Init,
bool DirectInit, TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS,
StorageClass SC) {
VarDecl::InitializationStyle IS) {
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix),
m_DerivativeFnScope, Init, DirectInit, TSI, IS, SC);
m_DerivativeFnScope, Init, DirectInit, TSI, IS);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
Expand Down
6 changes: 2 additions & 4 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ __global__ void add_kernel_2(int *out, int *in) {
//CHECK-NEXT: }

__global__ void add_kernel_3(int *out, int *in) {
__shared__ int shared;
shared = 1;
int index = threadIdx.x + blockIdx.x * blockDim.x;
out[index] += in[index] + shared;
out[index] += in[index];
}

// CHECK: void add_kernel_3_grad(int *out, int *in, int *_d_out, int *_d_in) {
Expand Down Expand Up @@ -170,7 +168,7 @@ int main(void) {

cudaMemset(d_in, 0, 10 * sizeof(int));
auto add_3 = clad::gradient(add_kernel_3, "in, out");
add_3.execute_kernel(dim3(10), dim3(1), sizeof(int), cudaStream, dummy_out, dummy_in, d_out, d_in);
add_3.execute_kernel(dim3(10), dim3(1), dummy_out, dummy_in, d_out, d_in);
cudaDeviceSynchronize();

cudaMemcpy(res, d_in, 10 * sizeof(int), cudaMemcpyDeviceToHost);
Expand Down

0 comments on commit 52dd28b

Please sign in to comment.