Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement weighted random walks #140

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions csrc/cpu/rw_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,176 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,

return std::make_tuple(n_out, e_out);
}


void compute_cdf(const int64_t *rowptr, const float_t *edge_weight,
float_t *edge_weight_cdf, int64_t numel) {
/* Convert edge weights to CDF as given in [1]

[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148
*/
at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
for(int64_t i = begin; i < end; i++) {
int64_t row_start = rowptr[i], row_end = rowptr[i + 1];

// Compute sum to normalize weights
float_t sum = 0.0;

for(int64_t j = row_start; j < row_end; j++) {
sum += edge_weight[j];
}

float_t acc = 0.0;

for(int64_t j = row_start; j < row_end; j++) {
acc += edge_weight[j] / sum;
edge_weight_cdf[j] = acc;
}
}
});
}


int64_t get_offset(const float_t *edge_weight, int64_t start, int64_t end) {
/*
The implementation given in [1] utilizes the `searchsorted` function in Numpy.
It is also available in PyTorch and its C++ API (via `at::searchsorted()`).
However, the implementation is adopted to the general case where the searched
values can be a multidimensional tensor. In our case, we have a 1D tensor of
edge weights (in form of a Cumulative Distribution Function) and a single
value, whose position we want to compute. To eliminate the overhead introduced
in the PyTorch implementation, one can examine the source code of
`searchsorted` [2] and find that for our case the whole function call can be
reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot
access it directly (the namespace is not exposed to the public API), but the
implementation is just a simple binary search. The code was copied here and
reduced to the bare minimum.

[1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
[2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp
*/
float_t value = ((float_t)rand() / RAND_MAX); // [0, 1)
int64_t original_start = start;

while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const float_t mid_val = edge_weight[mid];
if (!(mid_val >= value)) {
start = mid + 1;
}
else {
end = mid;
}
}

return start - original_start;
}

// See: https://louisabraham.github.io/articles/node2vec-sampling.html
// See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69
void rejection_sampling_weighted(const int64_t *rowptr, const int64_t *col,
const float_t *edge_weight_cdf, int64_t *start,
int64_t *n_out, int64_t *e_out,
const int64_t numel, const int64_t walk_length,
const double p, const double q) {

double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;

int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
for (auto n = begin; n < end; n++) {
int64_t t = start[n], v, x, e_cur, row_start, row_end;

n_out[n * (walk_length + 1)] = t;

row_start = rowptr[t], row_end = rowptr[t + 1];

if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
v = col[e_cur];
}
n_out[n * (walk_length + 1) + 1] = v;
e_out[n * walk_length] = e_cur;

for (auto l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];

if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
if (p == 1. && q == 1.) {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
x = col[e_cur];
}
else {
while (true) {
e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end);
x = col[e_cur];

auto r = ((double)rand() / (RAND_MAX)); // [0, 1)

if (x == t && r < prob_0)
break;
else if (is_neighbor(rowptr, col, x, t) && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
}

n_out[n * (walk_length + 1) + (l + 1)] = x;
e_out[n * walk_length + l] = e_cur;
t = v;
v = x;
}
}
});
}


std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(edge_weight);
CHECK_CPU(start);

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(edge_weight.dim() == 1);
CHECK_INPUT(start.dim() == 1);

auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
auto e_out = torch::empty({start.size(0), walk_length}, start.options());

auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto edge_weight_data = edge_weight.data_ptr<float_t>();
auto start_data = start.data_ptr<int64_t>();
auto n_out_data = n_out.data_ptr<int64_t>();
auto e_out_data = e_out.data_ptr<int64_t>();

auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());
auto edge_weight_cdf_data = edge_weight_cdf.data_ptr<float_t>();

compute_cdf(rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel());

rejection_sampling_weighted(rowptr_data, col_data, edge_weight_cdf_data,
start_data, n_out_data, e_out_data, start.numel(),
walk_length, p, q);

return std::make_tuple(n_out, e_out);
}
5 changes: 5 additions & 0 deletions csrc/cpu/rw_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q);
166 changes: 166 additions & 0 deletions csrc/cuda/rw_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,169 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,

return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
}


__global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight,
float_t *edge_weight_cdf, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel - 1) {
int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1];

float_t sum = 0.0;

for(int64_t i = row_start; i < row_end; i++) {
sum += edge_weight[i];
}

float_t acc = 0.0;

for(int64_t i = row_start; i < row_end; i++) {
acc += edge_weight[i] / sum;
edge_weight_cdf[i] = acc;
}
}
}

__device__ void get_offset(const float_t *edge_weight, int64_t start, int64_t end,
float_t value, int64_t *position_out) {
int64_t original_start = start;

while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const float_t mid_val = edge_weight[mid];
if (!(mid_val >= value)) {
start = mid + 1;
}
else {
end = mid;
}
}

*position_out = start - original_start;
}

__global__ void
rejection_sampling_weighted_kernel(unsigned int seed, const int64_t *rowptr,
const int64_t *col, const float_t *edge_weight_cdf,
const int64_t *start, int64_t *n_out,
int64_t *e_out, const int64_t walk_length,
const int64_t numel, const double p,
const double q) {

curandState_t state;
curand_init(seed, 0, 0, &state);

double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;

const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx < numel) {
int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end, offset;

n_out[thread_idx] = t;

row_start = rowptr[t], row_end = rowptr[t + 1];

if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
v = col[e_cur];
}

n_out[numel + thread_idx] = v;
e_out[thread_idx] = e_cur;

for (int64_t l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];

if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
if (p == 1. && q == 1.) {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
x = col[e_cur];
}
else {
while (true) {
get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset);
e_cur = row_start + offset;
x = col[e_cur];

double r = curand_uniform(&state); // (0, 1]

if (x == t && r < prob_0)
break;

bool is_neighbor = false;
row_start = rowptr[x], row_end = rowptr[x + 1];
for (int64_t i = row_start; i < row_end; i++) {
if (col[i] == t) {
is_neighbor = true;
break;
}
}

if (is_neighbor && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
}

n_out[(l + 1) * numel + thread_idx] = x;
e_out[l * numel + thread_idx] = e_cur;
t = v;
v = x;
}
}
}


std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(edge_weight);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(edge_weight.dim() == 1);
CHECK_INPUT(start.dim() == 1);

auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto e_out = torch::empty({walk_length, start.size(0)}, start.options());

auto stream = at::cuda::getCurrentCUDAStream();

auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options());

cdf_kernel<<<BLOCKS(rowptr.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), edge_weight.data_ptr<float_t>(),
edge_weight_cdf.data_ptr<float_t>(), rowptr.numel());

rejection_sampling_weighted_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
time(NULL), rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
edge_weight_cdf.data_ptr<float_t>(), start.data_ptr<int64_t>(),
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(),
walk_length, start.numel(), p, q);

return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
}

5 changes: 5 additions & 0 deletions csrc/cuda/rw_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q);

std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q);
18 changes: 17 additions & 1 deletion csrc/rw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,21 @@ random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
}
}

CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
random_walk_weighted(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor edge_weight, torch::Tensor start,
int64_t walk_length, double p, double q) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return random_walk_weighted_cuda(rowptr, col, edge_weight, start, walk_length, p, q);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return random_walk_weighted_cpu(rowptr, col, edge_weight, start, walk_length, p, q);
}
}

static auto registry =
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk);
torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk)
.op("torch_cluster::random_walk_weighted", &random_walk_weighted);
Loading