Skip to content

Commit

Permalink
Generalise the handler to reduce the specific implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mconcas committed Sep 13, 2024
1 parent 6fecdd1 commit 8a5900f
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 191 deletions.
119 changes: 45 additions & 74 deletions Common/DCAFitter/GPU/cuda/DCAFitterN.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,113 +36,84 @@ namespace o2::vertexing::device
{
namespace kernel
{
GPUg() void printKernel(o2::vertexing::DCAFitterN<2>* ft)
template <int N = 2>
GPUg() void printKernel(o2::vertexing::DCAFitterN<N>* ft)
{
if (threadIdx.x == 0) {
printf(" =============== GPU DCA Fitter ================\n");
printf(" =============== GPU DCA Fitter %d prongs ================\n", N);
ft->print();
printf(" ===============================================\n");
printf(" =========================================================\n\n");
}
}

GPUg() void processKernel(o2::vertexing::DCAFitterN<2>* ft, o2::track::TrackParCov* t1, o2::track::TrackParCov* t2, int* res)
template <typename Fitter, typename... Tr>
GPUg() void processKernel(Fitter* ft, int* res, Tr*... tracks)
{
*res = ft->process(*t1, *t2);
*res = ft->process(*tracks...);
}
} // namespace kernel

void print(o2::vertexing::DCAFitterN<2>& ft,
const int nBlocks,
const int nThreads)
/// CPU handlers
template <typename Fitter>
void print(const int nBlocks,
const int nThreads,
Fitter& ft)
{
DCAFitterN<2>* ft_device;
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(o2::vertexing::DCAFitterN<2>)));
gpuCheckError(cudaMemcpy(ft_device, &ft, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyHostToDevice));
Fitter* ft_device;
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(Fitter)));
gpuCheckError(cudaMemcpy(ft_device, &ft, sizeof(Fitter), cudaMemcpyHostToDevice));

kernel::printKernel<<<nBlocks, nThreads>>>(ft_device);
kernel::printKernel<Fitter::getNProngs()><<<nBlocks, nThreads>>>(ft_device);

gpuCheckError(cudaPeekAtLastError());
gpuCheckError(cudaDeviceSynchronize());
}

int process(o2::vertexing::DCAFitterN<2>& fitter,
o2::track::TrackParCov& track1,
o2::track::TrackParCov& track2,
const int nBlocks,
const int nThreads)
{
DCAFitterN<2>* ft_device;
o2::track::TrackParCov* t1_device;
o2::track::TrackParCov* t2_device;
int result, *result_device;

gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(o2::vertexing::DCAFitterN<2>)));
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&t1_device), sizeof(o2::track::TrackParCov)));
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&t2_device), sizeof(o2::track::TrackParCov)));
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&result_device), sizeof(int)));

gpuCheckError(cudaMemcpy(ft_device, &fitter, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyHostToDevice));
gpuCheckError(cudaMemcpy(t1_device, &track1, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
gpuCheckError(cudaMemcpy(t2_device, &track2, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));

kernel::processKernel<<<nBlocks, nThreads>>>(ft_device, t1_device, t2_device, result_device);

gpuCheckError(cudaPeekAtLastError());
gpuCheckError(cudaDeviceSynchronize());

gpuCheckError(cudaMemcpy(&result, result_device, sizeof(int), cudaMemcpyDeviceToHost));
gpuCheckError(cudaMemcpy(&fitter, ft_device, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyDeviceToHost));
gpuCheckError(cudaMemcpy(&track1, t1_device, sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
gpuCheckError(cudaMemcpy(&track2, t2_device, sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
gpuCheckError(cudaFree(ft_device));
gpuCheckError(cudaFree(t1_device));
gpuCheckError(cudaFree(t2_device));

gpuCheckError(cudaFree(result_device));

return result;
}

template <int N, class... Tr>
int process(o2::vertexing::DCAFitterN<2>&,
const int nBlocks = 1,
const int nThreads = 1,
template <typename Fitter, class... Tr>
int process(const int nBlocks,
const int nThreads,
Fitter& fitter,
Tr&... args)
{
DCAFitterN<N>* ft_device;
std::array<o2::track::TrackParCov*, N> tracks_device;
// o2::track::TrackParCov* t1_device;
// o2::track::TrackParCov* t2_device;
Fitter* ft_device;
std::array<o2::track::TrackParCov*, Fitter::getNProngs()> tracks_device;
int result, *result_device;

gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(o2::vertexing::DCAFitterN<N>)));
// gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&t1_device), sizeof(o2::track::TrackParCov)));
// gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&t2_device), sizeof(o2::track::TrackParCov)));
for (int iT{0}; iT < N; ++iT) {
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&(tracks_device[iT])), sizeof(o2::track::TrackParCov)));
}
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&ft_device), sizeof(Fitter)));
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&result_device), sizeof(int)));

gpuCheckError(cudaMemcpy(ft_device, &fitter, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyHostToDevice));
gpuCheckError(cudaMemcpy(t1_device, &track1, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
gpuCheckError(cudaMemcpy(t2_device, &track2, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
int iArg{0};
([&] {
gpuCheckError(cudaMalloc(reinterpret_cast<void**>(&(tracks_device[iArg])), sizeof(o2::track::TrackParCov)));
gpuCheckError(cudaMemcpy(tracks_device[iArg], &args, sizeof(o2::track::TrackParCov), cudaMemcpyHostToDevice));
++iArg;
}(),
...);

gpuCheckError(cudaMemcpy(ft_device, &fitter, sizeof(Fitter), cudaMemcpyHostToDevice));

kernel::processKernel<<<nBlocks, nThreads>>>(ft_device, t1_device, t2_device, result_device);
std::apply([&](auto&&... args) { kernel::processKernel<<<nBlocks, nThreads>>>(ft_device, result_device, args...); }, tracks_device);

gpuCheckError(cudaPeekAtLastError());
gpuCheckError(cudaDeviceSynchronize());

gpuCheckError(cudaMemcpy(&result, result_device, sizeof(int), cudaMemcpyDeviceToHost));
gpuCheckError(cudaMemcpy(&fitter, ft_device, sizeof(o2::vertexing::DCAFitterN<2>), cudaMemcpyDeviceToHost));
gpuCheckError(cudaMemcpy(&track1, t1_device, sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
gpuCheckError(cudaMemcpy(&track2, t2_device, sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
gpuCheckError(cudaFree(ft_device));
gpuCheckError(cudaFree(t1_device));
gpuCheckError(cudaFree(t2_device));
gpuCheckError(cudaMemcpy(&fitter, ft_device, sizeof(Fitter), cudaMemcpyDeviceToHost));
iArg = 0;
([&] {
gpuCheckError(cudaMemcpy(&args, tracks_device[iArg], sizeof(o2::track::TrackParCov), cudaMemcpyDeviceToHost));
gpuCheckError(cudaFree(tracks_device[iArg]));
++iArg;
}(),
...);

gpuCheckError(cudaFree(result_device));

return result;
}

template int process(const int, const int, o2::vertexing::DCAFitterN<2>&, o2::track::TrackParCov&, o2::track::TrackParCov&);
template int process(const int, const int, o2::vertexing::DCAFitterN<3>&, o2::track::TrackParCov&, o2::track::TrackParCov&, o2::track::TrackParCov&);
template void print(const int, const int, o2::vertexing::DCAFitterN<2>&);
template void print(const int, const int, o2::vertexing::DCAFitterN<3>&);
} // namespace o2::vertexing::device
Loading

0 comments on commit 8a5900f

Please sign in to comment.