Skip to content

Commit

Permalink
update negative sampling policy for dense graphs (bug fix)
Browse files Browse the repository at this point in the history
  • Loading branch information
XinweiHe committed Dec 24, 2020
1 parent 38d98d6 commit fc10d98
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 11 deletions.
29 changes: 22 additions & 7 deletions deepsnap/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,12 +1688,12 @@ def _create_neg_sampling(
torch.cat([self.edge_index, self.edge_label_index], -1)
)

if len(edge_index_all) > 0:
negative_edges = self.negative_sampling(
edge_index_all, self.num_nodes, num_neg_edges
)
else:
return torch.tensor([], dtype=torch.long)
# handle multigraph
edge_index_all_unique = torch.unique(edge_index_all, dim=1)

negative_edges = self.negative_sampling(
edge_index_all_unique, self.num_nodes, num_neg_edges
)

if not resample:
if self.edge_label is None:
Expand Down Expand Up @@ -1915,11 +1915,20 @@ def negative_sampling(edge_index, num_nodes=None, num_neg_samples=None):
:rtype: :class:`torch.LongTensor`
"""
num_neg_samples_available = min(
num_neg_samples, num_nodes * num_nodes - edge_index.shape[1]
)

if num_neg_samples_available == 0:
raise ValueError(
"No negative samples could be generated for a complete graph."
)

rng = range(num_nodes ** 2)
# idx = N * i + j
idx = (edge_index[0] * num_nodes + edge_index[1]).to("cpu")

perm = torch.tensor(random.sample(rng, num_neg_samples))
perm = torch.tensor(random.sample(rng, num_neg_samples_available))
mask = torch.from_numpy(np.isin(perm, idx)).to(torch.bool)
rest = mask.nonzero().view(-1)
while rest.numel() > 0: # pragma: no cover
Expand All @@ -1931,5 +1940,11 @@ def negative_sampling(edge_index, num_nodes=None, num_neg_samples=None):
row = perm // num_nodes
col = perm % num_nodes
neg_edge_index = torch.stack([row, col], dim=0).long()
if num_neg_samples_available < num_neg_samples:
multiplicity = math.ceil(
num_neg_samples / num_neg_samples_available
)
neg_edge_index = torch.cat([neg_edge_index] * multiplicity, dim=1)
neg_edge_index = neg_edge_index[:, :num_neg_samples]

return neg_edge_index.to(edge_index.device)
50 changes: 46 additions & 4 deletions deepsnap/hetero_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,11 +2830,19 @@ def _create_neg_sampling(
)
)

# handle multigraph
edge_index_all_unique = {}
for message_type in edge_index_all:
edge_index_all_unique[message_type] = torch.unique(
edge_index_all[message_type],
dim=1
)

negative_edges = (
self.negative_sampling(
edge_index_all,
edge_index_all_unique,
self.num_nodes(),
num_neg_edges,
num_neg_edges
)
)

Expand Down Expand Up @@ -2931,6 +2939,22 @@ def negative_sampling(
:rtype: :class:`torch.LongTensor`
"""
num_neg_samples_available = {}
for message_type in edge_index:
head_type = message_type[0]
tail_type = message_type[2]
num_neg_samples_available[message_type] = min(
num_neg_samples[message_type],
num_nodes[head_type]
* num_nodes[tail_type]
- edge_index[message_type].shape[1]
)
if num_neg_samples_available[message_type] == 0:
raise ValueError(
"No negative samples could be generated for a "
f"complete graph in message_type: {message_type}."
)

rng = {}
for message_type in edge_index:
head_type = message_type[0]
Expand Down Expand Up @@ -2960,7 +2984,7 @@ def negative_sampling(
for message_type in edge_index:
samples = random.sample(
rng[message_type],
num_neg_samples[message_type]
num_neg_samples_available[message_type]
)
perm[message_type] = torch.tensor(samples)

Expand Down Expand Up @@ -3019,11 +3043,29 @@ def negative_sampling(
row[message_type],
col[message_type]
],
dim=0,
dim=0
).long()
for message_type in edge_index
}
)
for message_type in edge_index:
if (
num_neg_samples_available[message_type]
< num_neg_samples[message_type]
):
multiplicity = math.ceil(
num_neg_samples[message_type]
/ num_neg_samples_available[message_type]
)
neg_edge_index[message_type] = torch.cat(
[neg_edge_index[message_type]] * multiplicity,
dim=1
)
neg_edge_index[message_type] = (
neg_edge_index[message_type][
:, :num_neg_samples[message_type]
]
)

for message_type in edge_index:
neg_edge_index[message_type].to(
Expand Down

0 comments on commit fc10d98

Please sign in to comment.