forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathggsnn.py
102 lines (81 loc) · 3.03 KB
/
ggsnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Gated Graph Sequence Neural Network for sequence outputs
"""
import torch
import torch.nn.functional as F
from dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling
from torch import nn
class GGSNN(nn.Module):
def __init__(
self,
annotation_size,
out_feats,
n_steps,
n_etypes,
max_seq_length,
num_cls,
):
super(GGSNN, self).__init__()
self.annotation_size = annotation_size
self.out_feats = out_feats
self.max_seq_length = max_seq_length
self.ggnn = GatedGraphConv(
in_feats=out_feats,
out_feats=out_feats,
n_steps=n_steps,
n_etypes=n_etypes,
)
self.annotation_out_layer = nn.Linear(
annotation_size + out_feats, annotation_size
)
pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)
self.pooling = GlobalAttentionPooling(pooling_gate_nn)
self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)
self.loss_fn = nn.CrossEntropyLoss(reduction="none")
def forward(self, graph, seq_lengths, ground_truth=None):
etypes = graph.edata.pop("type")
annotation = graph.ndata.pop("annotation").float()
assert annotation.size()[-1] == self.annotation_size
node_num = graph.num_nodes()
all_logits = []
for _ in range(self.max_seq_length):
zero_pad = torch.zeros(
[node_num, self.out_feats - self.annotation_size],
dtype=torch.float,
device=annotation.device,
)
h1 = torch.cat([annotation.detach(), zero_pad], -1)
out = self.ggnn(graph, h1, etypes)
out = torch.cat([out, annotation], -1)
logits = self.pooling(graph, out)
logits = self.output_layer(logits)
all_logits.append(logits)
annotation = self.annotation_out_layer(out)
annotation = F.softmax(annotation, -1)
all_logits = torch.stack(all_logits, 1)
preds = torch.argmax(all_logits, -1)
if ground_truth is not None:
loss = sequence_loss(all_logits, ground_truth, seq_lengths)
return loss, preds
return preds
def sequence_loss(logits, ground_truth, seq_length=None):
def sequence_mask(length):
max_length = logits.size(1)
batch_size = logits.size(0)
range_tensor = torch.arange(
0, max_length, dtype=seq_length.dtype, device=seq_length.device
)
range_tensor = torch.stack([range_tensor] * batch_size, 0)
expanded_length = torch.stack([length] * max_length, -1)
mask = (range_tensor < expanded_length).float()
return mask
loss = nn.CrossEntropyLoss(reduction="none")(
logits.permute((0, 2, 1)), ground_truth
)
if seq_length is None:
loss = loss.mean()
else:
mask = sequence_mask(seq_length)
loss = (loss * mask).sum(-1) / seq_length.float()
loss = loss.mean()
return loss