Skip to content

Commit

Permalink
avoid creating function that compile all types and be more explicit w…
Browse files Browse the repository at this point in the history
…hen naming
  • Loading branch information
jnke2016 committed Nov 7, 2024
1 parent 114bf56 commit 63a59ca
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 71 deletions.
24 changes: 0 additions & 24 deletions cpp/include/cugraph/detail/shuffle_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,30 +142,6 @@ shuffle_ext_vertex_value_pairs_to_local_gpu_by_vertex_partitioning(
rmm::device_uvector<vertex_t>&& vertices,
rmm::device_uvector<value_t>&& values);

/**
* @brief Shuffle external (i.e. before renumbering) vertex & values pairs to their local GPU based
* on vertex partitioning.
*
* @tparam vertex_t Type of vertex identifiers. Needs to be an integral type.
* @tparam value_t Type of values.
*
* @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
* @param[in] vertices Vertices to shuffle.
* @param[in] values_0 First values to shuffle.
* @param[in] values_1 Second values to shuffle.
*
* @return Tuple of vectors storing shuffled vertex & value pairs.
*/
template <typename vertex_t, typename value0_t, typename value1_t>
std::
tuple<rmm::device_uvector<vertex_t>, rmm::device_uvector<value0_t>, rmm::device_uvector<value1_t>>
shuffle_ext_vertex_values_pairs_to_local_gpu_by_vertex_partitioning(
raft::handle_t const& handle,
rmm::device_uvector<vertex_t>&& vertices,
rmm::device_uvector<value0_t>&& values_0,
rmm::device_uvector<value1_t>&& values_1);

/**
* @brief Permute a range.
*
Expand Down
21 changes: 10 additions & 11 deletions cpp/include/cugraph/detail/utility_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,20 @@ void scalar_fill(raft::handle_t const& handle, value_t* d_value, size_t size, va
/**
* @brief Sort a device span
*
* @tparam value_t type of the value to operate on
* @tparam value_t type of the value to operate on. Must be either int32_t or int64_t.
*
* @param [in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
* @param[out] values device span to sort
*
*/
template <typename value_t>
void sort(raft::handle_t const& handle, raft::device_span<value_t> values);
void sort_ints(raft::handle_t const& handle, raft::device_span<value_t> values);

/**
* @brief Keep unique element from a device span
*
* @tparam value_t type of the value to operate on
* @tparam value_t type of the value to operate on. Must be either int32_t or int64_t.
*
* @param [in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
Expand All @@ -90,23 +90,22 @@ void sort(raft::handle_t const& handle, raft::device_span<value_t> values);
*
*/
template <typename value_t>
size_t unique(raft::handle_t const& handle, raft::device_span<value_t> values);
size_t unique_ints(raft::handle_t const& handle, raft::device_span<value_t> values);

/**
* @brief Increment the values of a device span by a constant value
*
* @tparam value_t type of the value to operate on
* @tparam value_t type of the value to operate on. Must be either int32_t or int64_t.
*
* @param [in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator,
* and handles to various CUDA libraries) to run graph algorithms.
* @param[out] values device span to update
* @param[in] value value to be added to each element of the buffer
* @param[in] stream_view stream view
*
*/
template <typename value_t>
void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<value_t> values,
value_t value);
void transform_increment_ints(raft::device_span<value_t> values,
value_t value,
rmm::cuda_stream_view const& stream_view);

