Skip to content

Commit

Permalink
Biased RW implementation and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Cornett committed Jun 24, 2024
1 parent c4e28bf commit 12af533
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 6 deletions.
78 changes: 75 additions & 3 deletions cpp/src/sampling/random_walks_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "detail/graph_partition_utils.cuh"
#include "prims/per_v_random_select_transform_outgoing_e.cuh"
#include "prims/vertex_frontier.cuh"
#include "prims/property_op_utils.cuh"

#include <cugraph/algorithms.hpp>
#include <cugraph/detail/shuffle_wrappers.hpp>
Expand All @@ -29,6 +30,7 @@
#include <cugraph/partition_manager.hpp>
#include <cugraph/utilities/host_scalar_comm.hpp>
#include <cugraph/utilities/shuffle_comm.cuh>
#include <cugraph/graph_functions.hpp>

#include <raft/core/handle.hpp>
#include <raft/random/rng.cuh>
Expand Down Expand Up @@ -70,6 +72,18 @@ struct sample_edges_op_t {
}
};

template <typename vertex_t, typename bias_t>
struct e_bias_op_t {

raft::device_span<bias_t const> vertex_weight_sum{};

__device__ bias_t
operator()(vertex_t src, vertex_t, thrust::nullopt_t, thrust::nullopt_t, bias_t weight) const
{
return weight / vertex_weight_sum[src];
}
};

template <typename weight_t>
struct uniform_selector {
raft::random::RngState rng_state_;
Expand Down Expand Up @@ -139,7 +153,9 @@ struct uniform_selector {

template <typename weight_t>
struct biased_selector {
uint64_t seed_{0};
raft::random::RngState rng_state_;

biased_selector(uint64_t seed) : rng_state_(seed) {}

template <typename GraphViewType>
std::tuple<rmm::device_uvector<typename GraphViewType::vertex_type>,
Expand All @@ -156,7 +172,63 @@ struct biased_selector {
// instead of making a decision based on the index I need to find
// upper_bound (or is it lower_bound) of the random number and
// the cumulative weight.
CUGRAPH_FAIL("biased sampling not implemented");

// Create vertex frontier
using vertex_t = typename GraphViewType::vertex_type;

using tag_t = void;

cugraph::vertex_frontier_t<vertex_t, tag_t, GraphViewType::is_multi_gpu, false> vertex_frontier(handle, 1);

vertex_frontier.bucket(0).insert(current_vertices.begin(), current_vertices.end());

// Create data structs for results
rmm::device_uvector<vertex_t> minors(0, handle.get_stream());
// Should this be optional? Necessary for biased
std::optional<rmm::device_uvector<weight_t>> weights{std::nullopt};

if (edge_weight_view) {
auto vertex_weight_sum = compute_out_weight_sums(handle, graph_view, *edge_weight_view);
auto [sample_offsets, sample_e_op_results] =
cugraph::per_v_random_select_transform_outgoing_e(
handle,
graph_view,
vertex_frontier.bucket(0),
cugraph::edge_src_dummy_property_t{}.view(),
cugraph::edge_dst_dummy_property_t{}.view(),
*edge_weight_view,
e_bias_op_t<vertex_t, weight_t>{
raft::device_span<weight_t const>(vertex_weight_sum.data(), vertex_weight_sum.size())},
sample_edges_op_t<vertex_t, weight_t>{},
rng_state_,
size_t{1},
true,
std::make_optional(
thrust::make_tuple(vertex_t{cugraph::invalid_vertex_id<vertex_t>::value}, weight_t{0.0})));
minors = std::move(std::get<0>(sample_e_op_results));
weights = std::move(std::get<1>(sample_e_op_results));
} else {
// Just uniform random walk
auto [sample_offsets, sample_e_op_results] =
cugraph::per_v_random_select_transform_outgoing_e(
handle,
graph_view,
vertex_frontier.bucket(0),
cugraph::edge_src_dummy_property_t{}.view(),
cugraph::edge_dst_dummy_property_t{}.view(),
cugraph::edge_dummy_property_t{}.view(),
sample_edges_op_t<vertex_t, void>{},
rng_state_,
size_t{1},
true,
std::make_optional(vertex_t{cugraph::invalid_vertex_id<vertex_t>::value}));

minors = std::move(sample_e_op_results);
}

// Return results
return std::make_tuple(std::move(minors), std::move(weights));

}
};

Expand Down Expand Up @@ -483,7 +555,7 @@ biased_random_walks(raft::handle_t const& handle,
std::optional<edge_property_view_t<edge_t, weight_t const*>>{edge_weight_view},
start_vertices,
max_length,
detail::biased_selector<weight_t>{(seed == 0 ? detail::get_current_time_nanoseconds() : seed)});
detail::biased_selector<weight_t>((seed == 0 ? detail::get_current_time_nanoseconds() : seed)));
}

template <typename vertex_t, typename edge_t, typename weight_t, bool multi_gpu>
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/sampling/mg_random_walks_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct BiasedRandomWalks_Usecase {
}

// FIXME: Not currently implemented
bool expect_throw() { return true; }
bool expect_throw() { return !test_weighted; }
};

struct Node2VecRandomWalks_Usecase {
Expand Down Expand Up @@ -295,7 +295,7 @@ INSTANTIATE_TEST_SUITE_P(
cugraph::test::File_Usecase("test/datasets/web-Google.mtx"),
cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"),
cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))));

INSTANTIATE_TEST_SUITE_P(
simple_test,
Tests_BiasedRandomWalks_File,
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/sampling/sg_random_walks_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct BiasedRandomWalks_Usecase {
}

// FIXME: Not currently implemented
bool expect_throw() { return true; }
bool expect_throw() { return !test_weighted; }
};

struct Node2VecRandomWalks_Usecase {
Expand Down

0 comments on commit 12af533

Please sign in to comment.