Skip to content

Commit

Permalink
Implement weighted random walks
Browse files Browse the repository at this point in the history
This commit implements weighted biased random walks as in the original
Node2vec paper. In particular, it adds a new parameter to the `random_walk`
function, i.e., `edge_weight`, which allows passing edge weights to the
underlying random walk generation procedure. If edge weights are set,
the function normalizes them by the node degree and converts the weights
into CDFs over given nodes (needed by the rejection sampling method).
The implementation of the new rejection sampling method is based on [1].

[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69

* Update `random_walk` API
* Implement weighted rejection sampling on CPU
* Implement weighted random walk for GPU (CUDA)
* Compute CDFs using C++/CUDA
* Add tests for weighted random walks
  • Loading branch information
pbielak committed Aug 10, 2022
1 parent cc4696b commit c4154c8
Show file tree
Hide file tree
Showing 7 changed files with 407 additions and 4 deletions.
165 changes: 165 additions & 0 deletions csrc/cpu/rw_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,168 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,

return std::make_tuple(n_out, e_out);
}


void compute_cdf(const int64_t *rowptr, const float_t *edge_weight,
float_t *edge_weight_cdf, int64_t numel) {
/* Convert edge weights to CDF as given in [1]
[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148
*/
at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
for(int64_t i = begin; i < end; i++) {
int64_t row_start = rowptr[i], row_end = rowptr[i + 1];
float_t acc = 0.0;

for(int64_t j = row_start; j < row_end; j++) {
acc += edge_weight[j];
edge_weight_cdf[j] = acc;
}
}
});
}


int64_t get_offset(const float_t *edge_weight, int64_t start, int64_t end) {
/*
The implementation given in [1] utilizes the `searchsorted` function in Numpy.
It is also available in PyTorch and its C++ API (via `at::searchsorted()`).
However, the implementation is adopted to the general case where the searched
values can be a multidimensional tensor. In our case, we have a 1D tensor of
edge weights (in form of a Cumulative Distribution Function) and a single
value, whose position we want to compute. To eliminate the overhead introduced
in the PyTorch implementation, one can examine the source code of
`searchsorted` [2] and find that for our case the whole function call can be
reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot
access it directly (the namespace is not exposed to the public API), but the
implementation is just a simple binary search. The code was copied here and
reduced to the bare minimum.
[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
[2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp
*/
float_t value = ((float_t)rand() / RAND_MAX); // [0, 1)
int64_t original_start = start;

while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const float_t mid_val = edge_weight[mid];
if (!(mid_val >= value)) {
start = mid + 1;
}
else {
end = mid;
}
}

return start - original_start;
}

// See: https://louisabraham.github.io/articles/node2vec-sampling.html
// See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
void rejection_sampling_weighted(const int64_t *rowptr, const int64_t *col,
const float_t *edge_weight_cdf, int64_t *start,
int64_t *n_out, int64_t *e_out,
const int64_t numel, const int64_t walk_length,
const double p, const double q) {

double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;

int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
for (auto n = begin; n < end; n++) {
int64_t t = start[n], v, x, e_cur, row_start, row_end;

n_out[n * (walk_length + 1)] = t;

row_start = rowptr[t], row_end = rowptr[t + 1];

if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
v = col[e_cur];
}
n_out[n * (walk_length + 1) + 1] = v;
e_out[n * walk_length] = e_cur;

for (auto l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];

if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
if (p == 1. && q == 1.) {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
x = col[e_cur];
}
else {
while (true) {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
x = col[e_cur];

auto r = ((double)rand() / (RAND_MAX)); // [0, 1)

if (x == t && r < prob_0)
break;
else if (is_neighbor(rowptr, col, x, t) && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
}

n_out[n * (walk_length + 1) + (l + 1)] = x;
e_out[n * walk_length + l] = e_cur;
t = v;
v = x;
}
}
});
}


std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(edge_weight);
CHECK_CPU(start);

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(edge_weight.dim() == 1);
CHECK_INPUT(start.dim() == 1);

auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
auto e_out = torch::empty({start.size(0), walk_length}, start.options());

auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto edge_weight_data = edge_weight.data_ptr<float_t>();
auto start_data = start.data_ptr<int64_t>();
auto n_out_data = n_out.data_ptr<int64_t>();
auto e_out_data = e_out.data_ptr<int64_t>();

auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());
auto edge_weight_cdf_data = edge_weight_cdf.data_ptr<float_t>();

compute_cdf(rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel());

rejection_sampling_weighted(rowptr_data, col_data, edge_weight_cdf_data,
start_data, n_out_data, e_out_data, start.numel(),
walk_length, p, q);

return std::make_tuple(n_out, e_out);
}
5 changes: 5 additions & 0 deletions csrc/cpu/rw_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q);
160 changes: 160 additions & 0 deletions csrc/cuda/rw_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,163 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,

return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
}


__global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight,
float_t *edge_weight_cdf, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel - 1) {
int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1];

float_t acc = 0.0;

for(int64_t i = row_start; i < row_end; i++) {
acc += edge_weight[i];
edge_weight_cdf[i] = acc;
}
}
}

__device__ void get_offset(const float_t *edge_weight, int64_t start, int64_t end,
float_t value, int64_t *position_out) {
int64_t original_start = start;

while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const float_t mid_val = edge_weight[mid];
if (!(mid_val >= value)) {
start = mid + 1;
}
else {
end = mid;
}
}

*position_out = start - original_start;
}

__global__ void
rejection_sampling_weighted_kernel(unsigned int seed, const int64_t *rowptr,
const int64_t *col, const float_t *edge_weight_cdf,
const int64_t *start, int64_t *n_out,
int64_t *e_out, const int64_t walk_length,
const int64_t numel, const double p,
const double q) {

curandState_t state;
curand_init(seed, 0, 0, &state);

double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;

const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel) {
int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end, offset;

n_out[thread_idx] = t;

row_start = rowptr[t], row_end = rowptr[t + 1];

if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
v = col[e_cur];
}

n_out[numel + thread_idx] = v;
e_out[thread_idx] = e_cur;

for (int64_t l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];

if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
if (p == 1. && q == 1.) {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
x = col[e_cur];
}
else {
while (true) {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
x = col[e_cur];

double r = curand_uniform(&state); // (0, 1]

if (x == t && r < prob_0)
break;

bool is_neighbor = false;
row_start = rowptr[x], row_end = rowptr[x + 1];
for (int64_t i = row_start; i < row_end; i++) {
if (col[i] == t) {
is_neighbor = true;
break;
}
}

if (is_neighbor && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
}

n_out[(l + 1) * numel + thread_idx] = x;
e_out[l * numel + thread_idx] = e_cur;
t = v;
v = x;
}
}
}


std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(edge_weight);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(edge_weight.dim() == 1);
CHECK_INPUT(start.dim() == 1);

auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto e_out = torch::empty({walk_length, start.size(0)}, start.options());

auto stream = at::cuda::getCurrentCUDAStream();

auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());

cdf_kernel<<<BLOCKS(rowptr.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), edge_weight.data_ptr<float_t>(),
edge_weight_cdf.data_ptr<float_t>(), rowptr.numel());

rejection_sampling_weighted_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
time(NULL), rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
edge_weight_cdf.data_ptr<float_t>(), start.data_ptr<int64_t>(),
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(),
walk_length, start.numel(), p, q);

return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
}

5 changes: 5 additions & 0 deletions csrc/cuda/rw_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q);
18 changes: 17 additions & 1 deletion csrc/rw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,21 @@ random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
}
}

CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_weighted_cuda(rowptr, col, edge_weight, start, walk_length, p, q);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_weighted_cpu(rowptr, col, edge_weight, start, walk_length, p, q);
}
}

static auto registry =
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk);
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk)
.op("torch_cluster::random_walk_weighted", &random_walk_weighted);
Loading

0 comments on commit c4154c8

Please sign in to comment.