From c4154c8ac772ee5e8c6bb3b36e7d736bec04e9a3 Mon Sep 17 00:00:00 2001 From: Piotr Bielak Date: Tue, 9 Aug 2022 16:32:09 +0200 Subject: [PATCH 1/2] Implement weighted random walks This commit implements weighted biased random walks as in the original Node2vec paper. In particular, it adds a new parameter to the `random_walk` function, i.e., `edge_weight`, which allows passing edge weights to the underlying random walk generation procedure. If edge weights are set, the function normalizes them by the node degree and converts the weights into CDFs over given nodes (needed by the rejection sampling method). The implementation of the new rejection sampling method is based on [1]. [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69 * Update `random_walk` API * Implement weighted rejection sampling on CPU * Implement weighted random walk for GPU (CUDA) * Compute CDFs using C++/CUDA * Add tests for weighted random walks --- csrc/cpu/rw_cpu.cpp | 165 +++++++++++++++++++++++++++++++++++++++++++ csrc/cpu/rw_cpu.h | 5 ++ csrc/cuda/rw_cuda.cu | 160 +++++++++++++++++++++++++++++++++++++++++ csrc/cuda/rw_cuda.h | 5 ++ csrc/rw.cpp | 18 ++++- test/test_rw.py | 41 +++++++++++ torch_cluster/rw.py | 17 ++++- 7 files changed, 407 insertions(+), 4 deletions(-) diff --git a/csrc/cpu/rw_cpu.cpp b/csrc/cpu/rw_cpu.cpp index c26b28ed..4b5e63a4 100644 --- a/csrc/cpu/rw_cpu.cpp +++ b/csrc/cpu/rw_cpu.cpp @@ -137,3 +137,168 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, return std::make_tuple(n_out, e_out); } + + +void compute_cdf(const int64_t *rowptr, const float_t *edge_weight, + float_t *edge_weight_cdf, int64_t numel) { + /* Convert edge weights to CDF as given in [1] + + [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148 + */ + at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) { + for(int64_t i = begin; i < end; i++) { + int64_t row_start = rowptr[i], row_end = rowptr[i + 1]; + float_t acc = 0.0; + + for(int64_t j = row_start; j < row_end; j++) { + acc += edge_weight[j]; + edge_weight_cdf[j] = acc; + } + } + }); +} + + +int64_t get_offset(const float_t *edge_weight, int64_t start, int64_t end) { + /* + The implementation given in [1] utilizes the `searchsorted` function in Numpy. + It is also available in PyTorch and its C++ API (via `at::searchsorted()`). + However, the implementation is adopted to the general case where the searched + values can be a multidimensional tensor. In our case, we have a 1D tensor of + edge weights (in form of a Cumulative Distribution Function) and a single + value, whose position we want to compute. To eliminate the overhead introduced + in the PyTorch implementation, one can examine the source code of + `searchsorted` [2] and find that for our case the whole function call can be + reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot + access it directly (the namespace is not exposed to the public API), but the + implementation is just a simple binary search. The code was copied here and + reduced to the bare minimum. + + [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69 + [2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp + */ + float_t value = ((float_t)rand() / RAND_MAX); // [0, 1) + int64_t original_start = start; + + while (start < end) { + const int64_t mid = start + ((end - start) >> 1); + const float_t mid_val = edge_weight[mid]; + if (!(mid_val >= value)) { + start = mid + 1; + } + else { + end = mid; + } + } + + return start - original_start; +} + +// See: https://louisabraham.github.io/articles/node2vec-sampling.html +// See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69 +void rejection_sampling_weighted(const int64_t *rowptr, const int64_t *col, + const float_t *edge_weight_cdf, int64_t *start, + int64_t *n_out, int64_t *e_out, + const int64_t numel, const int64_t walk_length, + const double p, const double q) { + + double max_prob = fmax(fmax(1. / p, 1.), 1. / q); + double prob_0 = 1. / p / max_prob; + double prob_1 = 1. / max_prob; + double prob_2 = 1. / q / max_prob; + + int64_t grain_size = at::internal::GRAIN_SIZE / walk_length; + at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) { + for (auto n = begin; n < end; n++) { + int64_t t = start[n], v, x, e_cur, row_start, row_end; + + n_out[n * (walk_length + 1)] = t; + + row_start = rowptr[t], row_end = rowptr[t + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + v = t; + } else { + e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end); + v = col[e_cur]; + } + n_out[n * (walk_length + 1) + 1] = v; + e_out[n * walk_length] = e_cur; + + for (auto l = 1; l < walk_length; l++) { + row_start = rowptr[v], row_end = rowptr[v + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + x = v; + } else if (row_end - row_start == 1) { + e_cur = row_start; + x = col[e_cur]; + } else { + if (p == 1. && q == 1.) { + e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end); + x = col[e_cur]; + } + else { + while (true) { + e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end); + x = col[e_cur]; + + auto r = ((double)rand() / (RAND_MAX)); // [0, 1) + + if (x == t && r < prob_0) + break; + else if (is_neighbor(rowptr, col, x, t) && r < prob_1) + break; + else if (r < prob_2) + break; + } + } + } + + n_out[n * (walk_length + 1) + (l + 1)] = x; + e_out[n * walk_length + l] = e_cur; + t = v; + v = x; + } + } + }); +} + + +std::tuple +random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q) { + CHECK_CPU(rowptr); + CHECK_CPU(col); + CHECK_CPU(edge_weight); + CHECK_CPU(start); + + CHECK_INPUT(rowptr.dim() == 1); + CHECK_INPUT(col.dim() == 1); + CHECK_INPUT(edge_weight.dim() == 1); + CHECK_INPUT(start.dim() == 1); + + auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options()); + auto e_out = torch::empty({start.size(0), walk_length}, start.options()); + + auto rowptr_data = rowptr.data_ptr(); + auto col_data = col.data_ptr(); + auto edge_weight_data = edge_weight.data_ptr(); + auto start_data = start.data_ptr(); + auto n_out_data = n_out.data_ptr(); + auto e_out_data = e_out.data_ptr(); + + auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options()); + auto edge_weight_cdf_data = edge_weight_cdf.data_ptr(); + + compute_cdf(rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel()); + + rejection_sampling_weighted(rowptr_data, col_data, edge_weight_cdf_data, + start_data, n_out_data, e_out_data, start.numel(), + walk_length, p, q); + + return std::make_tuple(n_out, e_out); +} diff --git a/csrc/cpu/rw_cpu.h b/csrc/cpu/rw_cpu.h index 6ea9fcd6..8d0c1085 100644 --- a/csrc/cpu/rw_cpu.h +++ b/csrc/cpu/rw_cpu.h @@ -5,3 +5,8 @@ std::tuple random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, int64_t walk_length, double p, double q); + +std::tuple +random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q); diff --git a/csrc/cuda/rw_cuda.cu b/csrc/cuda/rw_cuda.cu index 763b861f..5848b00b 100644 --- a/csrc/cuda/rw_cuda.cu +++ b/csrc/cuda/rw_cuda.cu @@ -150,3 +150,163 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous()); } + + +__global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight, + float_t *edge_weight_cdf, int64_t numel) { + const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_idx < numel - 1) { + int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1]; + + float_t acc = 0.0; + + for(int64_t i = row_start; i < row_end; i++) { + acc += edge_weight[i]; + edge_weight_cdf[i] = acc; + } + } +} + +__device__ void get_offset(const float_t *edge_weight, int64_t start, int64_t end, + float_t value, int64_t *position_out) { + int64_t original_start = start; + + while (start < end) { + const int64_t mid = start + ((end - start) >> 1); + const float_t mid_val = edge_weight[mid]; + if (!(mid_val >= value)) { + start = mid + 1; + } + else { + end = mid; + } + } + + *position_out = start - original_start; +} + +__global__ void +rejection_sampling_weighted_kernel(unsigned int seed, const int64_t *rowptr, + const int64_t *col, const float_t *edge_weight_cdf, + const int64_t *start, int64_t *n_out, + int64_t *e_out, const int64_t walk_length, + const int64_t numel, const double p, + const double q) { + + curandState_t state; + curand_init(seed, 0, 0, &state); + + double max_prob = fmax(fmax(1. / p, 1.), 1. / q); + double prob_0 = 1. / p / max_prob; + double prob_1 = 1. / max_prob; + double prob_2 = 1. / q / max_prob; + + const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_idx < numel) { + int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end, offset; + + n_out[thread_idx] = t; + + row_start = rowptr[t], row_end = rowptr[t + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + v = t; + } else { + get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset); + e_cur = row_start + offset; + v = col[e_cur]; + } + + n_out[numel + thread_idx] = v; + e_out[thread_idx] = e_cur; + + for (int64_t l = 1; l < walk_length; l++) { + row_start = rowptr[v], row_end = rowptr[v + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + x = v; + } else if (row_end - row_start == 1) { + e_cur = row_start; + x = col[e_cur]; + } else { + if (p == 1. && q == 1.) { + get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset); + e_cur = row_start + offset; + x = col[e_cur]; + } + else { + while (true) { + get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset); + e_cur = row_start + offset; + x = col[e_cur]; + + double r = curand_uniform(&state); // (0, 1] + + if (x == t && r < prob_0) + break; + + bool is_neighbor = false; + row_start = rowptr[x], row_end = rowptr[x + 1]; + for (int64_t i = row_start; i < row_end; i++) { + if (col[i] == t) { + is_neighbor = true; + break; + } + } + + if (is_neighbor && r < prob_1) + break; + else if (r < prob_2) + break; + } + } + } + + n_out[(l + 1) * numel + thread_idx] = x; + e_out[l * numel + thread_idx] = e_cur; + t = v; + v = x; + } + } +} + + +std::tuple +random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q) { + CHECK_CUDA(rowptr); + CHECK_CUDA(col); + CHECK_CUDA(edge_weight); + CHECK_CUDA(start); + cudaSetDevice(rowptr.get_device()); + + CHECK_INPUT(rowptr.dim() == 1); + CHECK_INPUT(col.dim() == 1); + CHECK_INPUT(edge_weight.dim() == 1); + CHECK_INPUT(start.dim() == 1); + + auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options()); + auto e_out = torch::empty({walk_length, start.size(0)}, start.options()); + + auto stream = at::cuda::getCurrentCUDAStream(); + + auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options()); + + cdf_kernel<<>>( + rowptr.data_ptr(), edge_weight.data_ptr(), + edge_weight_cdf.data_ptr(), rowptr.numel()); + + rejection_sampling_weighted_kernel<<>>( + time(NULL), rowptr.data_ptr(), col.data_ptr(), + edge_weight_cdf.data_ptr(), start.data_ptr(), + n_out.data_ptr(), e_out.data_ptr(), + walk_length, start.numel(), p, q); + + return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous()); +} + diff --git a/csrc/cuda/rw_cuda.h b/csrc/cuda/rw_cuda.h index 79f4139f..b3919d83 100644 --- a/csrc/cuda/rw_cuda.h +++ b/csrc/cuda/rw_cuda.h @@ -5,3 +5,8 @@ std::tuple random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, int64_t walk_length, double p, double q); + +std::tuple +random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q); diff --git a/csrc/rw.cpp b/csrc/rw.cpp index 0f8de62d..5d48646d 100644 --- a/csrc/rw.cpp +++ b/csrc/rw.cpp @@ -33,5 +33,21 @@ random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, } } +CLUSTER_API std::tuple +random_walk_weighted(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q) { + if (rowptr.device().is_cuda()) { +#ifdef WITH_CUDA + return random_walk_weighted_cuda(rowptr, col, edge_weight, start, walk_length, p, q); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return random_walk_weighted_cpu(rowptr, col, edge_weight, start, walk_length, p, q); + } +} + static auto registry = - torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk); + torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk) + .op("torch_cluster::random_walk_weighted", &random_walk_weighted); diff --git a/test/test_rw.py b/test/test_rw.py index 0ff91a44..345d9717 100644 --- a/test/test_rw.py +++ b/test/test_rw.py @@ -77,3 +77,44 @@ def test_rw_small_with_edge_indices(device): [1, 0, 1, 0], [-1, -1, -1, -1], ] + + +@pytest.mark.parametrize('device', devices) +def test_rw_large_with_edge_weights(device): + row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) + col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) + start = tensor([0, 1, 2, 3, 4], torch.long, device) + edge_weight = tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], torch.float, device) + walk_length = 10 + + out = random_walk( + row, col, start, walk_length, + edge_weight=edge_weight, + ) + assert out[:, 0].tolist() == start.tolist() + + for n in range(start.size(0)): + cur = start[n].item() + for i in range(1, walk_length): + assert out[n, i].item() in col[row == cur].tolist() + cur = out[n, i].item() + + +@pytest.mark.parametrize('device', devices) +def test_rw_small_with_edge_weights(device): + row = tensor([0, 1], torch.long, device) + col = tensor([1, 0], torch.long, device) + start = tensor([0, 1, 2], torch.long, device) + edge_weight = tensor([1, 1], torch.float, device) + walk_length = 4 + + out = random_walk( + row, col, start, walk_length, + num_nodes=3, + edge_weight=edge_weight, + ) + assert out.tolist() == [ + [0, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + [2, 2, 2, 2, 2], + ] diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index 12e06837..672a75e2 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -15,6 +15,7 @@ def random_walk( coalesced: bool = True, num_nodes: Optional[int] = None, return_edge_indices: bool = False, + edge_weight: Optional[Tensor] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Samples random walks of length :obj:`walk_length` from all node indices in :obj:`start` in the graph given by :obj:`(row, col)` as described in the @@ -39,6 +40,8 @@ def random_walk( return_edge_indices (bool, optional): Whether to additionally return the indices of edges traversed during the random walk. (default: :obj:`False`) + edge_weight (Tensor, optional): Weights of edges given by `row` and + `col` (default: :obj:`None`) :rtype: :class:`LongTensor` """ @@ -54,9 +57,17 @@ def random_walk( rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) - node_seq, edge_seq = torch.ops.torch_cluster.random_walk( - rowptr, col, start, walk_length, p, q, - ) + if edge_weight is None: + node_seq, edge_seq = torch.ops.torch_cluster.random_walk( + rowptr, col, start, walk_length, p, q, + ) + else: + # Normalize edge weights by node degrees + edge_weight = edge_weight / deg[row] + + node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted( + rowptr, col, edge_weight, start, walk_length, p, q, + ) if return_edge_indices: return node_seq, edge_seq From 2b63087b865b1db6de7ecb0128417b38d16f059e Mon Sep 17 00:00:00 2001 From: Piotr Bielak Date: Fri, 9 Sep 2022 13:27:40 +0200 Subject: [PATCH 2/2] Fix edge weight normalization --- csrc/cpu/rw_cpu.cpp | 10 +++++++++- csrc/cuda/rw_cuda.cu | 8 +++++++- torch_cluster/rw.py | 3 --- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/csrc/cpu/rw_cpu.cpp b/csrc/cpu/rw_cpu.cpp index 4b5e63a4..5b624e83 100644 --- a/csrc/cpu/rw_cpu.cpp +++ b/csrc/cpu/rw_cpu.cpp @@ -148,10 +148,18 @@ void compute_cdf(const int64_t *rowptr, const float_t *edge_weight, at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) { for(int64_t i = begin; i < end; i++) { int64_t row_start = rowptr[i], row_end = rowptr[i + 1]; + + // Compute sum to normalize weights + float_t sum = 0.0; + + for(int64_t j = row_start; j < row_end; j++) { + sum += edge_weight[j]; + } + float_t acc = 0.0; for(int64_t j = row_start; j < row_end; j++) { - acc += edge_weight[j]; + acc += edge_weight[j] / sum; edge_weight_cdf[j] = acc; } } diff --git a/csrc/cuda/rw_cuda.cu b/csrc/cuda/rw_cuda.cu index 5848b00b..32f186b7 100644 --- a/csrc/cuda/rw_cuda.cu +++ b/csrc/cuda/rw_cuda.cu @@ -159,10 +159,16 @@ __global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight, if (thread_idx < numel - 1) { int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1]; + float_t sum = 0.0; + + for(int64_t i = row_start; i < row_end; i++) { + sum += edge_weight[i]; + } + float_t acc = 0.0; for(int64_t i = row_start; i < row_end; i++) { - acc += edge_weight[i]; + acc += edge_weight[i] / sum; edge_weight_cdf[i] = acc; } } diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index 672a75e2..7c976bc9 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -62,9 +62,6 @@ def random_walk( rowptr, col, start, walk_length, p, q, ) else: - # Normalize edge weights by node degrees - edge_weight = edge_weight / deg[row] - node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted( rowptr, col, edge_weight, start, walk_length, p, q, )