From 431791fedfa9a57af4e21619ec51c1f10431f124 Mon Sep 17 00:00:00 2001 From: Christina Koutsou <74819775+kchristin22@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:20:28 +0300 Subject: [PATCH] Enable computation of CUDA global kernels derivative in reverse mode (#1059) --- include/clad/Differentiator/Differentiator.h | 133 +++++++++++++++---- lib/Differentiator/DiffPlanner.cpp | 21 ++- lib/Differentiator/ReverseModeVisitor.cpp | 8 ++ test/CUDA/GradientKernels.cu | 79 +++++++++++ test/lit.cfg | 6 +- 5 files changed, 215 insertions(+), 32 deletions(-) create mode 100644 test/CUDA/GradientKernels.cu diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 647428423..db9b699f2 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -38,16 +38,23 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { return count; } - /// Tape type used for storing values in reverse-mode AD inside loops. - template - using tape = tape_impl; +#ifdef __CUDACC__ +#define CUDA_ARGS bool CUDAkernel, dim3 grid, dim3 block, +#define CUDA_REST_ARGS size_t shared_mem, cudaStream_t stream, +#else +#define CUDA_ARGS +#define CUDA_REST_ARGS +#endif - /// Add value to the end of the tape, return the same value. - template - CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { - to.emplace_back(std::forward(val)...); - return to.back(); - } +/// Tape type used for storing values in reverse-mode AD inside loops. +template using tape = tape_impl; + +/// Add value to the end of the tape, return the same value. +template +CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { + to.emplace_back(std::forward(val)...); + return to.back(); +} /// Add value to the end of the tape, return the same value. /// A specialization for clad::array_ref types to use in reverse mode. @@ -115,17 +122,35 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename std::enable_if::type = true> CUDA_HOST_DEVICE return_type_t execute_with_default_args(list, F f, list, - Args&&... args) { + CUDA_ARGS CUDA_REST_ARGS Args&&... args) { +#if defined(__CUDACC__) && !defined(__CUDA_ARCH__) + if (CUDAkernel) { + void* argPtrs[] = {(void*)&args..., (void*)static_cast(nullptr)...}; + cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); + } else { + return f(static_cast(args)..., static_cast(nullptr)...); + } +#else return f(static_cast(args)..., static_cast(nullptr)...); +#endif } template ::type = true> - return_type_t execute_with_default_args(list, F f, - list, - Args&&... args) { + return_type_t + execute_with_default_args(list, F f, list, + CUDA_ARGS CUDA_REST_ARGS Args&&... args) { +#if defined(__CUDACC__) && !defined(__CUDA_ARCH__) + if (CUDAkernel) { + void* argPtrs[] = {(void*)&args...}; + cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); + } else { + return f(static_cast(args)...); + } +#else return f(static_cast(args)...); +#endif } // for executing member-functions @@ -167,12 +192,13 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { CladFunctionType m_Function; char* m_Code; FunctorType *m_Functor = nullptr; + bool m_CUDAkernel = false; public: - CUDA_HOST_DEVICE CladFunction(CladFunctionType f, - const char* code, - FunctorType* functor = nullptr) - : m_Functor(functor) { + CUDA_HOST_DEVICE CladFunction(CladFunctionType f, const char* code, + FunctorType* functor = nullptr, + bool CUDAkernel = false) + : m_Functor(functor), m_CUDAkernel(CUDAkernel) { assert(f && "Must pass a non-0 argument."); if (size_t length = GetLength(code)) { m_Function = f; @@ -210,9 +236,37 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { printf("CladFunction is invalid\n"); return static_cast>(return_type_t()); } + if (m_CUDAkernel) { + printf("Use execute_kernel() for global CUDA kernels\n"); + return static_cast>(return_type_t()); + } // here static_cast is used to achieve perfect forwarding +#ifdef __CUDACC__ + return execute_helper(m_Function, m_CUDAkernel, dim3(0), dim3(0), + static_cast(args)...); +#else return execute_helper(m_Function, static_cast(args)...); +#endif + } + +#ifdef __CUDACC__ + template + typename std::enable_if::value, + return_type_t>::type + execute_kernel(dim3 grid, dim3 block, Args&&... args) CUDA_HOST_DEVICE { + if (!m_Function) { + printf("CladFunction is invalid\n"); + return static_cast>(return_type_t()); + } + if (!m_CUDAkernel) { + printf("Use execute() for non-global CUDA kernels\n"); + return static_cast>(return_type_t()); + } + + return execute_helper(m_Function, m_CUDAkernel, grid, block, + static_cast(args)...); } +#endif /// `Execute` overload to be used when derived function type cannot be /// deduced. One reason for this can be when user tries to differentiate @@ -258,12 +312,39 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// Helper function for executing non-member derived functions. template CUDA_HOST_DEVICE return_type_t - execute_helper(Fn f, Args&&... args) { + execute_helper(Fn f, CUDA_ARGS Args&&... args) { // `static_cast` is required here for perfect forwarding. - return execute_with_default_args( - DropArgs_t{}, f, - TakeNFirstArgs_t{}, - static_cast(args)...); +#if defined(__CUDACC__) + if constexpr (sizeof...(Args) >= 2) { + auto secondArg = + std::get<1>(std::forward_as_tuple(std::forward(args)...)); + if constexpr (std::is_same, + cudaStream_t>::value) { + return [&](auto shared_mem, cudaStream_t stream, auto&&... args_) { + return execute_with_default_args( + DropArgs_t{}, f, + TakeNFirstArgs_t{}, + CUDAkernel, grid, block, shared_mem, stream, + static_cast(args_)...); + }(static_cast(args)...); + } else { + return execute_with_default_args( + DropArgs_t{}, f, + TakeNFirstArgs_t{}, CUDAkernel, + grid, block, 0, nullptr, static_cast(args)...); + } + } else { + return execute_with_default_args( + DropArgs_t{}, f, + TakeNFirstArgs_t{}, CUDAkernel, + grid, block, 0, nullptr, static_cast(args)...); + } +#else + return execute_with_default_args( + DropArgs_t{}, f, + TakeNFirstArgs_t{}, + static_cast(args)...); +#endif } /// Helper functions for executing member derived functions. @@ -393,10 +474,10 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { annotate("G"))) CUDA_HOST_DEVICE gradient(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), - const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code); + const char* code = "", bool CUDAkernel = false) { + assert(f && "Must pass in a non-0 argument"); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code, nullptr, CUDAkernel); } /// Specialization for differentiating functors. diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 0f7829a20..e25cabf40 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -177,9 +177,6 @@ namespace clad { void DiffRequest::updateCall(FunctionDecl* FD, FunctionDecl* OverloadedFD, Sema& SemaRef) { CallExpr* call = this->CallContext; - // Index of "code" parameter: - auto codeArgIdx = static_cast(call->getNumArgs()) - 1; - auto derivedFnArgIdx = codeArgIdx - 1; assert(call && "Must be set"); assert(FD && "Trying to update with null FunctionDecl"); @@ -191,6 +188,24 @@ namespace clad { ASTContext& C = SemaRef.getASTContext(); FunctionDecl* replacementFD = OverloadedFD ? OverloadedFD : FD; + + // Index of "CUDAkernel" parameter: + int numArgs = static_cast(call->getNumArgs()); + if (numArgs > 4) { + auto kernelArgIdx = numArgs - 1; + auto* cudaKernelFlag = + SemaRef + .ActOnCXXBoolLiteral(noLoc, + replacementFD->hasAttr() + ? tok::kw_true + : tok::kw_false) + .get(); + call->setArg(kernelArgIdx, cudaKernelFlag); + numArgs--; + } + auto codeArgIdx = numArgs - 1; + auto derivedFnArgIdx = numArgs - 2; + // Create ref to generated FD. DeclRefExpr* DRE = DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ff5020ef5..c1e5442f2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -240,6 +240,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(BuildDeclStmt(gradientVD)); } + // If the function is a global kernel, we need to transform it + // into a device function when calling it inside the overload function + // which is the final global kernel returned. + if (m_Derivative->hasAttr()) { + m_Derivative->dropAttr(); + m_Derivative->addAttr(clang::CUDADeviceAttr::CreateImplicit(m_Context)); + } + Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs, /*UseRefQualifiedThisObj=*/true); addToCurrentBlock(callExpr); diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu new file mode 100644 index 000000000..0a5e40d54 --- /dev/null +++ b/test/CUDA/GradientKernels.cu @@ -0,0 +1,79 @@ +// RUN: %cladclang_cuda -I%S/../../include %s -fsyntax-only \ +// RUN: %cudasmlevel --cuda-path=%cudapath -Xclang -verify 2>&1 | %filecheck %s + +// RUN: %cladclang_cuda -I%S/../../include %s -xc++ %cudasmlevel \ +// RUN: --cuda-path=%cudapath -L/usr/local/cuda/lib64 -lcudart_static \ +// RUN: -L%cudapath/lib64/stubs \ +// RUN: -ldl -lrt -pthread -lm -lstdc++ -lcuda -lnvrtc + +// REQUIRES: cuda-runtime + +// expected-no-diagnostics + +// XFAIL: clang-15 + +#include "clad/Differentiator/Differentiator.h" + +__global__ void kernel(int *a) { + *a *= *a; +} + +// CHECK: void kernel_grad(int *a, int *_d_a) { +//CHECK-NEXT: int _t0 = *a; +//CHECK-NEXT: *a *= *a; +//CHECK-NEXT: { +//CHECK-NEXT: *a = _t0; +//CHECK-NEXT: int _r_d0 = *_d_a; +//CHECK-NEXT: *_d_a = 0; +//CHECK-NEXT: *_d_a += _r_d0 * *a; +//CHECK-NEXT: *_d_a += *a * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } + +void fake_kernel(int *a) { + *a *= *a; +} + +int main(void) { + int *a = (int*)malloc(sizeof(int)); + *a = 2; + int *d_a; + cudaMalloc(&d_a, sizeof(int)); + cudaMemcpy(d_a, a, sizeof(int), cudaMemcpyHostToDevice); + + int *asquare = (int*)malloc(sizeof(int)); + *asquare = 1; + int *d_square; + cudaMalloc(&d_square, sizeof(int)); + cudaMemcpy(d_square, asquare, sizeof(int), cudaMemcpyHostToDevice); + + auto test = clad::gradient(kernel); + dim3 grid(1); + dim3 block(1); + cudaStream_t cudaStream; + cudaStreamCreate(&cudaStream); + test.execute_kernel(grid, block, 0, cudaStream, d_a, d_square); + + cudaDeviceSynchronize(); + + cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost); + printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 2, a^2 = 4 + + auto error = clad::gradient(fake_kernel); + error.execute_kernel(grid, block, d_a, d_square); // CHECK-EXEC: Use execute() for non-global CUDA kernels + + test.execute(d_a, d_square); // CHECK-EXEC: Use execute_kernel() for global CUDA kernels + + cudaMemset(d_a, 5, 1); // first byte is set to 5 + cudaMemset(d_square, 1, 1); + + test.execute_kernel(grid, block, d_a, d_square); + cudaDeviceSynchronize(); + + cudaMemcpy(asquare, d_square, sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(a, d_a, sizeof(int), cudaMemcpyDeviceToHost); + printf("a = %d, a^2 = %d\n", *a, *asquare); // CHECK-EXEC: a = 5, a^2 = 10 + + return 0; +} \ No newline at end of file diff --git a/test/lit.cfg b/test/lit.cfg index 19ae580f8..59baa61b8 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -257,13 +257,13 @@ lit.util.usePlatformSdkOnDarwin(config, lit_config) #\ -plugin-arg-ad -Xclang -fdump-derived-fn -Xclang -load -Xclang../../Debug+Asserts/lib/libclad.so #FIXME: we need to introduce a better way to check compatible version of clang, propagating #-fvalidate-clang-version flag is not enough. -flags = ' -std=c++11 -Xclang -add-plugin -Xclang clad -Xclang \ +flags = ' -Xclang -add-plugin -Xclang clad -Xclang \ -plugin-arg-clad -Xclang -fdump-derived-fn -Xclang \ -load -Xclang ' + config.cladlib -config.substitutions.append( ('%cladclang_cuda', config.clang + flags) ) +config.substitutions.append( ('%cladclang_cuda', config.clang + ' -std=c++17' + flags) ) -config.substitutions.append( ('%cladclang', config.clang + '++ -DCLAD_NO_NUM_DIFF ' + flags) ) +config.substitutions.append( ('%cladclang', config.clang + '++ -DCLAD_NO_NUM_DIFF ' + ' -std=c++11' + flags) ) config.substitutions.append( ('%cladlib', config.cladlib) )