forked from amazon-science/tgl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
modules.py
99 lines (92 loc) · 5.08 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import dgl
from memorys import *
from layers import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x