From 1c9354964facd978fd204e83b2590b461184bf8c Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Thu, 25 Apr 2024 00:23:46 -0700 Subject: [PATCH 01/53] add __host__ to host callable functions --- .../cugraph/edge_partition_device_view.cuh | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/cpp/include/cugraph/edge_partition_device_view.cuh b/cpp/include/cugraph/edge_partition_device_view.cuh index fc19a8f68dd..7c4b74e1ce0 100644 --- a/cpp/include/cugraph/edge_partition_device_view.cuh +++ b/cpp/include/cugraph/edge_partition_device_view.cuh @@ -214,7 +214,7 @@ class edge_partition_device_view_t - size_t compute_number_of_edges(MajorIterator major_first, + __host__ size_t compute_number_of_edges(MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const { @@ -250,7 +250,7 @@ class edge_partition_device_view_t()); } - rmm::device_uvector compute_local_degrees(rmm::cuda_stream_view stream) const + __host__ rmm::device_uvector compute_local_degrees(rmm::cuda_stream_view stream) const { rmm::device_uvector local_degrees(this->major_range_size(), stream); if (dcs_nzd_vertices_) { @@ -277,7 +277,7 @@ class edge_partition_device_view_t - rmm::device_uvector compute_local_degrees(MajorIterator major_first, + __host__ rmm::device_uvector compute_local_degrees(MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const { @@ -306,7 +306,7 @@ class edge_partition_device_view_t - size_t compute_number_of_edges_with_mask(MaskIterator mask_first, + __host__ size_t compute_number_of_edges_with_mask(MaskIterator mask_first, MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const @@ -348,7 +348,7 @@ class edge_partition_device_view_t - rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, + __host__ rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, rmm::cuda_stream_view stream) const { rmm::device_uvector local_degrees(this->major_range_size(), stream); @@ -384,7 +384,7 @@ class edge_partition_device_view_t - rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, + __host__ rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const @@ -553,7 +553,7 @@ class edge_partition_device_view_t - size_t compute_number_of_edges(MajorIterator major_first, + __host__ size_t compute_number_of_edges(MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const { @@ -573,7 +573,7 @@ class edge_partition_device_view_t()); } - rmm::device_uvector compute_local_degrees(rmm::cuda_stream_view stream) const + __host__ rmm::device_uvector compute_local_degrees(rmm::cuda_stream_view stream) const { rmm::device_uvector local_degrees(this->major_range_size(), stream); thrust::transform(rmm::exec_policy(stream), @@ -589,7 +589,7 @@ class edge_partition_device_view_t - rmm::device_uvector compute_local_degrees(MajorIterator major_first, + __host__ rmm::device_uvector compute_local_degrees(MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const { @@ -607,7 +607,7 @@ class edge_partition_device_view_t - size_t compute_number_of_edges_with_mask(MaskIterator mask_first, + __host__ size_t compute_number_of_edges_with_mask(MaskIterator mask_first, MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const @@ -632,7 +632,7 @@ class edge_partition_device_view_t - rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, + __host__ rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, rmm::cuda_stream_view stream) const { rmm::device_uvector local_degrees(this->major_range_size(), stream); @@ -651,7 +651,7 @@ class edge_partition_device_view_t - rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, + __host__ rmm::device_uvector compute_local_degrees_with_mask(MaskIterator mask_first, MajorIterator major_first, MajorIterator major_last, rmm::cuda_stream_view stream) const From 14f194a0b41b10b87ebf5a896af305bf2621d90c Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Thu, 25 Apr 2024 00:24:14 -0700 Subject: [PATCH 02/53] refactor sampling primitive --- .../sample_and_compute_local_nbr_indices.cuh | 1725 +++++++++++++++++ ...r_v_random_select_transform_outgoing_e.cuh | 1177 +---------- 2 files changed, 1829 insertions(+), 1073 deletions(-) create mode 100644 cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh diff --git a/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh b/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh new file mode 100644 index 00000000000..1e15fddb1c7 --- /dev/null +++ b/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh @@ -0,0 +1,1725 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "prims/property_op_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#ifndef NO_CUGRAPH_OPS +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cugraph { + +namespace detail { + +int32_t constexpr per_v_random_select_transform_outgoing_e_block_size = 256; + +size_t constexpr compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold = + packed_bools_per_word() * + size_t{4} /* tuning parameter */; // minimum local degree to compute inclusive sums of valid + // local neighbors per word to accelerate finding n'th local + // neighbor vertex +size_t constexpr compute_valid_local_nbr_count_inclusive_sum_mid_local_degree_threshold = + packed_bools_per_word() * static_cast(raft::warp_size()) * + size_t{4} /* tuning parameter */; // minimum local degree to use a CUDA warp +size_t constexpr compute_valid_local_nbr_count_inclusive_sum_high_local_degree_threshold = + packed_bools_per_word() * + static_cast(per_v_random_select_transform_outgoing_e_block_size) * + size_t{4} /* tuning parameter */; // minimum local degree to use a CUDA block + +template +struct constant_e_bias_op_t { + __device__ float operator()(key_t, + typename GraphViewType::vertex_type, + typename EdgeSrcValueInputWrapper::value_type, + typename EdgeDstValueInputWrapper::value_type, + typename EdgeValueInputWrapper::value_type) const + { + return 1.0; + } +}; + +template +struct compute_local_degree_displacements_and_global_degree_t { + raft::device_span gathered_local_degrees{}; + raft::device_span + partitioned_local_degree_displacements{}; // one partition per gpu in the same minor_comm + raft::device_span global_degrees{}; + int minor_comm_size{}; + + __device__ void operator()(size_t i) const + { + constexpr int buffer_size = 8; // tuning parameter + edge_t displacements[buffer_size]; + edge_t sum{0}; + for (int round = 0; round < (minor_comm_size + buffer_size - 1) / buffer_size; ++round) { + auto loop_count = std::min(buffer_size, minor_comm_size - round * buffer_size); + for (int j = 0; j < loop_count; ++j) { + displacements[j] = sum; + sum += gathered_local_degrees[i + (round * buffer_size + j) * global_degrees.size()]; + } + thrust::copy( + thrust::seq, + displacements, + displacements + loop_count, + partitioned_local_degree_displacements.begin() + i * minor_comm_size + round * buffer_size); + } + global_degrees[i] = sum; + } +}; + +// convert a (neighbor index, key index) pair to a (minor_comm_rank, intra-partition offset, +// neighbor index, key index) quadruplet, minor_comm_rank is set to -1 if an neighbor index is +// invalid +template +struct convert_pair_to_quadruplet_t { + raft::device_span + partitioned_local_degree_displacements{}; // one partition per gpu in the same minor_comm + raft::device_span tx_counts{}; + size_t stride{}; + int minor_comm_size{}; + edge_t invalid_idx{}; + + __device__ thrust::tuple operator()( + thrust::tuple index_pair) const + { + auto nbr_idx = thrust::get<0>(index_pair); + auto key_idx = thrust::get<1>(index_pair); + auto local_nbr_idx = nbr_idx; + int minor_comm_rank{-1}; + size_t intra_partition_offset{}; + if (nbr_idx != invalid_idx) { + auto displacement_first = + partitioned_local_degree_displacements.begin() + key_idx * minor_comm_size; + minor_comm_rank = + static_cast(thrust::distance( + displacement_first, + thrust::upper_bound( + thrust::seq, displacement_first, displacement_first + minor_comm_size, nbr_idx))) - + 1; + local_nbr_idx -= *(displacement_first + minor_comm_rank); + cuda::atomic_ref counter(tx_counts[minor_comm_rank]); + intra_partition_offset = counter.fetch_add(size_t{1}, cuda::std::memory_order_relaxed); + } + return thrust::make_tuple(minor_comm_rank, intra_partition_offset, local_nbr_idx, key_idx); + } +}; + +struct shuffle_index_compute_offset_t { + raft::device_span minor_comm_ranks{}; + raft::device_span intra_partition_displacements{}; + raft::device_span tx_displacements{}; + + __device__ size_t operator()(size_t i) const + { + auto minor_comm_rank = minor_comm_ranks[i]; + assert(minor_comm_rank != -1); + return tx_displacements[minor_comm_rank] + intra_partition_displacements[i]; + } +}; + +template +struct find_nth_valid_nbr_idx_t { + using key_t = typename thrust::iterator_traits::value_type; + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + + edge_partition_device_view_t edge_partition{}; + EdgePartitionEdgeMaskWrapper edge_partition_e_mask; + KeyIterator key_first{}; + thrust::tuple, raft::device_span> + key_valid_local_nbr_count_inclusive_sums{}; + + __device__ edge_t operator()(thrust::tuple pair) const + { + edge_t local_nbr_idx = thrust::get<0>(pair); + size_t key_idx = thrust::get<1>(pair); + auto key = *(key_first + key_idx); + vertex_t major{}; + if constexpr (std::is_same_v) { + major = key; + } else { + major = thrust::get<0>(key); + } + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + vertex_t const* indices{nullptr}; + edge_t edge_offset{0}; + [[maybe_unused]] edge_t local_degree{0}; + if constexpr (GraphViewType::is_multi_gpu) { + auto major_hypersparse_first = edge_partition.major_hypersparse_first(); + if (major_hypersparse_first && (major >= *major_hypersparse_first)) { + auto major_hypersparse_idx = edge_partition.major_hypersparse_idx_from_major_nocheck(major); + if (major_hypersparse_idx) { + thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges( + edge_partition.major_offset_from_major_nocheck(*major_hypersparse_first) + + *major_hypersparse_idx); + } + } else { + thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset); + } + } else { + thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset); + } + + if (local_degree < compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold) { + local_nbr_idx = find_nth_set_bits( + (*edge_partition_e_mask).value_first(), edge_offset, local_degree, local_nbr_idx + 1); + } else { + auto inclusive_sum_first = thrust::get<1>(key_valid_local_nbr_count_inclusive_sums).begin(); + auto start_offset = thrust::get<0>(key_valid_local_nbr_count_inclusive_sums)[key_idx]; + auto end_offset = thrust::get<0>(key_valid_local_nbr_count_inclusive_sums)[key_idx + 1]; + auto word_idx = + static_cast(thrust::distance(inclusive_sum_first + start_offset, + thrust::upper_bound(thrust::seq, + inclusive_sum_first + start_offset, + inclusive_sum_first + end_offset, + local_nbr_idx))); + local_nbr_idx = + word_idx * packed_bools_per_word() + + find_nth_set_bits( + (*edge_partition_e_mask).value_first(), + edge_offset + word_idx * packed_bools_per_word(), + local_degree - word_idx * packed_bools_per_word(), + (local_nbr_idx + 1) - + ((word_idx > 0) ? *(inclusive_sum_first + start_offset + word_idx - 1) : edge_t{0})); + } + return local_nbr_idx; + } +}; + +template +__global__ static void compute_valid_local_nbr_count_inclusive_sums_mid_local_degree( + edge_partition_device_view_t edge_partition, + edge_partition_edge_property_device_view_t edge_partition_e_mask, + raft::device_span edge_partition_frontier_majors, + raft::device_span inclusive_sum_offsets, + raft::device_span frontier_indices, + raft::device_span inclusive_sums) +{ + static_assert(per_v_random_select_transform_outgoing_e_block_size % raft::warp_size() == 0); + + auto const tid = threadIdx.x + blockIdx.x * blockDim.x; + auto const lane_id = tid % raft::warp_size(); + + auto idx = static_cast(tid / raft::warp_size()); + + using WarpScan = cub::WarpScan; + __shared__ typename WarpScan::TempStorage temp_storage; + + while (idx < frontier_indices.size()) { + auto frontier_idx = frontier_indices[idx]; + auto major = edge_partition_frontier_majors[frontier_idx]; + vertex_t major_idx{}; + if constexpr (multi_gpu) { + major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); + } else { + major_idx = edge_partition.major_offset_from_major_nocheck(major); + } + auto edge_offset = edge_partition.local_offset(major_idx); + auto local_degree = edge_partition.local_degree(major_idx); + + auto start_offset = inclusive_sum_offsets[frontier_idx]; + auto end_offset = inclusive_sum_offsets[frontier_idx + 1]; + auto num_inclusive_sums = end_offset - start_offset; + auto rounded_up_num_inclusive_sums = + ((num_inclusive_sums + raft::warp_size() - 1) / raft::warp_size()) * raft::warp_size(); + edge_t sum{0}; + for (size_t j = lane_id; j <= rounded_up_num_inclusive_sums; j += raft::warp_size()) { + auto inc = + (j < num_inclusive_sums) + ? static_cast(count_set_bits( + edge_partition_e_mask.value_first(), + edge_offset + packed_bools_per_word() * j, + cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j))) + : edge_t{0}; + WarpScan(temp_storage).InclusiveSum(inc, inc); + inclusive_sums[start_offset + j] = sum + inc; + sum += __shfl_sync(raft::warp_full_mask(), inc, raft::warp_size() - 1); + } + + idx += gridDim.x * (blockDim.x / raft::warp_size()); + } +} + +template +__global__ static void compute_valid_local_nbr_count_inclusive_sums_high_local_degree( + edge_partition_device_view_t edge_partition, + edge_partition_edge_property_device_view_t edge_partition_e_mask, + raft::device_span edge_partition_frontier_majors, + raft::device_span inclusive_sum_offsets, + raft::device_span frontier_indices, + raft::device_span inclusive_sums) +{ + static_assert(per_v_random_select_transform_outgoing_e_block_size % raft::warp_size() == 0); + + auto idx = static_cast(blockIdx.x); + + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + __shared__ edge_t sum; + + while (idx < frontier_indices.size()) { + auto frontier_idx = frontier_indices[idx]; + auto major = edge_partition_frontier_majors[frontier_idx]; + vertex_t major_idx{}; + if constexpr (multi_gpu) { + major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); + } else { + major_idx = edge_partition.major_offset_from_major_nocheck(major); + } + auto edge_offset = edge_partition.local_offset(major_idx); + auto local_degree = edge_partition.local_degree(major_idx); + + auto start_offset = inclusive_sum_offsets[frontier_idx]; + auto end_offset = inclusive_sum_offsets[frontier_idx + 1]; + auto num_inclusive_sums = end_offset - start_offset; + auto rounded_up_num_inclusive_sums = + ((num_inclusive_sums + per_v_random_select_transform_outgoing_e_block_size - 1) / + per_v_random_select_transform_outgoing_e_block_size) * + per_v_random_select_transform_outgoing_e_block_size; + if (threadIdx.x == per_v_random_select_transform_outgoing_e_block_size - 1) { sum = 0; } + for (size_t j = threadIdx.x; j <= rounded_up_num_inclusive_sums; j += blockDim.x) { + auto inc = + (j < num_inclusive_sums) + ? static_cast(count_set_bits( + edge_partition_e_mask.value_first(), + edge_offset + packed_bools_per_word() * j, + cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j))) + : edge_t{0}; + BlockScan(temp_storage).InclusiveSum(inc, inc); + inclusive_sums[start_offset + j] = sum + inc; + __syncthreads(); + if (threadIdx.x == per_v_random_select_transform_outgoing_e_block_size - 1) { sum += inc; } + } + + idx += gridDim.x; + } +} + +// divide the frontier to three partitions, the low_degree partition has vertices with degreee in +// [min_low_partition_degree_threshold, min_mid_partition_degree_threshold), the medium degree +// partition has vertices with degree in [min_mid_partition_degree_threshold, +// min_high_partition_degree_threshold), and the high degree partition has vertices with degree in +// [min_high_partition_degree_threshold, infinite). +template +std::tuple, std::vector /* size = 3 (# partitions) + 1 */> +partition_frontier(raft::handle_t const& handle, + raft::device_span frontier_degrees, + edge_t min_low_partition_degree_threshold, + edge_t min_mid_partition_degree_threshold, + edge_t min_high_partition_degree_threshold) +{ + size_t constexpr num_partitions = 3; // low, mid, high + std::vector offsets(num_partitions + 1); + offsets[0] = size_t{0}; + + rmm::device_uvector indices(frontier_degrees.size(), handle.get_stream()); + indices.resize( + thrust::distance(indices.begin(), + thrust::copy_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(frontier_degrees.size()), + frontier_degrees.begin(), + indices.begin(), + [threshold = min_low_partition_degree_threshold] __device__( + edge_t degree) { return degree >= threshold; })), + handle.get_stream()); + + auto mid_first = + thrust::partition(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + [frontier_degrees, threshold = min_mid_partition_degree_threshold] __device__( + auto idx) { return frontier_degrees[idx] < threshold; }); + offsets[1] = static_cast(thrust::distance(indices.begin(), mid_first)); + auto high_first = thrust::partition( + handle.get_thrust_policy(), + mid_first, + indices.end(), + [frontier_degrees, threshold = min_high_partition_degree_threshold] __device__(auto idx) { + return frontier_degrees[idx] < threshold; + }); + offsets[2] = static_cast(thrust::distance(indices.begin(), high_first)); + offsets[3] = indices.size(); + + return std::make_tuple(std::move(indices), std::move(offsets)); +} + +template +std::tuple, rmm::device_uvector> +compute_valid_local_nbr_count_inclusive_sums( + raft::handle_t const& handle, + edge_partition_device_view_t const& edge_partition, + edge_partition_edge_property_device_view_t const& + edge_partition_e_mask, + raft::device_span edge_partition_frontier_majors) +{ + auto edge_partition_local_degrees = + edge_partition.compute_local_degrees(edge_partition_frontier_majors.begin(), + edge_partition_frontier_majors.end(), + handle.get_stream()); + auto inclusive_sum_offsets = + rmm::device_uvector(edge_partition_frontier_majors.size() + 1, handle.get_stream()); + inclusive_sum_offsets.set_element_to_zero_async(0, handle.get_stream()); + auto size_first = thrust::make_transform_iterator( + edge_partition_local_degrees.begin(), + cuda::proclaim_return_type([] __device__(edge_t local_degree) { + return static_cast((local_degree + packed_bools_per_word() - 1) / + packed_bools_per_word()); + })); + thrust::inclusive_scan(handle.get_thrust_policy(), + size_first, + size_first + edge_partition_local_degrees.size(), + inclusive_sum_offsets.begin() + 1); + + auto [edge_partition_frontier_indices, frontier_partition_offsets] = partition_frontier( + handle, + raft::device_span(edge_partition_local_degrees.data(), + edge_partition_local_degrees.size()), + static_cast(compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold), + static_cast(compute_valid_local_nbr_count_inclusive_sum_mid_local_degree_threshold), + static_cast(compute_valid_local_nbr_count_inclusive_sum_high_local_degree_threshold)); + + rmm::device_uvector inclusive_sums( + inclusive_sum_offsets.back_element(handle.get_stream()), handle.get_stream()); + + thrust::for_each( + handle.get_thrust_policy(), + edge_partition_frontier_indices.begin(), + edge_partition_frontier_indices.begin() + frontier_partition_offsets[1], + [edge_partition, + edge_partition_e_mask, + edge_partition_frontier_majors, + inclusive_sum_offsets = + raft::device_span(inclusive_sum_offsets.data(), inclusive_sum_offsets.size()), + inclusive_sums = raft::device_span(inclusive_sums.data(), + inclusive_sums.size())] __device__(size_t i) { + auto major = edge_partition_frontier_majors[i]; + vertex_t major_idx{}; + if constexpr (multi_gpu) { + major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); + } else { + major_idx = edge_partition.major_offset_from_major_nocheck(major); + } + auto edge_offset = edge_partition.local_offset(major_idx); + auto local_degree = edge_partition.local_degree(major_idx); + edge_t sum{0}; + auto start_offset = inclusive_sum_offsets[i]; + auto end_offset = inclusive_sum_offsets[i + 1]; + for (size_t j = 0; j < end_offset - start_offset; ++j) { + sum += count_set_bits( + edge_partition_e_mask.value_first(), + edge_offset + packed_bools_per_word() * j, + cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j)); + inclusive_sums[start_offset + j] = sum; + } + }); + + auto mid_partition_size = frontier_partition_offsets[2] - frontier_partition_offsets[1]; + if (mid_partition_size > 0) { + raft::grid_1d_warp_t update_grid(mid_partition_size, + per_v_random_select_transform_outgoing_e_block_size, + handle.get_device_properties().maxGridSize[0]); + compute_valid_local_nbr_count_inclusive_sums_mid_local_degree<<>>( + edge_partition, + edge_partition_e_mask, + edge_partition_frontier_majors, + raft::device_span(inclusive_sum_offsets.data(), inclusive_sum_offsets.size()), + raft::device_span( + edge_partition_frontier_indices.data() + frontier_partition_offsets[1], + frontier_partition_offsets[2] - frontier_partition_offsets[1]), + raft::device_span(inclusive_sums.data(), inclusive_sums.size())); + } + + auto high_partition_size = frontier_partition_offsets[3] - frontier_partition_offsets[2]; + if (high_partition_size > 0) { + raft::grid_1d_block_t update_grid(high_partition_size, + per_v_random_select_transform_outgoing_e_block_size, + handle.get_device_properties().maxGridSize[0]); + compute_valid_local_nbr_count_inclusive_sums_high_local_degree<<>>( + edge_partition, + edge_partition_e_mask, + edge_partition_frontier_majors, + raft::device_span(inclusive_sum_offsets.data(), inclusive_sum_offsets.size()), + raft::device_span( + edge_partition_frontier_indices.data() + frontier_partition_offsets[2], + frontier_partition_offsets[3] - frontier_partition_offsets[2]), + raft::device_span(inclusive_sums.data(), inclusive_sums.size())); + } + + return std::make_tuple(std::move(inclusive_sum_offsets), std::move(inclusive_sums)); +} + +template +rmm::device_uvector get_sampling_index_without_replacement( + raft::handle_t const& handle, + rmm::device_uvector&& frontier_degrees, + raft::random::RngState& rng_state, + size_t K) +{ +#ifndef NO_CUGRAPH_OPS + edge_t mid_partition_degree_range_last = static_cast(K * 10); // tuning parameter + assert(mid_partition_degree_range_last > K); + size_t high_partition_oversampling_K = K * 2; // tuning parameter + assert(high_partition_oversampling_K > K); + + auto [frontier_indices, frontier_partition_offsets] = partition_frontier( + handle, + raft::device_span(frontier_degrees.data(), frontier_degrees.size()), + edge_t{0}, + static_cast(K + 1), + mid_partition_degree_range_last + 1); + + rmm::device_uvector sample_nbr_indices(frontier_degrees.size() * K, handle.get_stream()); + + auto low_partition_size = frontier_partition_offsets[1]; + if (low_partition_size > 0) { + thrust::for_each(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(low_partition_size * K), + [K, + frontier_index_first = frontier_indices.begin(), + frontier_degrees = raft::device_span(frontier_degrees.data(), + frontier_degrees.size()), + sample_nbr_indices = raft::device_span(sample_nbr_indices.data(), + sample_nbr_indices.size()), + invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { + auto frontier_idx = *(frontier_index_first + i); + auto degree = frontier_degrees[frontier_idx]; + auto sample_idx = static_cast(i % K); + sample_nbr_indices[frontier_idx * K + sample_idx] = + (sample_idx < degree) ? sample_idx : invalid_idx; + }); + } + + auto mid_partition_size = frontier_partition_offsets[2] - frontier_partition_offsets[1]; + if (mid_partition_size > 0) { + // FIXME: tmp_degrees & tmp_sample_nbr_indices can be avoided if we customize + // cugraph::ops::get_sampling_index + rmm::device_uvector tmp_degrees(mid_partition_size, handle.get_stream()); + rmm::device_uvector tmp_sample_nbr_indices(mid_partition_size * K, handle.get_stream()); + thrust::gather(handle.get_thrust_policy(), + frontier_indices.begin() + frontier_partition_offsets[1], + frontier_indices.begin() + frontier_partition_offsets[2], + frontier_degrees.begin(), + tmp_degrees.begin()); + cugraph::ops::graph::get_sampling_index(tmp_sample_nbr_indices.data(), + rng_state, + tmp_degrees.data(), + mid_partition_size, + static_cast(K), + false, + handle.get_stream()); + thrust::for_each(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(mid_partition_size * K), + [K, + seed_index_first = frontier_indices.begin() + frontier_partition_offsets[1], + tmp_sample_nbr_indices = tmp_sample_nbr_indices.data(), + sample_nbr_indices = sample_nbr_indices.data()] __device__(size_t i) { + auto seed_idx = *(seed_index_first + i / K); + auto sample_idx = static_cast(i % K); + sample_nbr_indices[seed_idx * K + sample_idx] = tmp_sample_nbr_indices[i]; + }); + } + + auto high_partition_size = frontier_partition_offsets[3] - frontier_partition_offsets[2]; + if (high_partition_size > 0) { + // to limit memory footprint ((1 << 20) is a tuning parameter), std::max for forward progress + // guarantee when high_partition_oversampling_K is exorbitantly large + auto seeds_to_sort_per_iteration = + std::max(static_cast(handle.get_device_properties().multiProcessorCount * (1 << 20)) / + high_partition_oversampling_K, + size_t{1}); + + rmm::device_uvector tmp_sample_nbr_indices( + seeds_to_sort_per_iteration * high_partition_oversampling_K, handle.get_stream()); + assert(high_partition_oversampling_K * 2 <= + static_cast(std::numeric_limits::max())); + rmm::device_uvector tmp_sample_indices( + tmp_sample_nbr_indices.size(), + handle.get_stream()); // sample indices ([0, high_partition_oversampling_K)) within a segment + // (one segment per seed) + + rmm::device_uvector segment_sorted_tmp_sample_nbr_indices(tmp_sample_nbr_indices.size(), + handle.get_stream()); + rmm::device_uvector segment_sorted_tmp_sample_indices(tmp_sample_nbr_indices.size(), + handle.get_stream()); + + rmm::device_uvector d_tmp_storage(0, handle.get_stream()); + size_t tmp_storage_bytes{0}; + + auto num_chunks = + (high_partition_size + seeds_to_sort_per_iteration - 1) / seeds_to_sort_per_iteration; + for (size_t i = 0; i < num_chunks; ++i) { + size_t num_segments = std::min(seeds_to_sort_per_iteration, + high_partition_size - seeds_to_sort_per_iteration * i); + + rmm::device_uvector unique_counts(num_segments, handle.get_stream()); + + std::optional> retry_segment_indices{std::nullopt}; + std::optional> retry_degrees{std::nullopt}; + std::optional> retry_sample_nbr_indices{std::nullopt}; + std::optional> retry_sample_indices{std::nullopt}; + std::optional> retry_segment_sorted_sample_nbr_indices{ + std::nullopt}; + std::optional> retry_segment_sorted_sample_indices{std::nullopt}; + while (true) { + auto segment_frontier_index_first = frontier_indices.begin() + + frontier_partition_offsets[2] + + seeds_to_sort_per_iteration * i; + auto segment_frontier_degree_first = thrust::make_transform_iterator( + segment_frontier_index_first, + indirection_t{frontier_degrees.begin()}); + + if (retry_segment_indices) { + retry_degrees = + rmm::device_uvector((*retry_segment_indices).size(), handle.get_stream()); + thrust::gather(handle.get_thrust_policy(), + (*retry_segment_indices).begin(), + (*retry_segment_indices).end(), + segment_frontier_degree_first, + (*retry_degrees).begin()); + retry_sample_nbr_indices = rmm::device_uvector( + (*retry_segment_indices).size() * high_partition_oversampling_K, handle.get_stream()); + retry_sample_indices = + rmm::device_uvector((*retry_sample_nbr_indices).size(), handle.get_stream()); + retry_segment_sorted_sample_nbr_indices = + rmm::device_uvector((*retry_sample_nbr_indices).size(), handle.get_stream()); + retry_segment_sorted_sample_indices = + rmm::device_uvector((*retry_sample_nbr_indices).size(), handle.get_stream()); + } + + if (retry_segment_indices) { + cugraph::ops::graph::get_sampling_index( + (*retry_sample_nbr_indices).data(), + rng_state, + (*retry_degrees).begin(), + (*retry_degrees).size(), + static_cast(high_partition_oversampling_K), + true, + handle.get_stream()); + } else { + // FIXME: this temporary is unnecessary if we update get_sampling_index to take a thrust + // iterator + rmm::device_uvector tmp_degrees(num_segments, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + segment_frontier_degree_first, + segment_frontier_degree_first + num_segments, + tmp_degrees.begin()); + cugraph::ops::graph::get_sampling_index( + tmp_sample_nbr_indices.data(), + rng_state, + tmp_degrees.data(), + num_segments, + static_cast(high_partition_oversampling_K), + true, + handle.get_stream()); + } + + if (retry_segment_indices) { + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator((*retry_segment_indices).size() * + high_partition_oversampling_K), + [high_partition_oversampling_K, + unique_counts = unique_counts.data(), + segment_sorted_tmp_sample_nbr_indices = segment_sorted_tmp_sample_nbr_indices.data(), + retry_segment_indices = (*retry_segment_indices).data(), + retry_sample_nbr_indices = (*retry_sample_nbr_indices).data(), + retry_sample_indices = (*retry_sample_indices).data()] __device__(size_t i) { + auto segment_idx = retry_segment_indices[i / high_partition_oversampling_K]; + auto sample_idx = static_cast(i % high_partition_oversampling_K); + auto unique_count = unique_counts[segment_idx]; + auto output_first = thrust::make_zip_iterator( + thrust::make_tuple(retry_sample_nbr_indices, retry_sample_indices)); + // sample index for the previously selected neighbor indices should be smaller than + // the new candidates to ensure that the previously selected neighbor indices will be + // selected again + if (sample_idx < unique_count) { + *(output_first + i) = + thrust::make_tuple(segment_sorted_tmp_sample_nbr_indices + [segment_idx * high_partition_oversampling_K + sample_idx], + static_cast(sample_idx)); + } else { + *(output_first + i) = + thrust::make_tuple(retry_sample_nbr_indices[i], + high_partition_oversampling_K + (sample_idx - unique_count)); + } + }); + } else { + thrust::tabulate( + handle.get_thrust_policy(), + tmp_sample_indices.begin(), + tmp_sample_indices.begin() + num_segments * high_partition_oversampling_K, + [high_partition_oversampling_K] __device__(size_t i) { + return static_cast(i % high_partition_oversampling_K); + }); + } + + // sort the (sample neighbor index, sample index) pairs (key: sample neighbor index) + + cub::DeviceSegmentedSort::SortPairs( + static_cast(nullptr), + tmp_storage_bytes, + retry_segment_indices ? (*retry_sample_nbr_indices).data() + : tmp_sample_nbr_indices.data(), + retry_segment_indices ? (*retry_segment_sorted_sample_nbr_indices).data() + : segment_sorted_tmp_sample_nbr_indices.data(), + retry_segment_indices ? (*retry_sample_indices).data() : tmp_sample_indices.data(), + retry_segment_indices ? (*retry_segment_sorted_sample_indices).data() + : segment_sorted_tmp_sample_indices.data(), + (retry_segment_indices ? (*retry_segment_indices).size() : num_segments) * + high_partition_oversampling_K, + retry_segment_indices ? (*retry_segment_indices).size() : num_segments, + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + multiplier_t{high_partition_oversampling_K}), + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{1}), + multiplier_t{high_partition_oversampling_K}), + handle.get_stream()); + if (tmp_storage_bytes > d_tmp_storage.size()) { + d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); + } + cub::DeviceSegmentedSort::SortPairs( + d_tmp_storage.data(), + tmp_storage_bytes, + retry_segment_indices ? (*retry_sample_nbr_indices).data() + : tmp_sample_nbr_indices.data(), + retry_segment_indices ? (*retry_segment_sorted_sample_nbr_indices).data() + : segment_sorted_tmp_sample_nbr_indices.data(), + retry_segment_indices ? (*retry_sample_indices).data() : tmp_sample_indices.data(), + retry_segment_indices ? (*retry_segment_sorted_sample_indices).data() + : segment_sorted_tmp_sample_indices.data(), + (retry_segment_indices ? (*retry_segment_indices).size() : num_segments) * + high_partition_oversampling_K, + retry_segment_indices ? (*retry_segment_indices).size() : num_segments, + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + multiplier_t{high_partition_oversampling_K}), + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{1}), + multiplier_t{high_partition_oversampling_K}), + handle.get_stream()); + + // count the number of unique neighbor indices + + if (retry_segment_indices) { + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator((*retry_segment_indices).size()), + [high_partition_oversampling_K, + unique_counts = unique_counts.data(), + retry_segment_indices = (*retry_segment_indices).data(), + retry_segment_sorted_pair_first = thrust::make_zip_iterator( + thrust::make_tuple((*retry_segment_sorted_sample_nbr_indices).begin(), + (*retry_segment_sorted_sample_indices).begin())), + segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple( + segment_sorted_tmp_sample_nbr_indices.begin(), + segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) { + auto unique_count = static_cast(thrust::distance( + retry_segment_sorted_pair_first + high_partition_oversampling_K * i, + thrust::unique( + thrust::seq, + retry_segment_sorted_pair_first + high_partition_oversampling_K * i, + retry_segment_sorted_pair_first + high_partition_oversampling_K * (i + 1), + [] __device__(auto lhs, auto rhs) { + return thrust::get<0>(lhs) == thrust::get<0>(rhs); + }))); + auto segment_idx = retry_segment_indices[i]; + unique_counts[segment_idx] = unique_count; + thrust::copy( + thrust::seq, + retry_segment_sorted_pair_first + high_partition_oversampling_K * i, + retry_segment_sorted_pair_first + high_partition_oversampling_K * i + unique_count, + segment_sorted_pair_first + high_partition_oversampling_K * segment_idx); + }); + } else { + thrust::tabulate( + handle.get_thrust_policy(), + unique_counts.begin(), + unique_counts.end(), + [high_partition_oversampling_K, + segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple( + segment_sorted_tmp_sample_nbr_indices.begin(), + segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) { + return static_cast(thrust::distance( + segment_sorted_pair_first + high_partition_oversampling_K * i, + thrust::unique(thrust::seq, + segment_sorted_pair_first + high_partition_oversampling_K * i, + segment_sorted_pair_first + high_partition_oversampling_K * (i + 1), + [] __device__(auto lhs, auto rhs) { + return thrust::get<0>(lhs) == thrust::get<0>(rhs); + }))); + }); + } + + auto num_retry_segments = + thrust::count_if(handle.get_thrust_policy(), + unique_counts.begin(), + unique_counts.end(), + [K] __device__(auto count) { return count < K; }); + if (num_retry_segments > 0) { + retry_segment_indices = + rmm::device_uvector(num_retry_segments, handle.get_stream()); + thrust::copy_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(num_segments), + (*retry_segment_indices).begin(), + [K, unique_counts = unique_counts.data()] __device__(size_t i) { + return unique_counts[i] < K; + }); + } else { + break; + } + } + + // sort the segment-sorted (sample index, sample neighbor index) pairs (key: sample index) + + cub::DeviceSegmentedSort::SortPairs( + static_cast(nullptr), + tmp_storage_bytes, + segment_sorted_tmp_sample_indices.data(), + tmp_sample_indices.data(), + segment_sorted_tmp_sample_nbr_indices.data(), + tmp_sample_nbr_indices.data(), + num_segments * high_partition_oversampling_K, + num_segments, + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + multiplier_t{high_partition_oversampling_K}), + thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type( + [high_partition_oversampling_K, unique_counts = unique_counts.data()] __device__( + size_t i) { return i * high_partition_oversampling_K + unique_counts[i]; })), + handle.get_stream()); + if (tmp_storage_bytes > d_tmp_storage.size()) { + d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); + } + cub::DeviceSegmentedSort::SortPairs( + d_tmp_storage.data(), + tmp_storage_bytes, + segment_sorted_tmp_sample_indices.data(), + tmp_sample_indices.data(), + segment_sorted_tmp_sample_nbr_indices.data(), + tmp_sample_nbr_indices.data(), + num_segments * high_partition_oversampling_K, + num_segments, + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + multiplier_t{high_partition_oversampling_K}), + thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type( + [high_partition_oversampling_K, unique_counts = unique_counts.data()] __device__( + size_t i) { return i * high_partition_oversampling_K + unique_counts[i]; })), + handle.get_stream()); + + // copy the neighbor indices back to sample_nbr_indices + + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(num_segments * K), + [K, + high_partition_oversampling_K, + frontier_indices = frontier_indices.begin() + frontier_partition_offsets[2] + + seeds_to_sort_per_iteration * i, + tmp_sample_nbr_indices = tmp_sample_nbr_indices.data(), + sample_nbr_indices = sample_nbr_indices.data()] __device__(size_t i) { + auto seed_idx = *(frontier_indices + i / K); + auto sample_idx = static_cast(i % K); + *(sample_nbr_indices + seed_idx * K + sample_idx) = + *(tmp_sample_nbr_indices + (i / K) * high_partition_oversampling_K + sample_idx); + }); + } + } + + frontier_degrees.resize(0, handle.get_stream()); + frontier_degrees.shrink_to_fit(handle.get_stream()); + + return sample_nbr_indices; +#else + CUGRAPH_FAIL("unimplemented."); +#endif +} + +template +std::tuple, + std::optional>, + std::vector> +uniform_sample_and_compute_local_nbr_indices( + raft::handle_t const& handle, + GraphViewType const& graph_view, + VertexFrontierBucketType const& frontier, + AggregateLocalFrontierBuffer const& aggregate_local_frontier, + std::vector const& local_frontier_displacements, + std::vector const& local_frontier_sizes, + raft::random::RngState& rng_state, + size_t K, + bool with_replacement) +{ + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + using key_t = typename VertexFrontierBucketType::key_type; + + auto minor_comm_size = + GraphViewType::is_multi_gpu + ? handle.get_subcomm(cugraph::partition_manager::minor_comm_name()).get_size() + : int{1}; + assert(minor_comm_size == graph_view.number_of_local_edge_partitions()); + + auto edge_mask_view = graph_view.edge_mask_view(); + + // 1. compute degrees + + rmm::device_uvector frontier_degrees(0, handle.get_stream()); + auto frontier_partitioned_local_degree_displacements = + (minor_comm_size > 1) + ? std::make_optional>(size_t{0}, handle.get_stream()) + : std::nullopt; // one partition per gpu in the same minor_comm + + std::optional, rmm::device_uvector>>> + local_frontier_valid_local_nbr_count_inclusive_sums{}; // to avoid searching the entire + // neighbor list K times for high degree + // vertices with edge masking + if (edge_mask_view) { + local_frontier_valid_local_nbr_count_inclusive_sums = + std::vector, rmm::device_uvector>>{}; + (*local_frontier_valid_local_nbr_count_inclusive_sums) + .reserve(graph_view.number_of_local_edge_partitions()); + } + + { + auto aggregate_local_frontier_local_degrees = + (minor_comm_size > 1) + ? std::make_optional>( + local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) + : std::nullopt; + + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + auto edge_partition = + edge_partition_device_view_t( + graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; + + vertex_t const* edge_partition_frontier_major_first{nullptr}; + + auto edge_partition_frontier_key_first = + ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier) + : frontier.begin()) + + local_frontier_displacements[i]; + if constexpr (std::is_same_v) { + edge_partition_frontier_major_first = edge_partition_frontier_key_first; + } else { + edge_partition_frontier_major_first = thrust::get<0>(edge_partition_frontier_key_first); + } + + auto edge_partition_frontier_local_degrees = + edge_partition_e_mask ? edge_partition.compute_local_degrees_with_mask( + (*edge_partition_e_mask).value_first(), + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()) + : edge_partition.compute_local_degrees( + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()); + + if (minor_comm_size > 1) { + // FIXME: this copy is unnecessary if edge_partition.compute_local_degrees() takes a pointer + // to the output array + thrust::copy( + handle.get_thrust_policy(), + edge_partition_frontier_local_degrees.begin(), + edge_partition_frontier_local_degrees.end(), + (*aggregate_local_frontier_local_degrees).begin() + local_frontier_displacements[i]); + } else { + frontier_degrees = std::move(edge_partition_frontier_local_degrees); + } + + if (edge_partition_e_mask) { + (*local_frontier_valid_local_nbr_count_inclusive_sums) + .push_back(compute_valid_local_nbr_count_inclusive_sums( + handle, + edge_partition, + *edge_partition_e_mask, + raft::device_span(edge_partition_frontier_major_first, + local_frontier_sizes[i]))); + } + } + + if (minor_comm_size > 1) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + rmm::device_uvector frontier_gathered_local_degrees(0, handle.get_stream()); + std::tie(frontier_gathered_local_degrees, std::ignore) = + shuffle_values(minor_comm, + (*aggregate_local_frontier_local_degrees).begin(), + local_frontier_sizes, + handle.get_stream()); + aggregate_local_frontier_local_degrees = std::nullopt; + + frontier_degrees.resize(frontier.size(), handle.get_stream()); + frontier_partitioned_local_degree_displacements = + rmm::device_uvector(frontier_degrees.size() * minor_comm_size, handle.get_stream()); + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(frontier_degrees.size()), + compute_local_degree_displacements_and_global_degree_t{ + raft::device_span(frontier_gathered_local_degrees.data(), + frontier_gathered_local_degrees.size()), + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), + raft::device_span(frontier_degrees.data(), frontier_degrees.size()), + minor_comm_size}); + } + } + + // 2. sample neighbor indices + + rmm::device_uvector sample_nbr_indices(0, handle.get_stream()); + + if (with_replacement) { + if (frontier_degrees.size() > 0) { + sample_nbr_indices.resize(frontier.size() * K, handle.get_stream()); + cugraph::ops::graph::get_sampling_index(sample_nbr_indices.data(), + rng_state, + frontier_degrees.data(), + static_cast(frontier_degrees.size()), + static_cast(K), + with_replacement, + handle.get_stream()); + frontier_degrees.resize(0, handle.get_stream()); + frontier_degrees.shrink_to_fit(handle.get_stream()); + } + } else { + sample_nbr_indices = + get_sampling_index_without_replacement(handle, std::move(frontier_degrees), rng_state, K); + } + + // 3. shuffle neighbor indices + + auto sample_local_nbr_indices = std::move( + sample_nbr_indices); // neighbor index within an edge partition (note that each vertex's + // neighbors are distributed in minor_comm_size partitions) + std::optional> sample_key_indices{ + std::nullopt}; // relevant only when (minor_comm_size > 1) + std::vector local_frontier_sample_offsets{}; + if (minor_comm_size > 1) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + sample_key_indices = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + auto minor_comm_ranks = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + auto intra_partition_displacements = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + rmm::device_uvector d_tx_counts(minor_comm_size, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), size_t{0}); + auto input_pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + divider_t{K}))); + thrust::transform( + handle.get_thrust_policy(), + input_pair_first, + input_pair_first + sample_local_nbr_indices.size(), + thrust::make_zip_iterator(thrust::make_tuple(minor_comm_ranks.begin(), + intra_partition_displacements.begin(), + sample_local_nbr_indices.begin(), + (*sample_key_indices).begin())), + convert_pair_to_quadruplet_t{ + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), + raft::device_span(d_tx_counts.data(), d_tx_counts.size()), + frontier.size(), + minor_comm_size, + cugraph::ops::graph::INVALID_ID}); + rmm::device_uvector tx_displacements(minor_comm_size, handle.get_stream()); + thrust::exclusive_scan( + handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), tx_displacements.begin()); + auto tmp_sample_local_nbr_indices = + rmm::device_uvector(tx_displacements.back_element(handle.get_stream()) + + d_tx_counts.back_element(handle.get_stream()), + handle.get_stream()); + auto tmp_sample_key_indices = + rmm::device_uvector(tmp_sample_local_nbr_indices.size(), handle.get_stream()); + auto pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); + thrust::scatter_if( + handle.get_thrust_policy(), + pair_first, + pair_first + sample_local_nbr_indices.size(), + thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + shuffle_index_compute_offset_t{ + raft::device_span(minor_comm_ranks.data(), minor_comm_ranks.size()), + raft::device_span(intra_partition_displacements.data(), + intra_partition_displacements.size()), + raft::device_span(tx_displacements.data(), tx_displacements.size())}), + minor_comm_ranks.begin(), + thrust::make_zip_iterator( + thrust::make_tuple(tmp_sample_local_nbr_indices.begin(), tmp_sample_key_indices.begin())), + is_not_equal_t{-1}); + + sample_local_nbr_indices = std::move(tmp_sample_local_nbr_indices); + sample_key_indices = std::move(tmp_sample_key_indices); + + std::vector h_tx_counts(d_tx_counts.size()); + raft::update_host( + h_tx_counts.data(), d_tx_counts.data(), d_tx_counts.size(), handle.get_stream()); + handle.sync_stream(); + + pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); + auto [rx_value_buffer, rx_counts] = + shuffle_values(minor_comm, pair_first, h_tx_counts, handle.get_stream()); + + sample_local_nbr_indices = std::move(std::get<0>(rx_value_buffer)); + sample_key_indices = std::move(std::get<1>(rx_value_buffer)); + local_frontier_sample_offsets = std::vector(rx_counts.size() + 1); + local_frontier_sample_offsets[0] = size_t{0}; + std::inclusive_scan( + rx_counts.begin(), rx_counts.end(), local_frontier_sample_offsets.begin() + 1); + } else { + local_frontier_sample_offsets = std::vector{size_t{0}, frontier.size() * K}; + } + + // 4. convert neighbor indices in the neighbor list considering edge mask to neighbor indices in + // the neighbor list ignoring edge mask + + if (edge_mask_view) { + auto sample_key_idx_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type([K, + sample_key_indices = sample_key_indices + ? thrust::make_optional>( + (*sample_key_indices).data(), (*sample_key_indices).size()) + : thrust::nullopt] __device__(size_t i) { + return sample_key_indices ? (*sample_key_indices)[i] : i / K; + })); + auto pair_first = thrust::make_zip_iterator(sample_local_nbr_indices.begin(), sample_key_idx_first); + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + auto edge_partition = + edge_partition_device_view_t( + graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; + + auto edge_partition_frontier_key_first = + ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier) + : frontier.begin()) + + local_frontier_displacements[i]; + thrust::transform( + handle.get_thrust_policy(), + pair_first + local_frontier_sample_offsets[i], + pair_first + local_frontier_sample_offsets[i + 1], + sample_local_nbr_indices.begin() + local_frontier_sample_offsets[i], + find_nth_valid_nbr_idx_t{ + edge_partition, + edge_partition_e_mask, + edge_partition_frontier_key_first, + thrust::make_tuple( + raft::device_span( + std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), + std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()), + raft::device_span( + std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), + std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()))}); + } + } + + return std::make_tuple(std::move(sample_local_nbr_indices), + std::move(sample_key_indices), + std::move(local_frontier_sample_offsets)); +} + +#if 0 +biased_sampling_nbr_indices(raft::handle_t const& handle GraphViewType const& graph_view, + VertexFrontierBucketType const& frontier + std::vector const& local_frontier_displacements, + std::vector const& local_frontier_sizes, + raft::random::RngState& rng_state, + size_t K, + bool with_replacement) +{ + using bias_t = typename detail::edge_op_result_type::type; + + auto minor_comm_size = + GraphViewType::is_multi_gpu + ? handle.get_subcomm(cugraph::partition_manager::minor_comm_name()).get_size() + : int{1}; + assert(minor_comm_size == graph_view.number_of_local_edge_partitions()); + + auto edge_mask_view = graph_view.edge_mask_view(); + + // 1. compute degrees + + rmm::device_uvector frontier_degrees(0, handle.get_stream()); + auto frontier_partitioned_local_bias_sum_displacements = + (minor_comm_size > 1) + ? std::make_optional>(size_t{0}, handle.get_stream()) + : std::nullopt; // one partition per gpu in the same minor_comm + + { + auto aggregate_local_frontier_local_bias_sums = + (minor_comm_size > 1) + ? std::make_optional>( + local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) + : std::nullopt; + std::vector>; + + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + auto edge_partition = + edge_partition_device_view_t( + graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; + + vertex_t const* edge_partition_frontier_major_first{nullptr}; + + auto edge_partition_frontier_key_first = + ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier) + : frontier.begin()) + + local_frontier_displacements[i]; + if constexpr (std::is_same_v) { + edge_partition_frontier_major_first = edge_partition_frontier_key_first; + } else { + edge_partition_frontier_major_first = thrust::get<0>(edge_partition_frontier_key_first); + } + + auto edge_partition_frontier_local_degrees = edge_partition.compute_local_degrees_with_mask( + (*edge_partition_e_mask).value_first(), + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()); + auto edge_partition_frontier_local_degree_inclusive_sums; + // if wiht_replacment = false && degree <= K skip bias computing? how should I handle bias == + // 0 update documentation for this. + + auto edge_partition_frontier_e_biases; + + if (minor_comm_size > 1) { + // FIXME: this copy is unnecessary if edge_partition.compute_local_degrees() takes a pointer + // to the output array + thrust::copy( + handle.get_thrust_policy(), + edge_partition_frontier_local_degrees.begin(), + edge_partition_frontier_local_degrees.end(), + (*aggregate_local_frontier_local_degrees).begin() + local_frontier_displacements[i]); + } else { + frontier_degrees = std::move(edge_partition_frontier_local_degrees); + } + } + + if (minor_comm_size > 1) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + rmm::device_uvector frontier_gathered_local_degrees(0, handle.get_stream()); + std::tie(frontier_gathered_local_degrees, std::ignore) = + shuffle_values(minor_comm, + (*aggregate_local_frontier_local_degrees).begin(), + local_frontier_sizes, + handle.get_stream()); + aggregate_local_frontier_local_degrees = std::nullopt; + + frontier_dgrees.resize(frontier.size(), handle.get_stream()); + frontier_partitioned_local_degree_displacements = + rmm::device_uvector(frontier_degrees.size() * minor_comm_size, handle.get_stream()); + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(frontier_degrees.size()), + compute_local_degree_displacements_and_global_degree_t{ + raft::device_span(frontier_gathered_local_degrees.data(), + frontier_gathered_local_degrees.size()), + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), + raft::device_span(frontier_degrees.data(), frontier_degrees.size()), + minor_comm_size}); + } + } + + // 2. sample neighbor indices + + rmm::device_uvector sample_nbr_indices(0, handle.get_stream()); + + if (with_replacement) { + // generate random numbers in [0.0, 1.0); + // scale by bias_sum; + // find minor_comm_rank and compute local random number; + } else { + if (degree <= K) { + // thrust::sequence + } + else if (degree < K * minor_comm_size * (status_size_per_K/weight_size) { + // auto bias_first = (); + // gather biases; + // generate indices + } + else { + // locally generatei incdices; gather states; + // generate indices + } + } + + // 3. shuffle neighbor indices + + auto sample_local_nbr_indices = std::move( + sample_nbr_indices); // neighbor index within an edge partition (note that each vertex's + // neighbors are distributed in minor_comm_size partitions) + std::optional> sample_key_indices{ + std::nullopt}; // relevant only when (minor_comm_size > 1) + std::vector local_frontier_sample_offsets{}; + if (minor_comm_size > 1) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + sample_key_indices = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + auto minor_comm_ranks = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + auto intra_partition_displacements = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + rmm::device_uvector d_tx_counts(minor_comm_size, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), size_t{0}); + auto input_pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + divider_t{K}))); + thrust::transform( + handle.get_thrust_policy(), + input_pair_first, + input_pair_first + sample_local_nbr_indices.size(), + thrust::make_zip_iterator(thrust::make_tuple(minor_comm_ranks.begin(), + intra_partition_displacements.begin(), + sample_local_nbr_indices.begin(), + (*sample_key_indices).begin())), + convert_pair_to_quadruplet_t{ + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), + raft::device_span(d_tx_counts.data(), d_tx_counts.size()), + frontier.size(), + minor_comm_size, + cugraph::ops::graph::INVALID_ID}); + rmm::device_uvector tx_displacements(minor_comm_size, handle.get_stream()); + thrust::exclusive_scan( + handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), tx_displacements.begin()); + auto tmp_sample_local_nbr_indices = + rmm::device_uvector(tx_displacements.back_element(handle.get_stream()) + + d_tx_counts.back_element(handle.get_stream()), + handle.get_stream()); + auto tmp_sample_key_indices = + rmm::device_uvector(tmp_sample_local_nbr_indices.size(), handle.get_stream()); + auto pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); + thrust::scatter_if( + handle.get_thrust_policy(), + pair_first, + pair_first + sample_local_nbr_indices.size(), + thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + shuffle_index_compute_offset_t{ + raft::device_span(minor_comm_ranks.data(), minor_comm_ranks.size()), + raft::device_span(intra_partition_displacements.data(), + intra_partition_displacements.size()), + raft::device_span(tx_displacements.data(), tx_displacements.size())}), + minor_comm_ranks.begin(), + thrust::make_zip_iterator( + thrust::make_tuple(tmp_sample_local_nbr_indices.begin(), tmp_sample_key_indices.begin())), + is_not_equal_t{-1}); + + sample_local_nbr_indices = std::move(tmp_sample_local_nbr_indices); + sample_key_indices = std::move(tmp_sample_key_indices); + + std::vector h_tx_counts(d_tx_counts.size()); + raft::update_host( + h_tx_counts.data(), d_tx_counts.data(), d_tx_counts.size(), handle.get_stream()); + handle.sync_stream(); + + pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); + auto [rx_value_buffer, rx_counts] = + shuffle_values(minor_comm, pair_first, h_tx_counts, handle.get_stream()); + + sample_local_nbr_indices = std::move(std::get<0>(rx_value_buffer)); + sample_key_indices = std::move(std::get<1>(rx_value_buffer)); + local_frontier_sample_offsets = std::vector(rx_counts.size() + 1); + local_frontier_sample_offsets[0] = size_t{0}; + std::inclusive_scan( + rx_counts.begin(), rx_counts.end(), local_frontier_sample_offsets.begin() + 1); + } else { + local_frontier_sample_offsets = std::vector{size_t{0}, frontier.size() * K}; + } +} + +template +std::tuple, + std::optional>, + std::vector> // (neighbor indices, key indices, local frontier sample offsets) +sample_and_compute_local_nbr_indices(raft::handle_t const& handle, + GraphViewType const& graph_view, + VertexFrontierBucketType const& frontier, + AggregateLocalFrontierBuffer const& aggregate_local_frontier, + EdgeSrcValueInputWrapper edge_src_value_input, + EdgeDstValueInputWrapper edge_dst_value_input, + EdgeValueInputWrapper edge_value_input, + EdgeBiasOp e_bias_op, + std::vector const& local_frontier_displacements, + std::vector const& local_frontier_sizes, + raft::random::RngState& rng_state, + size_t K, + bool with_replacement, + bool do_expensive_check) +{ + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + using key_t = typename VertexFrontierBucketType::key_type; + +#ifndef NO_CUGRAPH_OPS + + bool constexpr use_bias = !std::is_same_v>; + + using edge_partition_src_input_device_view_t = std::conditional_t< + std::is_same_v, + edge_partition_endpoint_dummy_property_device_view_t, + edge_partition_endpoint_property_device_view_t< + vertex_t, + typename EdgeSrcValueInputWrapper::value_iterator, + typename EdgeSrcValueInputWrapper::value_type>>; + using edge_partition_dst_input_device_view_t = std::conditional_t< + std::is_same_v, + edge_partition_endpoint_dummy_property_device_view_t, + edge_partition_endpoint_property_device_view_t< + vertex_t, + typename EdgeDstValueInputWrapper::value_iterator, + typename EdgeDstValueInputWrapper::value_type>>; + using edge_partition_e_input_device_view_t = std::conditional_t< + std::is_same_v, + detail::edge_partition_edge_dummy_property_device_view_t, + detail::edge_partition_edge_property_device_view_t< + edge_t, + typename EdgeValueInputWrapper::value_iterator, + typename EdgeValueInputWrapper::value_type>>; + + static_assert(GraphViewType::is_storage_transposed == incoming); + + CUGRAPH_EXPECTS(K >= size_t{1}, + "Invalid input argument: invalid K, K should be a positive integer."); + CUGRAPH_EXPECTS(K <= static_cast(std::numeric_limits::max()), + "Invalid input argument: the current implementation expects K to be no larger " + "than std::numeric_limits::max()."); + + auto minor_comm_size = + GraphViewType::is_multi_gpu + ? handle.get_subcomm(cugraph::partition_manager::minor_comm_name()).get_size() + : int{1}; + assert(minor_comm_size == graph_view.number_of_local_edge_partitions()); + + if (do_expensive_check) { + // FIXME: should I check frontier & aggregate_local_frontier? + } + + // 1. compute degrees + + auto edge_mask_view = graph_view.edge_mask_view(); + + rmm::device_uvector frontier_degrees(frontier.size(), handle.get_stream()); + auto frontier_partitioned_local_degree_displacements = + ((minor_comm_size > 1) && (!use_bias || with_replacement)) + ? std::make_optional>(size_t{0}, handle.get_stream()) + : std::nullopt; // one partition per gpu in the same minor_comm + + // BIASED: unnecessary for biased as we can just set bias to 0 for masked out edges? + std::optional, rmm::device_uvector>>> + local_frontier_valid_local_nbr_count_inclusive_sums{}; // to avoid searching the entire + // neighbor list K times for high degree + // vertices with edge masking + if (edge_mask_view) { + local_frontier_valid_local_nbr_count_inclusive_sums = + std::vector, rmm::device_uvector>>{}; + (*local_frontier_valid_local_nbr_count_inclusive_sums) + .reserve(graph_view.number_of_local_edge_partitions()); + } + + { + auto aggregate_local_frontier_local_degrees = + (minor_comm_size > 1) + ? std::make_optional>( + local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) + : std::nullopt; + + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + auto edge_partition = + edge_partition_device_view_t( + graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; + + vertex_t const* edge_partition_frontier_major_first{nullptr}; + + auto edge_partition_frontier_key_first = + ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier) + : frontier.begin()) + + local_frontier_displacements[i]; + if constexpr (std::is_same_v) { + edge_partition_frontier_major_first = edge_partition_frontier_key_first; + } else { + edge_partition_frontier_major_first = thrust::get<0>(edge_partition_frontier_key_first); + } + + auto edge_partition_frontier_local_degrees = + edge_partition_e_mask ? edge_partition.compute_local_degrees_with_mask( + (*edge_partition_e_mask).value_first(), + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()) + : edge_partition.compute_local_degrees( + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()); + + if (minor_comm_size > 1) { + // FIXME: this copy is unnecessary if edge_partition.compute_local_degrees() takes a pointer + // to the output array + thrust::copy( + handle.get_thrust_policy(), + edge_partition_frontier_local_degrees.begin(), + edge_partition_frontier_local_degrees.end(), + (*aggregate_local_frontier_local_degrees).begin() + local_frontier_displacements[i]); + } else { + frontier_degrees = std::move(edge_partition_frontier_local_degrees); + } + + if (edge_partition_e_mask) { + (*local_frontier_valid_local_nbr_count_inclusive_sums) + .push_back(compute_valid_local_nbr_count_inclusive_sums( + handle, + edge_partition, + *edge_partition_e_mask, + raft::device_span(edge_partition_frontier_major_first, + local_frontier_sizes[i]))); + } + } + + if (minor_comm_size > 1) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + rmm::device_uvector frontier_gathered_local_degrees(0, handle.get_stream()); + std::tie(frontier_gathered_local_degrees, std::ignore) = + shuffle_values(minor_comm, + (*aggregate_local_frontier_local_degrees).begin(), + local_frontier_sizes, + handle.get_stream()); + aggregate_local_frontier_local_degrees = std::nullopt; + frontier_partitioned_local_degree_displacements = + rmm::device_uvector(frontier_degrees.size() * minor_comm_size, handle.get_stream()); + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(frontier_degrees.size()), + compute_local_degree_displacements_and_global_degree_t{ + raft::device_span(frontier_gathered_local_degrees.data(), + frontier_gathered_local_degrees.size()), + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), + raft::device_span(frontier_degrees.data(), frontier_degrees.size()), + minor_comm_size}); + } + } + + rmm::device_uvector sample_nbr_indices(0, handle.get_stream()); + + // BIAS: indices or random number + // 4. shuffle randomly selected indices + + auto sample_local_nbr_indices = std::move( + sample_nbr_indices); // neighbor index within an edge partition (note that each vertex's + // neighbors are distributed in minor_comm_size partitions) + std::optional> sample_key_indices{ + std::nullopt}; // relevant only when (minor_comm_size > 1) + std::vector local_frontier_sample_offsets{}; + if (minor_comm_size > 1) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + sample_key_indices = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + auto minor_comm_ranks = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + auto intra_partition_displacements = + rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); + rmm::device_uvector d_tx_counts(minor_comm_size, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), size_t{0}); + auto input_pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), + thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), + divider_t{K}))); + thrust::transform( + handle.get_thrust_policy(), + input_pair_first, + input_pair_first + sample_local_nbr_indices.size(), + thrust::make_zip_iterator(thrust::make_tuple(minor_comm_ranks.begin(), + intra_partition_displacements.begin(), + sample_local_nbr_indices.begin(), + (*sample_key_indices).begin())), + convert_pair_to_quadruplet_t{ + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), + raft::device_span(d_tx_counts.data(), d_tx_counts.size()), + frontier.size(), + minor_comm_size, + cugraph::ops::graph::INVALID_ID}); + rmm::device_uvector tx_displacements(minor_comm_size, handle.get_stream()); + thrust::exclusive_scan( + handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), tx_displacements.begin()); + auto tmp_sample_local_nbr_indices = + rmm::device_uvector(tx_displacements.back_element(handle.get_stream()) + + d_tx_counts.back_element(handle.get_stream()), + handle.get_stream()); + auto tmp_sample_key_indices = + rmm::device_uvector(tmp_sample_local_nbr_indices.size(), handle.get_stream()); + auto pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); + thrust::scatter_if( + handle.get_thrust_policy(), + pair_first, + pair_first + sample_local_nbr_indices.size(), + thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + shuffle_index_compute_offset_t{ + raft::device_span(minor_comm_ranks.data(), minor_comm_ranks.size()), + raft::device_span(intra_partition_displacements.data(), + intra_partition_displacements.size()), + raft::device_span(tx_displacements.data(), tx_displacements.size())}), + minor_comm_ranks.begin(), + thrust::make_zip_iterator( + thrust::make_tuple(tmp_sample_local_nbr_indices.begin(), tmp_sample_key_indices.begin())), + is_not_equal_t{-1}); + + sample_local_nbr_indices = std::move(tmp_sample_local_nbr_indices); + sample_key_indices = std::move(tmp_sample_key_indices); + + std::vector h_tx_counts(d_tx_counts.size()); + raft::update_host( + h_tx_counts.data(), d_tx_counts.data(), d_tx_counts.size(), handle.get_stream()); + handle.sync_stream(); + + pair_first = thrust::make_zip_iterator( + thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); + auto [rx_value_buffer, rx_counts] = + shuffle_values(minor_comm, pair_first, h_tx_counts, handle.get_stream()); + + sample_local_nbr_indices = std::move(std::get<0>(rx_value_buffer)); + sample_key_indices = std::move(std::get<1>(rx_value_buffer)); + local_frontier_sample_offsets = std::vector(rx_counts.size() + 1); + local_frontier_sample_offsets[0] = size_t{0}; + std::inclusive_scan( + rx_counts.begin(), rx_counts.end(), local_frontier_sample_offsets.begin() + 1); + } else { + local_frontier_sample_offsets = std::vector{size_t{0}, frontier.size() * K}; + } + + return std::make_tuple(std::move(sample_local_nbr_indices), + std::move(sample_key_indices), + std::move(local_frontier_sample_offsets)); +#else + CUGRAPH_FAIL("unimplemented."); + return std::make_tuple( + rmm::device_uvector(0, handle.get_stream()), std::nullopt, std::vector{}); +#endif +} +#endif + +} // namespace detail + +} // namespace cugraph diff --git a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh index 5240c49cb80..480aec6a02a 100644 --- a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh +++ b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include "prims/detail/sample_and_compute_local_nbr_indices.cuh" #include "prims/property_op_utils.cuh" #include @@ -53,101 +54,6 @@ namespace cugraph { namespace detail { -int32_t constexpr per_v_random_select_transform_outgoing_e_block_size = 256; - -size_t constexpr compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold = - packed_bools_per_word() * - size_t{4} /* tuning parameter */; // minimum local degree to compute inclusive sums of valid - // local neighbors per word to accelerate finding n'th local - // neighbor vertex -size_t constexpr compute_valid_local_nbr_count_inclusive_sum_mid_local_degree_threshold = - packed_bools_per_word() * static_cast(raft::warp_size()) * - size_t{ - 4} /* tuning parameter */; // minimum local degree to use a CUDA warp to compute inclusive sums -size_t constexpr compute_valid_local_nbr_count_inclusive_sum_high_local_degree_threshold = - packed_bools_per_word() * per_v_random_select_transform_outgoing_e_block_size * - size_t{4} /* tuning parameter */; // minimum local degree to use a CUDA block to compute - // inclusive sums - -template -struct compute_local_degree_displacements_and_global_degree_t { - raft::device_span gathered_local_degrees{}; - raft::device_span - partitioned_local_degree_displacements{}; // one partition per gpu in the same minor_comm - raft::device_span global_degrees{}; - int minor_comm_size{}; - - __device__ void operator()(size_t i) const - { - constexpr int buffer_size = 8; // tuning parameter - edge_t displacements[buffer_size]; - edge_t sum{0}; - for (int round = 0; round < (minor_comm_size + buffer_size - 1) / buffer_size; ++round) { - auto loop_count = std::min(buffer_size, minor_comm_size - round * buffer_size); - for (int j = 0; j < loop_count; ++j) { - displacements[j] = sum; - sum += gathered_local_degrees[i + (round * buffer_size + j) * global_degrees.size()]; - } - thrust::copy( - thrust::seq, - displacements, - displacements + loop_count, - partitioned_local_degree_displacements.begin() + i * minor_comm_size + round * buffer_size); - } - global_degrees[i] = sum; - } -}; - -// convert a (neighbor index, key index) pair to a (minor_comm_rank, intra-partition offset, -// neighbor index, key index) quadruplet, minor_comm_rank is set to -1 if an neighbor index is -// invalid -template -struct convert_pair_to_quadruplet_t { - raft::device_span - partitioned_local_degree_displacements{}; // one partition per gpu in the same minor_comm - raft::device_span tx_counts{}; - size_t stride{}; - int minor_comm_size{}; - edge_t invalid_idx{}; - - __device__ thrust::tuple operator()( - thrust::tuple index_pair) const - { - auto nbr_idx = thrust::get<0>(index_pair); - auto key_idx = thrust::get<1>(index_pair); - auto local_nbr_idx = nbr_idx; - int minor_comm_rank{-1}; - size_t intra_partition_offset{}; - if (nbr_idx != invalid_idx) { - auto displacement_first = - partitioned_local_degree_displacements.begin() + key_idx * minor_comm_size; - minor_comm_rank = - static_cast(thrust::distance( - displacement_first, - thrust::upper_bound( - thrust::seq, displacement_first, displacement_first + minor_comm_size, nbr_idx))) - - 1; - local_nbr_idx -= *(displacement_first + minor_comm_rank); - cuda::atomic_ref counter(tx_counts[minor_comm_rank]); - intra_partition_offset = counter.fetch_add(size_t{1}, cuda::std::memory_order_relaxed); - } - return thrust::make_tuple(minor_comm_rank, intra_partition_offset, local_nbr_idx, key_idx); - } -}; - -struct shuffle_index_compute_offset_t { - raft::device_span minor_comm_ranks{}; - raft::device_span intra_partition_displacements{}; - raft::device_span tx_displacements{}; - - __device__ size_t operator()(size_t i) const - { - auto minor_comm_rank = minor_comm_ranks[i]; - assert(minor_comm_rank != -1); - return tx_displacements[minor_comm_rank] + intra_partition_displacements[i]; - } -}; - template struct check_invalid_t { edge_t invalid_idx{}; @@ -158,23 +64,12 @@ struct check_invalid_t { } }; -template -struct invalid_minor_comm_rank_t { - int invalid_minor_comm_rank{}; - __device__ bool operator()(thrust::tuple triplet) const - { - return thrust::get<1>(triplet) == invalid_minor_comm_rank; - } -}; - template struct transform_local_nbr_indices_t { @@ -186,19 +81,15 @@ struct transform_local_nbr_indices_t { thrust::optional local_key_indices{thrust::nullopt}; KeyIterator key_first{}; LocalNbrIdxIterator local_nbr_idx_first{}; - OutputValueIterator output_value_first{}; EdgePartitionSrcValueInputWrapper edge_partition_src_value_input; EdgePartitionDstValueInputWrapper edge_partition_dst_value_input; EdgePartitionEdgeValueInputWrapper edge_partition_e_value_input; - EdgePartitionEdgeMaskWrapper edge_partition_e_mask; - thrust::optional, raft::device_span>> - key_valid_local_nbr_count_inclusive_sums{}; EdgeOp e_op{}; edge_t invalid_idx{}; thrust::optional invalid_value{thrust::nullopt}; size_t K{}; - __device__ void operator()(size_t i) const + __device__ T operator()(size_t i) const { auto key_idx = local_key_indices ? (*local_key_indices)[i] : (i / K); auto key = *(key_first + key_idx); @@ -230,31 +121,6 @@ struct transform_local_nbr_indices_t { auto local_nbr_idx = *(local_nbr_idx_first + i); if (local_nbr_idx != invalid_idx) { vertex_t minor{}; - if (edge_partition_e_mask) { - if (local_degree < compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold) { - local_nbr_idx = find_nth_set_bits( - (*edge_partition_e_mask).value_first(), edge_offset, local_degree, local_nbr_idx + 1); - } else { - auto inclusive_sum_first = - thrust::get<1>(*key_valid_local_nbr_count_inclusive_sums).begin(); - auto start_offset = thrust::get<0>(*key_valid_local_nbr_count_inclusive_sums)[key_idx]; - auto end_offset = thrust::get<0>(*key_valid_local_nbr_count_inclusive_sums)[key_idx + 1]; - auto word_idx = static_cast( - thrust::distance(inclusive_sum_first + start_offset, - thrust::upper_bound(thrust::seq, - inclusive_sum_first + start_offset, - inclusive_sum_first + end_offset, - local_nbr_idx))); - local_nbr_idx = word_idx * packed_bools_per_word() + - find_nth_set_bits( - (*edge_partition_e_mask).value_first(), - edge_offset + word_idx * packed_bools_per_word(), - local_degree - word_idx * packed_bools_per_word(), - (local_nbr_idx + 1) - - ((word_idx > 0) ? *(inclusive_sum_first + start_offset + word_idx - 1) - : edge_t{0})); - } - } minor = indices[local_nbr_idx]; auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); @@ -271,14 +137,13 @@ struct transform_local_nbr_indices_t { } auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; - *(output_value_first + i) = - e_op(key_or_src, - key_or_dst, - edge_partition_src_value_input.get(src_offset), - edge_partition_dst_value_input.get(dst_offset), - edge_partition_e_value_input.get(edge_offset + local_nbr_idx)); + return e_op(key_or_src, + key_or_dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(edge_offset + local_nbr_idx)); } else if (invalid_value) { - *(output_value_first + i) = *invalid_value; + return *invalid_value; } } }; @@ -327,630 +192,13 @@ struct return_value_compute_offset_t { } }; -template -__global__ static void compute_valid_local_nbr_inclusive_sums_mid_local_degree( - edge_partition_device_view_t edge_partition, - edge_partition_edge_property_device_view_t edge_partition_e_mask, - raft::device_span edge_partition_frontier_majors, - raft::device_span inclusive_sum_offsets, - raft::device_span frontier_indices, - raft::device_span inclusive_sums) -{ - static_assert(per_v_random_select_transform_outgoing_e_block_size % raft::warp_size() == 0); - - auto const tid = threadIdx.x + blockIdx.x * blockDim.x; - auto const lane_id = tid % raft::warp_size(); - - auto idx = static_cast(tid / raft::warp_size()); - - using WarpScan = cub::WarpScan; - __shared__ typename WarpScan::TempStorage temp_storage; - - while (idx < frontier_indices.size()) { - auto frontier_idx = frontier_indices[idx]; - auto major = edge_partition_frontier_majors[frontier_idx]; - vertex_t major_idx{}; - if constexpr (multi_gpu) { - major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); - } else { - major_idx = edge_partition.major_offset_from_major_nocheck(major); - } - auto edge_offset = edge_partition.local_offset(major_idx); - auto local_degree = edge_partition.local_degree(major_idx); - - auto start_offset = inclusive_sum_offsets[frontier_idx]; - auto end_offset = inclusive_sum_offsets[frontier_idx + 1]; - auto num_inclusive_sums = end_offset - start_offset; - auto rounded_up_num_inclusive_sums = - ((num_inclusive_sums + raft::warp_size() - 1) / raft::warp_size()) * raft::warp_size(); - edge_t sum{0}; - for (size_t j = lane_id; j <= rounded_up_num_inclusive_sums; j += raft::warp_size()) { - auto inc = - (j < num_inclusive_sums) - ? static_cast(count_set_bits( - edge_partition_e_mask.value_first(), - edge_offset + packed_bools_per_word() * j, - cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j))) - : edge_t{0}; - WarpScan(temp_storage).InclusiveSum(inc, inc); - inclusive_sums[start_offset + j] = sum + inc; - sum += __shfl_sync(raft::warp_full_mask(), inc, raft::warp_size() - 1); - } - - idx += gridDim.x * (blockDim.x / raft::warp_size()); - } -} - -template -__global__ static void compute_valid_local_nbr_inclusive_sums_high_local_degree( - edge_partition_device_view_t edge_partition, - edge_partition_edge_property_device_view_t edge_partition_e_mask, - raft::device_span edge_partition_frontier_majors, - raft::device_span inclusive_sum_offsets, - raft::device_span frontier_indices, - raft::device_span inclusive_sums) -{ - static_assert(per_v_random_select_transform_outgoing_e_block_size % raft::warp_size() == 0); - - auto idx = static_cast(blockIdx.x); - - using BlockScan = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - - __shared__ edge_t sum; - - while (idx < frontier_indices.size()) { - auto frontier_idx = frontier_indices[idx]; - auto major = edge_partition_frontier_majors[frontier_idx]; - vertex_t major_idx{}; - if constexpr (multi_gpu) { - major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); - } else { - major_idx = edge_partition.major_offset_from_major_nocheck(major); - } - auto edge_offset = edge_partition.local_offset(major_idx); - auto local_degree = edge_partition.local_degree(major_idx); - - auto start_offset = inclusive_sum_offsets[frontier_idx]; - auto end_offset = inclusive_sum_offsets[frontier_idx + 1]; - auto num_inclusive_sums = end_offset - start_offset; - auto rounded_up_num_inclusive_sums = - ((num_inclusive_sums + per_v_random_select_transform_outgoing_e_block_size - 1) / - per_v_random_select_transform_outgoing_e_block_size) * - per_v_random_select_transform_outgoing_e_block_size; - if (threadIdx.x == per_v_random_select_transform_outgoing_e_block_size - 1) { sum = 0; } - for (size_t j = threadIdx.x; j <= rounded_up_num_inclusive_sums; j += blockDim.x) { - auto inc = - (j < num_inclusive_sums) - ? static_cast(count_set_bits( - edge_partition_e_mask.value_first(), - edge_offset + packed_bools_per_word() * j, - cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j))) - : edge_t{0}; - BlockScan(temp_storage).InclusiveSum(inc, inc); - inclusive_sums[start_offset + j] = sum + inc; - __syncthreads(); - if (threadIdx.x == per_v_random_select_transform_outgoing_e_block_size - 1) { sum += inc; } - } - - idx += gridDim.x; - } -} - -template -std::tuple, rmm::device_uvector> -compute_valid_local_nbr_count_inclusive_sums( - raft::handle_t const& handle, - edge_partition_device_view_t const& edge_partition, - edge_partition_edge_property_device_view_t const& - edge_partition_e_mask, - raft::device_span edge_partition_frontier_majors) -{ - auto edge_partition_local_degrees = - edge_partition.compute_local_degrees(edge_partition_frontier_majors.begin(), - edge_partition_frontier_majors.end(), - handle.get_stream()); - auto offsets = - rmm::device_uvector(edge_partition_frontier_majors.size() + 1, handle.get_stream()); - offsets.set_element_to_zero_async(0, handle.get_stream()); - auto size_first = thrust::make_transform_iterator( - edge_partition_local_degrees.begin(), - cuda::proclaim_return_type([] __device__(edge_t local_degree) { - return static_cast((local_degree + packed_bools_per_word() - 1) / - packed_bools_per_word()); - })); - thrust::inclusive_scan(handle.get_thrust_policy(), - size_first, - size_first + edge_partition_local_degrees.size(), - offsets.begin() + 1); - - rmm::device_uvector frontier_indices(edge_partition_frontier_majors.size(), - handle.get_stream()); - frontier_indices.resize( - thrust::distance( - frontier_indices.begin(), - thrust::copy_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(edge_partition_frontier_majors.size()), - frontier_indices.begin(), - [threshold = compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold, - local_degrees = raft::device_span( - edge_partition_local_degrees.data(), - edge_partition_local_degrees.size())] __device__(size_t i) { - return local_degrees[i] >= threshold; - })), - handle.get_stream()); - - auto low_last = thrust::partition( - handle.get_thrust_policy(), - frontier_indices.begin(), - frontier_indices.end(), - [threshold = compute_valid_local_nbr_count_inclusive_sum_mid_local_degree_threshold, - local_degrees = - raft::device_span(edge_partition_local_degrees.data(), - edge_partition_local_degrees.size())] __device__(size_t i) { - return local_degrees[i] < threshold; - }); - auto mid_last = thrust::partition( - handle.get_thrust_policy(), - low_last, - frontier_indices.end(), - [threshold = compute_valid_local_nbr_count_inclusive_sum_high_local_degree_threshold, - local_degrees = - raft::device_span(edge_partition_local_degrees.data(), - edge_partition_local_degrees.size())] __device__(size_t i) { - return local_degrees[i] < threshold; - }); - - rmm::device_uvector inclusive_sums(offsets.back_element(handle.get_stream()), - handle.get_stream()); - - thrust::for_each( - handle.get_thrust_policy(), - frontier_indices.begin(), - low_last, - [edge_partition, - edge_partition_e_mask, - edge_partition_frontier_majors, - offsets = raft::device_span(offsets.data(), offsets.size()), - inclusive_sums = raft::device_span(inclusive_sums.data(), - inclusive_sums.size())] __device__(size_t i) { - auto major = edge_partition_frontier_majors[i]; - vertex_t major_idx{}; - if constexpr (multi_gpu) { - major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); - } else { - major_idx = edge_partition.major_offset_from_major_nocheck(major); - } - auto edge_offset = edge_partition.local_offset(major_idx); - auto local_degree = edge_partition.local_degree(major_idx); - edge_t sum{0}; - auto start_offset = offsets[i]; - auto end_offset = offsets[i + 1]; - for (size_t j = 0; j < end_offset - start_offset; ++j) { - sum += count_set_bits( - edge_partition_e_mask.value_first(), - edge_offset + packed_bools_per_word() * j, - cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j)); - inclusive_sums[start_offset + j] = sum; - } - }); - - if (thrust::distance(low_last, mid_last) > 0) { - raft::grid_1d_warp_t update_grid(thrust::distance(low_last, mid_last), - per_v_random_select_transform_outgoing_e_block_size, - handle.get_device_properties().maxGridSize[0]); - compute_valid_local_nbr_inclusive_sums_mid_local_degree<<>>( - edge_partition, - edge_partition_e_mask, - edge_partition_frontier_majors, - raft::device_span(offsets.data(), offsets.size()), - raft::device_span(low_last, thrust::distance(low_last, mid_last)), - raft::device_span(inclusive_sums.data(), inclusive_sums.size())); - } - - if (thrust::distance(mid_last, frontier_indices.end()) > 0) { - raft::grid_1d_block_t update_grid(thrust::distance(mid_last, frontier_indices.end()), - per_v_random_select_transform_outgoing_e_block_size, - handle.get_device_properties().maxGridSize[0]); - compute_valid_local_nbr_inclusive_sums_high_local_degree<<>>( - edge_partition, - edge_partition_e_mask, - edge_partition_frontier_majors, - raft::device_span(offsets.data(), offsets.size()), - raft::device_span(mid_last, thrust::distance(mid_last, frontier_indices.end())), - raft::device_span(inclusive_sums.data(), inclusive_sums.size())); - } - - return std::make_tuple(std::move(offsets), std::move(inclusive_sums)); -} - -template -rmm::device_uvector get_sampling_index_without_replacement( - raft::handle_t const& handle, - rmm::device_uvector&& frontier_degrees, - raft::random::RngState& rng_state, - size_t K) -{ -#ifndef NO_CUGRAPH_OPS - edge_t mid_partition_degree_range_last = static_cast(K * 10); // tuning parameter - assert(mid_partition_degree_range_last > K); - size_t high_partition_over_sampling_K = K * 2; // tuning parameter - assert(high_partition_over_sampling_K > K); - - rmm::device_uvector sample_nbr_indices(frontier_degrees.size() * K, handle.get_stream()); - - rmm::device_uvector seed_indices(frontier_degrees.size(), handle.get_stream()); - thrust::sequence(handle.get_thrust_policy(), seed_indices.begin(), seed_indices.end(), size_t{0}); - auto low_first = - thrust::make_zip_iterator(thrust::make_tuple(frontier_degrees.begin(), seed_indices.begin())); - auto mid_first = thrust::partition( - handle.get_thrust_policy(), - low_first, - low_first + frontier_degrees.size(), - [K] __device__(auto pair) { return thrust::get<0>(pair) <= static_cast(K); }); - auto low_partition_size = static_cast(thrust::distance(low_first, mid_first)); - auto high_first = - thrust::partition(handle.get_thrust_policy(), - mid_first, - mid_first + (frontier_degrees.size() - low_partition_size), - [mid_partition_degree_range_last] __device__(auto pair) { - return thrust::get<0>(pair) < mid_partition_degree_range_last; - }); - auto mid_partition_size = static_cast(thrust::distance(mid_first, high_first)); - auto high_partition_size = frontier_degrees.size() - (low_partition_size + mid_partition_size); - - if (low_partition_size > 0) { - thrust::for_each(handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(low_partition_size * K), - [K, - low_first, - sample_nbr_indices = sample_nbr_indices.data(), - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { - auto pair = *(low_first + (i / K)); - auto degree = thrust::get<0>(pair); - auto seed_idx = thrust::get<1>(pair); - auto sample_idx = static_cast(i % K); - sample_nbr_indices[seed_idx * K + sample_idx] = - (sample_idx < degree) ? sample_idx : invalid_idx; - }); - } - - if (mid_partition_size > 0) { - rmm::device_uvector tmp_sample_nbr_indices(mid_partition_size * K, handle.get_stream()); - // FIXME: we can avoid the follow-up copy if get_sampling_index takes output offsets for - // sampling output - cugraph::ops::graph::get_sampling_index(tmp_sample_nbr_indices.data(), - rng_state, - thrust::get<0>(mid_first.get_iterator_tuple()), - mid_partition_size, - static_cast(K), - false, - handle.get_stream()); - thrust::for_each(handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(mid_partition_size * K), - [K, - seed_index_first = thrust::get<1>(mid_first.get_iterator_tuple()), - tmp_sample_nbr_indices = tmp_sample_nbr_indices.data(), - sample_nbr_indices = sample_nbr_indices.data()] __device__(size_t i) { - auto seed_idx = *(seed_index_first + i / K); - auto sample_idx = static_cast(i % K); - sample_nbr_indices[seed_idx * K + sample_idx] = tmp_sample_nbr_indices[i]; - }); - } - - if (high_partition_size > 0) { - // to limit memory footprint ((1 << 20) is a tuning parameter), std::max for forward progress - // guarantee when high_partition_over_sampling_K is exorbitantly large - auto seeds_to_sort_per_iteration = - std::max(static_cast(handle.get_device_properties().multiProcessorCount * (1 << 20)) / - high_partition_over_sampling_K, - size_t{1}); - - rmm::device_uvector tmp_sample_nbr_indices( - seeds_to_sort_per_iteration * high_partition_over_sampling_K, handle.get_stream()); - assert(high_partition_over_sampling_K * 2 <= - static_cast(std::numeric_limits::max())); - rmm::device_uvector tmp_sample_indices( - seeds_to_sort_per_iteration * high_partition_over_sampling_K, - handle.get_stream()); // sample indices within a segment (one partition per seed) - - rmm::device_uvector segment_sorted_tmp_sample_nbr_indices( - seeds_to_sort_per_iteration * high_partition_over_sampling_K, handle.get_stream()); - rmm::device_uvector segment_sorted_tmp_sample_indices( - seeds_to_sort_per_iteration * high_partition_over_sampling_K, handle.get_stream()); - - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - size_t tmp_storage_bytes{0}; - - auto num_chunks = - (high_partition_size + seeds_to_sort_per_iteration - 1) / seeds_to_sort_per_iteration; - for (size_t i = 0; i < num_chunks; ++i) { - size_t num_segments = std::min(seeds_to_sort_per_iteration, - high_partition_size - seeds_to_sort_per_iteration * i); - - rmm::device_uvector unique_counts(num_segments, handle.get_stream()); - - std::optional> retry_segment_indices{std::nullopt}; - std::optional> retry_degrees{std::nullopt}; - std::optional> retry_sample_nbr_indices{std::nullopt}; - std::optional> retry_sample_indices{std::nullopt}; - std::optional> retry_segment_sorted_sample_nbr_indices{ - std::nullopt}; - std::optional> retry_segment_sorted_sample_indices{std::nullopt}; - while (true) { - auto segment_degree_first = - thrust::get<0>(high_first.get_iterator_tuple()) + seeds_to_sort_per_iteration * i; - - if (retry_segment_indices) { - retry_degrees = - rmm::device_uvector((*retry_segment_indices).size(), handle.get_stream()); - thrust::transform( - handle.get_thrust_policy(), - (*retry_segment_indices).begin(), - (*retry_segment_indices).end(), - (*retry_degrees).begin(), - indirection_t{segment_degree_first}); - retry_sample_nbr_indices = rmm::device_uvector( - (*retry_segment_indices).size() * high_partition_over_sampling_K, handle.get_stream()); - retry_sample_indices = rmm::device_uvector( - (*retry_segment_indices).size() * high_partition_over_sampling_K, handle.get_stream()); - retry_segment_sorted_sample_nbr_indices = rmm::device_uvector( - (*retry_segment_indices).size() * high_partition_over_sampling_K, handle.get_stream()); - retry_segment_sorted_sample_indices = rmm::device_uvector( - (*retry_segment_indices).size() * high_partition_over_sampling_K, handle.get_stream()); - } - - cugraph::ops::graph::get_sampling_index( - retry_segment_indices ? (*retry_sample_nbr_indices).data() - : tmp_sample_nbr_indices.data(), - rng_state, - retry_segment_indices ? (*retry_degrees).begin() : segment_degree_first, - retry_segment_indices ? (*retry_degrees).size() : num_segments, - static_cast(high_partition_over_sampling_K), - true, - handle.get_stream()); - - if (retry_segment_indices) { - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator((*retry_segment_indices).size() * - high_partition_over_sampling_K), - [high_partition_over_sampling_K, - unique_counts = unique_counts.data(), - segment_sorted_tmp_sample_nbr_indices = segment_sorted_tmp_sample_nbr_indices.data(), - retry_segment_indices = (*retry_segment_indices).data(), - retry_sample_nbr_indices = (*retry_sample_nbr_indices).data(), - retry_sample_indices = (*retry_sample_indices).data()] __device__(size_t i) { - auto segment_idx = retry_segment_indices[i / high_partition_over_sampling_K]; - auto sample_idx = static_cast(i % high_partition_over_sampling_K); - auto unique_count = unique_counts[segment_idx]; - auto output_first = thrust::make_zip_iterator( - thrust::make_tuple(retry_sample_nbr_indices, retry_sample_indices)); - // sample index for the previously selected neighbor indices should be smaller than - // the new candidates to ensure that the previously selected neighbor indices will be - // selected again - if (sample_idx < unique_count) { - *(output_first + i) = - thrust::make_tuple(segment_sorted_tmp_sample_nbr_indices - [segment_idx * high_partition_over_sampling_K + sample_idx], - static_cast(sample_idx)); - } else { - *(output_first + i) = - thrust::make_tuple(retry_sample_nbr_indices[i], - high_partition_over_sampling_K + (sample_idx - unique_count)); - } - }); - } else { - thrust::tabulate( - handle.get_thrust_policy(), - tmp_sample_indices.begin(), - tmp_sample_indices.begin() + num_segments * high_partition_over_sampling_K, - [high_partition_over_sampling_K] __device__(size_t i) { - return static_cast(i % high_partition_over_sampling_K); - }); - } - - // sort the (sample neighbor index, sample index) pairs (key: sample neighbor index) - - cub::DeviceSegmentedSort::SortPairs( - static_cast(nullptr), - tmp_storage_bytes, - retry_segment_indices ? (*retry_sample_nbr_indices).data() - : tmp_sample_nbr_indices.data(), - retry_segment_indices ? (*retry_segment_sorted_sample_nbr_indices).data() - : segment_sorted_tmp_sample_nbr_indices.data(), - retry_segment_indices ? (*retry_sample_indices).data() : tmp_sample_indices.data(), - retry_segment_indices ? (*retry_segment_sorted_sample_indices).data() - : segment_sorted_tmp_sample_indices.data(), - (retry_segment_indices ? (*retry_segment_indices).size() : num_segments) * - high_partition_over_sampling_K, - retry_segment_indices ? (*retry_segment_indices).size() : num_segments, - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - multiplier_t{high_partition_over_sampling_K}), - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{1}), - multiplier_t{high_partition_over_sampling_K}), - handle.get_stream()); - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - cub::DeviceSegmentedSort::SortPairs( - d_tmp_storage.data(), - tmp_storage_bytes, - retry_segment_indices ? (*retry_sample_nbr_indices).data() - : tmp_sample_nbr_indices.data(), - retry_segment_indices ? (*retry_segment_sorted_sample_nbr_indices).data() - : segment_sorted_tmp_sample_nbr_indices.data(), - retry_segment_indices ? (*retry_sample_indices).data() : tmp_sample_indices.data(), - retry_segment_indices ? (*retry_segment_sorted_sample_indices).data() - : segment_sorted_tmp_sample_indices.data(), - (retry_segment_indices ? (*retry_segment_indices).size() : num_segments) * - high_partition_over_sampling_K, - retry_segment_indices ? (*retry_segment_indices).size() : num_segments, - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - multiplier_t{high_partition_over_sampling_K}), - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{1}), - multiplier_t{high_partition_over_sampling_K}), - handle.get_stream()); - - // count the number of unique neighbor indices - - if (retry_segment_indices) { - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator((*retry_segment_indices).size()), - [high_partition_over_sampling_K, - unique_counts = unique_counts.data(), - retry_segment_indices = (*retry_segment_indices).data(), - retry_segment_sorted_pair_first = thrust::make_zip_iterator( - thrust::make_tuple((*retry_segment_sorted_sample_nbr_indices).begin(), - (*retry_segment_sorted_sample_indices).begin())), - segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple( - segment_sorted_tmp_sample_nbr_indices.begin(), - segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) { - auto unique_count = static_cast(thrust::distance( - retry_segment_sorted_pair_first + high_partition_over_sampling_K * i, - thrust::unique( - thrust::seq, - retry_segment_sorted_pair_first + high_partition_over_sampling_K * i, - retry_segment_sorted_pair_first + high_partition_over_sampling_K * (i + 1), - [] __device__(auto lhs, auto rhs) { - return thrust::get<0>(lhs) == thrust::get<0>(rhs); - }))); - auto segment_idx = retry_segment_indices[i]; - unique_counts[segment_idx] = unique_count; - thrust::copy( - thrust::seq, - retry_segment_sorted_pair_first + high_partition_over_sampling_K * i, - retry_segment_sorted_pair_first + high_partition_over_sampling_K * i + unique_count, - segment_sorted_pair_first + high_partition_over_sampling_K * segment_idx); - }); - } else { - thrust::tabulate( - handle.get_thrust_policy(), - unique_counts.begin(), - unique_counts.end(), - [high_partition_over_sampling_K, - segment_sorted_pair_first = thrust::make_zip_iterator(thrust::make_tuple( - segment_sorted_tmp_sample_nbr_indices.begin(), - segment_sorted_tmp_sample_indices.begin()))] __device__(size_t i) { - return static_cast(thrust::distance( - segment_sorted_pair_first + high_partition_over_sampling_K * i, - thrust::unique(thrust::seq, - segment_sorted_pair_first + high_partition_over_sampling_K * i, - segment_sorted_pair_first + high_partition_over_sampling_K * (i + 1), - [] __device__(auto lhs, auto rhs) { - return thrust::get<0>(lhs) == thrust::get<0>(rhs); - }))); - }); - } - - auto num_retry_segments = - thrust::count_if(handle.get_thrust_policy(), - unique_counts.begin(), - unique_counts.end(), - [K] __device__(auto count) { return count < K; }); - if (num_retry_segments > 0) { - retry_segment_indices = - rmm::device_uvector(num_retry_segments, handle.get_stream()); - thrust::copy_if(handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(num_segments), - (*retry_segment_indices).begin(), - [K, unique_counts = unique_counts.data()] __device__(size_t i) { - return unique_counts[i] < K; - }); - } else { - break; - } - } - - // sort the segment-sorted (sample index, sample neighbor index) pairs (key: sample index) - - cub::DeviceSegmentedSort::SortPairs( - static_cast(nullptr), - tmp_storage_bytes, - segment_sorted_tmp_sample_indices.data(), - tmp_sample_indices.data(), - segment_sorted_tmp_sample_nbr_indices.data(), - tmp_sample_nbr_indices.data(), - num_segments * high_partition_over_sampling_K, - num_segments, - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - multiplier_t{high_partition_over_sampling_K}), - thrust::make_transform_iterator( - thrust::make_counting_iterator(size_t{0}), - cuda::proclaim_return_type( - [high_partition_over_sampling_K, unique_counts = unique_counts.data()] __device__( - size_t i) { return i * high_partition_over_sampling_K + unique_counts[i]; })), - handle.get_stream()); - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - cub::DeviceSegmentedSort::SortPairs( - d_tmp_storage.data(), - tmp_storage_bytes, - segment_sorted_tmp_sample_indices.data(), - tmp_sample_indices.data(), - segment_sorted_tmp_sample_nbr_indices.data(), - tmp_sample_nbr_indices.data(), - num_segments * high_partition_over_sampling_K, - num_segments, - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - multiplier_t{high_partition_over_sampling_K}), - thrust::make_transform_iterator( - thrust::make_counting_iterator(size_t{0}), - cuda::proclaim_return_type( - [high_partition_over_sampling_K, unique_counts = unique_counts.data()] __device__( - size_t i) { return i * high_partition_over_sampling_K + unique_counts[i]; })), - handle.get_stream()); - - // copy the neighbor indices back to sample_nbr_indices - - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(num_segments * K), - [K, - high_partition_over_sampling_K, - seed_indices = - thrust::get<1>(high_first.get_iterator_tuple()) + seeds_to_sort_per_iteration * i, - tmp_sample_nbr_indices = tmp_sample_nbr_indices.data(), - sample_nbr_indices = sample_nbr_indices.data()] __device__(size_t i) { - auto seed_idx = *(seed_indices + i / K); - auto sample_idx = static_cast(i % K); - *(sample_nbr_indices + seed_idx * K + sample_idx) = - *(tmp_sample_nbr_indices + (i / K) * high_partition_over_sampling_K + sample_idx); - }); - } - } - - frontier_degrees.resize(0, handle.get_stream()); - frontier_degrees.shrink_to_fit(handle.get_stream()); - - return sample_nbr_indices; -#else - CUGRAPH_FAIL("unimplemented."); -#endif -} - template std::tuple>, @@ -961,6 +209,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, EdgeSrcValueInputWrapper edge_src_value_input, EdgeDstValueInputWrapper edge_dst_value_input, EdgeValueInputWrapper edge_value_input, + EdgeBiasOp e_bias_op, EdgeOp e_op, raft::random::RngState& rng_state, size_t K, @@ -1017,6 +266,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, GraphViewType::is_multi_gpu ? handle.get_subcomm(cugraph::partition_manager::minor_comm_name()).get_size() : int{1}; + assert(graph_view.number_of_local_edge_partitions() == minor_comm_size); if (do_expensive_check) { // FIXME: better re-factor this check function? @@ -1044,19 +294,15 @@ per_v_random_select_transform_e(raft::handle_t const& handle, "Invalid input argument: frontier includes out-of-range keys."); } - auto frontier_key_first = frontier.begin(); - auto frontier_key_last = frontier.end(); - std::vector local_frontier_sizes{}; if (minor_comm_size > 1) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); local_frontier_sizes = host_scalar_allgather( minor_comm, - static_cast(thrust::distance(frontier_key_first, frontier_key_last)), + frontier.size(), handle.get_stream()); } else { - local_frontier_sizes = std::vector{static_cast( - static_cast(thrust::distance(frontier_key_first, frontier_key_last)))}; + local_frontier_sizes = std::vector{frontier.size()}; } std::vector local_frontier_displacements(local_frontier_sizes.size()); std::exclusive_scan(local_frontier_sizes.begin(), @@ -1066,7 +312,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, // 1. aggregate frontier - auto aggregate_local_frontier_keys = + auto aggregate_local_frontier = (minor_comm_size > 1) ? std::make_optional( local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) @@ -1074,257 +320,50 @@ per_v_random_select_transform_e(raft::handle_t const& handle, if (minor_comm_size > 1) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); device_allgatherv(minor_comm, - frontier_key_first, - get_dataframe_buffer_begin(*aggregate_local_frontier_keys), + frontier.begin(), + get_dataframe_buffer_begin(*aggregate_local_frontier), local_frontier_sizes, local_frontier_displacements, handle.get_stream()); } - // 2. compute degrees - - auto edge_mask_view = graph_view.edge_mask_view(); - - auto aggregate_local_frontier_local_degrees = - (minor_comm_size > 1) - ? std::make_optional>( - local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) - : std::nullopt; - rmm::device_uvector frontier_degrees(frontier.size(), handle.get_stream()); - - std::optional, rmm::device_uvector>>> - local_frontier_valid_local_nbr_count_inclusive_sums{}; // to avoid searching the entire - // neighbor list K times for high degree - // vertices with edge masking - if (edge_mask_view) { - local_frontier_valid_local_nbr_count_inclusive_sums = - std::vector, rmm::device_uvector>>{}; - (*local_frontier_valid_local_nbr_count_inclusive_sums) - .reserve(graph_view.number_of_local_edge_partitions()); - } - + // 2. randomly select neighbor indices and compute local neighbor indices for every local edge + // partition + + auto [sample_local_nbr_indices, sample_key_indices, local_frontier_sample_offsets] = + uniform_sample_and_compute_local_nbr_indices(handle, + graph_view, + frontier, + aggregate_local_frontier, + local_frontier_displacements, + local_frontier_sizes, + rng_state, + K, + with_replacement); + + std::vector local_frontier_sample_counts(minor_comm_size); + std::adjacent_difference(local_frontier_sample_offsets.begin() + 1, + local_frontier_sample_offsets.end(), + local_frontier_sample_counts.begin()); + + // 3. transform + + auto sample_e_op_results = + allocate_dataframe_buffer(local_frontier_sample_offsets.back(), handle.get_stream()); for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { auto edge_partition = edge_partition_device_view_t( graph_view.local_edge_partition_view(i)); - auto edge_partition_e_mask = - edge_mask_view - ? thrust::make_optional< - detail::edge_partition_edge_property_device_view_t>( - *edge_mask_view, i) - : thrust::nullopt; - - vertex_t const* edge_partition_frontier_major_first{nullptr}; auto edge_partition_frontier_key_first = - ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier_keys) - : frontier_key_first) + - local_frontier_displacements[i]; - if constexpr (std::is_same_v) { - edge_partition_frontier_major_first = edge_partition_frontier_key_first; - } else { - edge_partition_frontier_major_first = thrust::get<0>(edge_partition_frontier_key_first); - } - - auto edge_partition_frontier_local_degrees = - edge_partition_e_mask ? edge_partition.compute_local_degrees_with_mask( - (*edge_partition_e_mask).value_first(), - edge_partition_frontier_major_first, - edge_partition_frontier_major_first + local_frontier_sizes[i], - handle.get_stream()) - : edge_partition.compute_local_degrees( - edge_partition_frontier_major_first, - edge_partition_frontier_major_first + local_frontier_sizes[i], - handle.get_stream()); - - if (minor_comm_size > 1) { - // FIXME: this copy is unnecessary if edge_partition.compute_local_degrees() takes a pointer - // to the output array - thrust::copy( - handle.get_thrust_policy(), - edge_partition_frontier_local_degrees.begin(), - edge_partition_frontier_local_degrees.end(), - (*aggregate_local_frontier_local_degrees).begin() + local_frontier_displacements[i]); - } else { - frontier_degrees = std::move(edge_partition_frontier_local_degrees); - } - - if (edge_partition_e_mask) { - (*local_frontier_valid_local_nbr_count_inclusive_sums) - .push_back(compute_valid_local_nbr_count_inclusive_sums( - handle, - edge_partition, - *edge_partition_e_mask, - raft::device_span(edge_partition_frontier_major_first, - local_frontier_sizes[i]))); - } - } - - auto frontier_partitioned_local_degree_displacements = - (minor_comm_size > 1) - ? std::make_optional>(size_t{0}, handle.get_stream()) - : std::nullopt; // one partition per gpu in the same minor_comm - if (minor_comm_size > 1) { - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - - rmm::device_uvector frontier_gathered_local_degrees(0, handle.get_stream()); - std::tie(frontier_gathered_local_degrees, std::ignore) = - shuffle_values(minor_comm, - (*aggregate_local_frontier_local_degrees).begin(), - local_frontier_sizes, - handle.get_stream()); - aggregate_local_frontier_local_degrees = std::nullopt; - frontier_partitioned_local_degree_displacements = - rmm::device_uvector(frontier_degrees.size() * minor_comm_size, handle.get_stream()); - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(frontier_degrees.size()), - compute_local_degree_displacements_and_global_degree_t{ - raft::device_span(frontier_gathered_local_degrees.data(), - frontier_gathered_local_degrees.size()), - raft::device_span((*frontier_partitioned_local_degree_displacements).data(), - (*frontier_partitioned_local_degree_displacements).size()), - raft::device_span(frontier_degrees.data(), frontier_degrees.size()), - minor_comm_size}); - } - - // 3. randomly select neighbor indices - - rmm::device_uvector sample_nbr_indices(0, handle.get_stream()); - if (with_replacement) { - if (frontier_degrees.size() > 0) { - sample_nbr_indices.resize(frontier.size() * K, handle.get_stream()); - cugraph::ops::graph::get_sampling_index(sample_nbr_indices.data(), - rng_state, - frontier_degrees.data(), - static_cast(frontier_degrees.size()), - static_cast(K), - with_replacement, - handle.get_stream()); - frontier_degrees.resize(0, handle.get_stream()); - frontier_degrees.shrink_to_fit(handle.get_stream()); - } - } else { - sample_nbr_indices = - get_sampling_index_without_replacement(handle, std::move(frontier_degrees), rng_state, K); - } - - // 4. shuffle randomly selected indices - - auto sample_local_nbr_indices = std::move( - sample_nbr_indices); // neighbor index within an edge partition (note that each vertex's - // neighbors are distributed in minor_comm_size partitions) - std::optional> sample_key_indices{ - std::nullopt}; // relevant only when (minor_comm_size > 1) - auto local_frontier_sample_counts = std::vector{}; - auto local_frontier_sample_displacements = std::vector{}; - if (minor_comm_size > 1) { - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - - sample_key_indices = - rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); - auto minor_comm_ranks = - rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); - auto intra_partition_displacements = - rmm::device_uvector(sample_local_nbr_indices.size(), handle.get_stream()); - rmm::device_uvector d_tx_counts(minor_comm_size, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), size_t{0}); - auto input_pair_first = thrust::make_zip_iterator( - thrust::make_tuple(sample_local_nbr_indices.begin(), - thrust::make_transform_iterator(thrust::make_counting_iterator(size_t{0}), - divider_t{K}))); - thrust::transform( - handle.get_thrust_policy(), - input_pair_first, - input_pair_first + sample_local_nbr_indices.size(), - thrust::make_zip_iterator(thrust::make_tuple(minor_comm_ranks.begin(), - intra_partition_displacements.begin(), - sample_local_nbr_indices.begin(), - (*sample_key_indices).begin())), - convert_pair_to_quadruplet_t{ - raft::device_span((*frontier_partitioned_local_degree_displacements).data(), - (*frontier_partitioned_local_degree_displacements).size()), - raft::device_span(d_tx_counts.data(), d_tx_counts.size()), - frontier.size(), - minor_comm_size, - cugraph::ops::graph::INVALID_ID}); - rmm::device_uvector tx_displacements(minor_comm_size, handle.get_stream()); - thrust::exclusive_scan( - handle.get_thrust_policy(), d_tx_counts.begin(), d_tx_counts.end(), tx_displacements.begin()); - auto tmp_sample_local_nbr_indices = - rmm::device_uvector(tx_displacements.back_element(handle.get_stream()) + - d_tx_counts.back_element(handle.get_stream()), - handle.get_stream()); - auto tmp_sample_key_indices = - rmm::device_uvector(tmp_sample_local_nbr_indices.size(), handle.get_stream()); - auto pair_first = thrust::make_zip_iterator( - thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); - thrust::scatter_if( - handle.get_thrust_policy(), - pair_first, - pair_first + sample_local_nbr_indices.size(), - thrust::make_transform_iterator( - thrust::make_counting_iterator(size_t{0}), - shuffle_index_compute_offset_t{ - raft::device_span(minor_comm_ranks.data(), minor_comm_ranks.size()), - raft::device_span(intra_partition_displacements.data(), - intra_partition_displacements.size()), - raft::device_span(tx_displacements.data(), tx_displacements.size())}), - minor_comm_ranks.begin(), - thrust::make_zip_iterator( - thrust::make_tuple(tmp_sample_local_nbr_indices.begin(), tmp_sample_key_indices.begin())), - is_not_equal_t{-1}); - - sample_local_nbr_indices = std::move(tmp_sample_local_nbr_indices); - sample_key_indices = std::move(tmp_sample_key_indices); - - std::vector h_tx_counts(d_tx_counts.size()); - raft::update_host( - h_tx_counts.data(), d_tx_counts.data(), d_tx_counts.size(), handle.get_stream()); - handle.sync_stream(); - - pair_first = thrust::make_zip_iterator( - thrust::make_tuple(sample_local_nbr_indices.begin(), (*sample_key_indices).begin())); - auto [rx_value_buffer, rx_counts] = - shuffle_values(minor_comm, pair_first, h_tx_counts, handle.get_stream()); - - sample_local_nbr_indices = std::move(std::get<0>(rx_value_buffer)); - sample_key_indices = std::move(std::get<1>(rx_value_buffer)); - local_frontier_sample_displacements = std::vector(rx_counts.size()); - std::exclusive_scan( - rx_counts.begin(), rx_counts.end(), local_frontier_sample_displacements.begin(), size_t{0}); - local_frontier_sample_counts = std::move(rx_counts); - } else { - local_frontier_sample_counts.push_back(frontier.size() * K); - local_frontier_sample_displacements.push_back(size_t{0}); - } - - // 5. transform - - auto sample_e_op_results = allocate_dataframe_buffer( - local_frontier_sample_displacements.back() + local_frontier_sample_counts.back(), - handle.get_stream()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { - auto edge_partition = - edge_partition_device_view_t( - graph_view.local_edge_partition_view(i)); - auto edge_partition_e_mask = - edge_mask_view - ? thrust::make_optional< - detail::edge_partition_edge_property_device_view_t>( - *edge_mask_view, i) - : thrust::nullopt; - - auto edge_partition_frontier_key_first = - ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier_keys) - : frontier_key_first) + + ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier) + : frontier.begin()) + local_frontier_displacements[i]; auto edge_partition_sample_local_nbr_index_first = - sample_local_nbr_indices.begin() + local_frontier_sample_displacements[i]; + sample_local_nbr_indices.begin() + local_frontier_sample_offsets[i]; auto edge_partition_sample_e_op_result_first = - get_dataframe_buffer_begin(sample_e_op_results) + local_frontier_sample_displacements[i]; + get_dataframe_buffer_begin(sample_e_op_results) + local_frontier_sample_offsets[i]; edge_partition_src_input_device_view_t edge_partition_src_value_input{}; edge_partition_dst_input_device_view_t edge_partition_dst_value_input{}; @@ -1341,85 +380,60 @@ per_v_random_select_transform_e(raft::handle_t const& handle, if (minor_comm_size > 1) { auto edge_partition_sample_key_index_first = - (*sample_key_indices).begin() + local_frontier_sample_displacements[i]; - thrust::for_each( + (*sample_key_indices).begin() + local_frontier_sample_offsets[i]; + thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(local_frontier_sample_counts[i]), + edge_partition_sample_e_op_result_first, transform_local_nbr_indices_t{ edge_partition, thrust::make_optional(edge_partition_sample_key_index_first), edge_partition_frontier_key_first, edge_partition_sample_local_nbr_index_first, - edge_partition_sample_e_op_result_first, edge_partition_src_value_input, edge_partition_dst_value_input, edge_partition_e_value_input, - edge_partition_e_mask, - local_frontier_valid_local_nbr_count_inclusive_sums - ? thrust::make_optional(thrust::make_tuple( - raft::device_span( - std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), - std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()), - raft::device_span( - std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), - std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()))) - : thrust::nullopt, e_op, cugraph::ops::graph::INVALID_ID, to_thrust_optional(invalid_value), K}); } else { - thrust::for_each( + thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(frontier.size() * K), + edge_partition_sample_e_op_result_first, transform_local_nbr_indices_t{ - edge_partition, - thrust::nullopt, - edge_partition_frontier_key_first, - edge_partition_sample_local_nbr_index_first, - edge_partition_sample_e_op_result_first, - edge_partition_src_value_input, - edge_partition_dst_value_input, - edge_partition_e_value_input, - edge_partition_e_mask, - local_frontier_valid_local_nbr_count_inclusive_sums - ? thrust::make_optional(thrust::make_tuple( - raft::device_span( - std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), - std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()), - raft::device_span( - std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), - std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()))) - : thrust::nullopt, - e_op, - cugraph::ops::graph::INVALID_ID, - to_thrust_optional(invalid_value), - K}); + T>{edge_partition, + thrust::nullopt, + edge_partition_frontier_key_first, + edge_partition_sample_local_nbr_index_first, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + e_op, + cugraph::ops::graph::INVALID_ID, + to_thrust_optional(invalid_value), + K}); } } - aggregate_local_frontier_keys = std::nullopt; + aggregate_local_frontier = std::nullopt; - // 6. shuffle randomly selected & transformed results and update sample_offsets + // 4. shuffle randomly selected & transformed results and update sample_offsets auto sample_offsets = invalid_value ? std::nullopt : std::make_optional>( @@ -1542,6 +556,8 @@ per_v_random_select_transform_e(raft::handle_t const& handle, return std::make_tuple(std::move(sample_offsets), std::move(sample_e_op_results)); #else CUGRAPH_FAIL("unimplemented."); + return std::make_tuple(std::nullopt, + allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); #endif } @@ -1579,12 +595,12 @@ per_v_random_select_transform_e(raft::handle_t const& handle, * to this process in multi-GPU). Use either cugraph::edge_property_t::view() (if @p e_op needs to * access edge property values) or cugraph::edge_dummy_property_t::view() (if @p e_op does not * access edge property values). - * @param e_bias_op Quinary operator takes edge source, edge destination, property values for the - * source, destination, and edge and returns a floating point bias value to be used in biased random - * selection. - * @param e_op Quinary operator takes edge source, edge destination, property values for the source, - * destination, and edge and returns a value to be collected in the output. This function is called - * only for the selected edges. + * @param e_bias_op Quinary operator takes (tagged-)edge source, edge destination, property values + * for the source, destination, and edge and returns a floating point bias value to be used in + * biased random selection. + * @param e_op Quinary operator takes (tagged-)edge source, edge destination, property values for + * the source, destination, and edge and returns a value to be collected in the output. This + * function is called only for the selected edges. * @param K Number of outgoing edges to select per (tagged-)vertex. * @param with_replacement A flag to specify whether a single outgoing edge can be selected multiple * times (if @p with_replacement = true) or can be selected only once (if @p with_replacement = @@ -1616,7 +632,7 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, VertexFrontierBucketType const& frontier, EdgeSrcValueInputWrapper edge_src_value_input, EdgeDstValueInputWrapper edge_dst_value_input, - EdgeValueInputWrapper egde_value_input, + EdgeValueInputWrapper edge_value_input, EdgeBiasOp e_bias_op, EdgeOp e_op, raft::random::RngState& rng_state, @@ -1625,10 +641,19 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, std::optional invalid_value, bool do_expensive_check = false) { - CUGRAPH_FAIL("unimplemented."); - - return std::make_tuple(std::nullopt, - allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); + return detail::per_v_random_select_transform_e(handle, + graph_view, + frontier, + edge_src_value_input, + edge_dst_value_input, + edge_value_input, + e_bias_op, + e_op, + rng_state, + K, + with_replacement, + invalid_value, + do_expensive_check); } /** @@ -1664,9 +689,9 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, * to this process in multi-GPU). Use either cugraph::edge_property_t::view() (if @p e_op needs to * access edge property values) or cugraph::edge_dummy_property_t::view() (if @p e_op does not * access edge property values). - * @param e_op Quinary operator takes edge source, edge destination, property values for the source, - * destination, and edge and returns a value to be collected in the output. This function is called - * only for the selected edges. + * @param e_op Quinary operator takes (tagged-)edge source, edge destination, property values for + * the source, destination, and edge and returns a value to be collected in the output. This + * function is called only for the selected edges. * @param K Number of outgoing edges to select per (tagged-)vertex. * @param with_replacement A flag to specify whether a single outgoing edge can be selected multiple * times (if @p with_replacement = true) or can be selected only once (if @p with_replacement = @@ -1705,18 +730,24 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, std::optional invalid_value, bool do_expensive_check = false) { - return detail::per_v_random_select_transform_e(handle, - graph_view, - frontier, - edge_src_value_input, - edge_dst_value_input, - edge_value_input, - e_op, - rng_state, - K, - with_replacement, - invalid_value, - do_expensive_check); + return detail::per_v_random_select_transform_e( + handle, + graph_view, + frontier, + edge_src_value_input, + edge_dst_value_input, + edge_value_input, + detail::constant_e_bias_op_t{}, + e_op, + rng_state, + K, + with_replacement, + invalid_value, + do_expensive_check); } } // namespace cugraph From 83abe0e91b571271659f2fb524e4fcab57825efa Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Wed, 1 May 2024 14:39:11 -0700 Subject: [PATCH 03/53] add the thrust_tuple_get_or_identity utility function --- .../cugraph/utilities/thrust_tuple_utils.hpp | 60 +++++++++++++++++-- .../detail/extract_transform_v_frontier_e.cuh | 57 +++++------------- cpp/src/prims/detail/prim_functors.cuh | 7 +-- ...rm_reduce_v_frontier_outgoing_e_by_dst.cuh | 17 ++---- cpp/src/prims/update_v_frontier.cuh | 16 ++--- 5 files changed, 80 insertions(+), 77 deletions(-) diff --git a/cpp/include/cugraph/utilities/thrust_tuple_utils.hpp b/cpp/include/cugraph/utilities/thrust_tuple_utils.hpp index d98754f51d1..c69183085ba 100644 --- a/cpp/include/cugraph/utilities/thrust_tuple_utils.hpp +++ b/cpp/include/cugraph/utilities/thrust_tuple_utils.hpp @@ -17,6 +17,7 @@ #include +#include #include #include @@ -30,7 +31,7 @@ template struct is_thrust_tuple_of_arithemetic_impl { constexpr bool evaluate() const { - if (!std::is_arithmetic::type>::value) { + if (!std::is_arithmetic_v::type>) { return false; } else { return is_thrust_tuple_of_arithemetic_impl().evaluate(); @@ -123,19 +124,19 @@ struct is_arithmetic_vector : std::false_type {}; template