Skip to content

Commit

Permalink
Fix edge weight normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
pbielak committed Sep 9, 2022
1 parent c4154c8 commit bcda1c7
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions torch_cluster/rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,17 @@ def random_walk(
rowptr, col, start, walk_length, p, q,
)
else:
# Normalize edge weights by node degrees
edge_weight = edge_weight / deg[row]
# Normalize edge weights
from torch_sparse import SparseTensor

adj = SparseTensor(
row=row,
col=col,
value=edge_weight,
sparse_sizes=(num_nodes, num_nodes),
)

edge_weight = edge_weight / adj.sum(dim=1).repeat_interleave(deg)

node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted(
rowptr, col, edge_weight, start, walk_length, p, q,
Expand Down

0 comments on commit bcda1c7

Please sign in to comment.