Skip to content

Commit

Permalink
Fix appendage of nullptrs to args of a CUDA kernel (#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 authored Sep 23, 2024
1 parent b9a390d commit e2b8e35
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,18 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
CUDA_ARGS CUDA_REST_ARGS Args&&... args) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
void* argPtrs[] = {(void*)&args..., (void*)static_cast<Rest>(nullptr)...};
cudaLaunchKernel((void*)f, grid, block, argPtrs, shared_mem, stream);
constexpr size_t totalArgs = sizeof...(args) + sizeof...(Rest);
std::vector<void*> argPtrs;
argPtrs.reserve(totalArgs);
(argPtrs.push_back(static_cast<void*>(&args)), ...);

void* null_param = nullptr;
for (size_t i = sizeof...(args); i < totalArgs; ++i)
argPtrs[i] = &null_param;

cudaLaunchKernel((void*)f, grid, block, argPtrs.data(), shared_mem,
stream);
return return_type_t<F>();
} else {
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
}
Expand Down

0 comments on commit e2b8e35

Please sign in to comment.