diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index 5b05b54201..cc4c09ead2 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -875,19 +875,20 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { handle_.get_comms(), num_local_labels, handle_.get_stream()); std::exclusive_scan( - global_labels.begin(), global_labels.end(), global_labels.begin(), size_t{0}); - - // Compute the global start_vertex_label_offsets - cugraph::detail::transform_increment(handle_.get_stream(), - start_vertex_offsets_->as_type(), - start_vertex_offsets_->size_, - global_labels[handle_.get_comms().get_rank()]); + global_labels.begin(), global_labels.end(), global_labels.begin(), label_t{0}); // Retrieve the start_vertex_labels start_vertex_labels = cugraph::detail::convert_starting_vertex_offsets_to_labels( handle_, raft::device_span{start_vertex_offsets_->as_type(), start_vertex_offsets_->size_}); + + // Compute the global start_vertex_label_offsets + cugraph::detail::transform_increment(handle_.get_stream(), + (label_t*)(*start_vertex_labels).data(), + (size_t)(*start_vertex_labels).size(), + (label_t)global_labels[handle_.get_comms().get_rank()] + ); } if constexpr (multi_gpu) {