Skip to content

Commit

Permalink
Stop exporting fill_k kernel as that causes ODR violations (#6021)
Browse files Browse the repository at this point in the history
Removes the usage of `fill_k` in `cpp/src/fil/fil.cu` as that breaks the ODR requirements of CUDA whole compilation. To allow setting the shared memory of the kernel we move the logic over to `cpp/src/fil/infer.cu` and provide a c++ interface.

Authors:
  - Robert Maynard (https://github.com/robertmaynard)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #6021
  • Loading branch information
robertmaynard authored Aug 15, 2024
1 parent 0fbd919 commit f677791
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
11 changes: 3 additions & 8 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,12 @@ struct compute_smem_footprint : dispatch_functor<int> {
int run(predict_params);
};

template <int NITEMS,
leaf_algo_t leaf_algo,
bool cols_in_shmem,
bool CATS_SUPPORTED,
class storage_type>
__attribute__((visibility("hidden"))) __global__ void infer_k(storage_type forest,
predict_params params);

// infer() calls the inference kernel with the parameters on the stream
template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream);

template <typename storage_type>
void infer_shared_mem_size(predict_params params, int max_shm);

} // namespace fil
} // namespace ML
29 changes: 5 additions & 24 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -349,26 +349,6 @@ struct forest {
cat_sets_device_owner cat_sets_;
};

template <typename storage_type>
struct opt_into_arch_dependent_shmem : dispatch_functor<void> {
const int max_shm;
opt_into_arch_dependent_shmem(int max_shm_) : max_shm(max_shm_) {}

template <typename KernelParams = KernelTemplateParams<>>
void run(predict_params p)
{
auto kernel = infer_k<KernelParams::N_ITEMS,
KernelParams::LEAF_ALGO,
KernelParams::COLS_IN_SHMEM,
KernelParams::CATS_SUPPORTED,
storage_type>;
// p.shm_sz might be > max_shm or < MAX_SHM_STD, but we should not check for either, because
// we don't run on both proba_ssp_ and class_ssp_ (only class_ssp_). This should be quick.
RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shm));
}
};

template <typename real_t>
struct dense_forest<dense_node<real_t>> : forest<real_t> {
using node_t = dense_node<real_t>;
Expand Down Expand Up @@ -427,8 +407,9 @@ struct dense_forest<dense_node<real_t>> : forest<real_t> {
h.get_stream()));

// predict_proba is a runtime parameter, and opt-in is unconditional
dispatch_on_fil_template_params(opt_into_arch_dependent_shmem<storage<node_t>>(this->max_shm_),
static_cast<predict_params>(this->class_ssp_));
fil::infer_shared_mem_size<storage<node_t>>(static_cast<predict_params>(this->class_ssp_),
this->max_shm_);

// copy must be finished before freeing the host data
h.sync_stream();
h_nodes_.clear();
Expand Down Expand Up @@ -491,8 +472,8 @@ struct sparse_forest : forest<typename node_t::real_type> {
nodes_.data(), nodes, sizeof(node_t) * num_nodes_, cudaMemcpyHostToDevice, h.get_stream()));

// predict_proba is a runtime parameter, and opt-in is unconditional
dispatch_on_fil_template_params(opt_into_arch_dependent_shmem<storage<node_t>>(this->max_shm_),
static_cast<predict_params>(this->class_ssp_));
fil::infer_shared_mem_size<storage<node_t>>(static_cast<predict_params>(this->class_ssp_),
this->max_shm_);
}

virtual void infer(predict_params params, cudaStream_t stream) override
Expand Down
32 changes: 32 additions & 0 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -908,12 +908,38 @@ struct infer_k_storage_template : dispatch_functor<void> {
}
};

template <typename storage_type>
struct opt_into_arch_dependent_shmem : dispatch_functor<void> {
const int max_shm;
opt_into_arch_dependent_shmem(int max_shm_) : max_shm(max_shm_) {}

template <typename KernelParams = KernelTemplateParams<>>
void run(predict_params p)
{
auto kernel = infer_k<KernelParams::N_ITEMS,
KernelParams::LEAF_ALGO,
KernelParams::COLS_IN_SHMEM,
KernelParams::CATS_SUPPORTED,
storage_type>;
// p.shm_sz might be > max_shm or < MAX_SHM_STD, but we should not check for either, because
// we don't run on both proba_ssp_ and class_ssp_ (only class_ssp_). This should be quick.
RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shm));
}
};

template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream)
{
dispatch_on_fil_template_params(infer_k_storage_template<storage_type>(forest, stream), params);
}

template <typename storage_type>
void infer_shared_mem_size(predict_params params, int max_shm)
{
dispatch_on_fil_template_params(opt_into_arch_dependent_shmem<storage_type>(max_shm), params);
}

template void infer<dense_storage_f32>(dense_storage_f32 forest,
predict_params params,
cudaStream_t stream);
Expand All @@ -930,5 +956,11 @@ template void infer<sparse_storage8>(sparse_storage8 forest,
predict_params params,
cudaStream_t stream);

template void infer_shared_mem_size<dense_storage_f32>(predict_params params, int max_shm);
template void infer_shared_mem_size<dense_storage_f64>(predict_params params, int max_shm);
template void infer_shared_mem_size<sparse_storage16_f32>(predict_params params, int max_shm);
template void infer_shared_mem_size<sparse_storage16_f64>(predict_params params, int max_shm);
template void infer_shared_mem_size<sparse_storage8>(predict_params params, int max_shm);

} // namespace fil
} // namespace ML

0 comments on commit f677791

Please sign in to comment.