diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 3a8f35faf..72b4d62ab 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -125,8 +125,18 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { CUDA_ARGS CUDA_REST_ARGS Args&&... args) { #if defined(__CUDACC__) && !defined(__CUDA_ARCH__) if (CUDAkernel) { - void* argPtrs[] = {(void*)&args..., (void*)static_cast(nullptr)...}; - cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream); + constexpr size_t totalArgs = sizeof...(args) + sizeof...(Rest); + std::vector argPtrs; + argPtrs.reserve(totalArgs); + (argPtrs.push_back(static_cast(&args)), ...); + + void* null_param = nullptr; + for (size_t i = sizeof...(args); i < totalArgs; ++i) + argPtrs[i] = &null_param; + + cudaLaunchKernel((void*)f, grid, block, argPtrs.data(), shared_mem, + stream); + return return_type_t(); } else { return f(static_cast(args)..., static_cast(nullptr)...); }