/**
* @brief Fill a buffer with a sequence of values
Expand All @@ -116,7 +115,7 @@ void transform_increment(rmm::cuda_stream_view const& stream_view,
*
* Similar to the function std::iota, wraps the function thrust::sequence
*
* @tparam value_t type of the value to operate on
* @tparam value_t type of the value to operate on.
*
* @param[in] stream_view stream view
* @param[out] d_value device array to fill
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,10 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {

// Compute the global start_vertex_label_offsets

cugraph::detail::transform_increment(
handle_.get_stream(),
cugraph::detail::transform_increment_ints(
raft::device_span<label_t>{(*start_vertex_labels).data(), (*start_vertex_labels).size()},
(label_t)global_labels[handle_.get_comms().get_rank()]);
(label_t)global_labels[handle_.get_comms().get_rank()],
handle_.get_stream());
}

if constexpr (multi_gpu) {
Expand All @@ -902,11 +902,11 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {

// Get unique labels
// sort the start_vertex_labels
cugraph::detail::sort(
cugraph::detail::sort_ints(
handle_.get_stream(),
raft::device_span<label_t>{unique_labels.data(), unique_labels.size()});

auto num_unique_labels = cugraph::detail::unique(
auto num_unique_labels = cugraph::detail::unique_ints(
handle_.get_stream(),
raft::device_span<label_t>{unique_labels.data(), unique_labels.size()});

Expand Down
15 changes: 5 additions & 10 deletions cpp/src/detail/utility_wrappers_32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ template void scalar_fill(raft::handle_t const& handle, size_t* d_value, size_t

template void scalar_fill(raft::handle_t const& handle, float* d_value, size_t size, float value);

template void sort(raft::handle_t const& handle, raft::device_span<int32_t> d_span);
template void sort_ints(raft::handle_t const& handle, raft::device_span<int32_t> values);

template size_t unique(raft::handle_t const& handle, raft::device_span<int32_t> d_span);
template size_t unique(raft::handle_t const& handle, raft::device_span<uint32_t> d_span);
template size_t unique_ints(raft::handle_t const& handle, raft::device_span<int32_t> values);

template void sequence_fill(rmm::cuda_stream_view const& stream_view,
int32_t* d_value,
Expand All @@ -78,13 +77,9 @@ template void sequence_fill(rmm::cuda_stream_view const& stream_view,
size_t size,
uint32_t start_value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<int32_t> d_span,
int32_t value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<uint32_t> d_span,
uint32_t value);
template void transform_increment_ints(raft::device_span<int32_t> values,
int32_t value,
rmm::cuda_stream_view const& stream_view);

template void stride_fill(rmm::cuda_stream_view const& stream_view,
int32_t* d_value,
Expand Down
15 changes: 5 additions & 10 deletions cpp/src/detail/utility_wrappers_64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ template void scalar_fill(raft::handle_t const& handle,

template void scalar_fill(raft::handle_t const& handle, double* d_value, size_t size, double value);

template void sort(raft::handle_t const& handle, raft::device_span<int64_t> d_span);
template void sort_ints(raft::handle_t const& handle, raft::device_span<int64_t> values);

template size_t unique(raft::handle_t const& handle, raft::device_span<int64_t> d_span);
template size_t unique(raft::handle_t const& handle, raft::device_span<uint64_t> d_span);
template size_t unique_ints(raft::handle_t const& handle, raft::device_span<int64_t> values);

template void sequence_fill(rmm::cuda_stream_view const& stream_view,
int64_t* d_value,
Expand All @@ -76,13 +75,9 @@ template void sequence_fill(rmm::cuda_stream_view const& stream_view,
size_t size,
uint64_t start_value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<int64_t> d_span,
int64_t value);

template void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<uint64_t> d_span,
uint64_t value);
template void transform_increment_ints(raft::device_span<int64_t> values,
int64_t value,
rmm::cuda_stream_view const& stream_view);

template void stride_fill(rmm::cuda_stream_view const& stream_view,
int64_t* d_value,
Expand Down
22 changes: 11 additions & 11 deletions cpp/src/detail/utility_wrappers_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ void scalar_fill(raft::handle_t const& handle, value_t* d_value, size_t size, va
}

template <typename value_t>
void sort(raft::handle_t const& handle, raft::device_span<value_t> d_span)
void sort_ints(raft::handle_t const& handle, raft::device_span<value_t> values)
{
thrust::sort(handle.get_thrust_policy(), d_span.begin(), d_span.end());
thrust::sort(handle.get_thrust_policy(), values.begin(), values.end());
}

template <typename value_t>
size_t unique(raft::handle_t const& handle, raft::device_span<value_t> d_span)
size_t unique_ints(raft::handle_t const& handle, raft::device_span<value_t> values)
{
auto unique_element_last =
thrust::unique(handle.get_thrust_policy(), d_span.begin(), d_span.end());
return thrust::distance(d_span.begin(), unique_element_last);
thrust::unique(handle.get_thrust_policy(), values.begin(), values.end());
return thrust::distance(values.begin(), unique_element_last);
}

template <typename value_t>
Expand All @@ -88,14 +88,14 @@ void sequence_fill(rmm::cuda_stream_view const& stream_view,
}

template <typename value_t>
void transform_increment(rmm::cuda_stream_view const& stream_view,
raft::device_span<value_t> d_span,
value_t incr)
void transform_increment_ints(raft::device_span<value_t> values,
value_t incr,
rmm::cuda_stream_view const& stream_view)
{
thrust::transform(rmm::exec_policy(stream_view),
d_span.begin(),
d_span.end(),
d_span.begin(),
values.begin(),
values.end(),
values.begin(),
cuda::proclaim_return_type<value_t>([incr] __device__(value_t value) {
return static_cast<value_t>(value + incr);
}));
Expand Down

0 comments on commit 63a59ca

Please sign in to comment.