From e2b8e354d13193041a1210888061f12f950b1154 Mon Sep 17 00:00:00 2001 From: Christina Koutsou <74819775+kchristin22@users.noreply.github.com> Date: Mon, 23 Sep 2024 08:21:21 +0300 Subject: [PATCH] Fix appendage of nullptrs to args of a CUDA kernel (#1102) --- include/clad/Differentiator/Differentiator.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 3a8f35faf..72b4d62ab 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -125,8 +125,18 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { CUDA_ARGS CUDA_REST_ARGS Args&&... args) { #if defined(__CUDACC__) && !defined(__CUDA_ARCH__) if (CUDAkernel) { - void* argPtrs[] = {(void*)&args..., (void*)static_cast(nullptr)...}; - cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); + constexpr size_t totalArgs = sizeof...(args) + sizeof...(Rest); + std::vector argPtrs; + argPtrs.reserve(totalArgs); + (argPtrs.push_back(static_cast(&args)), ...); + + void* null_param = nullptr; + for (size_t i = sizeof...(args); i < totalArgs; ++i) + argPtrs[i] = &null_param; + + cudaLaunchKernel((void*)f, grid, block, argPtrs.data(), shared_mem, + stream); + return return_type_t(); } else { return f(static_cast(args)..., static_cast(nullptr)...); }