-
Notifications
You must be signed in to change notification settings - Fork 0
/
conv.py
155 lines (125 loc) · 6.03 KB
/
conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# This file is for HGT.
# Heterogeneous Graph Transformer.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch_geometric.nn.inits import glorot, uniform
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
class HGTConv(MessagePassing):
def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads,
dropout=0.2, use_norm=True, use_RTE=False, **kwargs):
super(HGTConv, self).__init__(aggr='add', **kwargs)
self.in_dim = in_dim
self.out_dim = out_dim
self.num_types = num_types
self.num_relations = num_relations
self.total_rel = num_types * num_relations * num_types
self.n_heads = n_heads
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.use_norm = use_norm
self.use_RTE = use_RTE
self.att = None
self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
for t in range(num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, self.n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, self.n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(num_types))
self.dropout = nn.Dropout(dropout)
if self.use_RTE:
pass # 增加时序信息的建模
glorot(self.relation_att)
glorot(self.relation_msg)
def forward(self, node_inp, node_type, edge_index, edge_type, edge_time=None):
return self.propagate(edge_index, node_inp=node_inp, node_type=node_type,
edge_type=edge_type, edge_time=edge_time)
def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time):
'''
j: source
i: target
<j, i>
'''
data_size = edge_index_i.size(0) # 边的数量
# Create Attention and Message tensor beforehand.
res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device)
res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device)
for source_type in range(self.num_types):
sb = (node_type_j == int(source_type))
k_linear = self.k_linears[source_type]
v_linear = self.v_linears[source_type]
for target_type in range(self.num_types):
tb = (node_type_i == int(target_type)) & sb
q_linear = self.q_linears[target_type]
for relation_type in range(self.num_relations):
'''
idx is all the edges with meta relation <source_type, relation_type, target_type>
'''
idx = (edge_type == int(relation_type)) & tb
if idx.sum() == 0:
continue
'''
Get the corresponding input node representations by idx.
Add temporal encoding to source representation (j)
'''
target_node_vec = node_inp_i[idx]
source_node_vec = node_inp_j[idx]
if self.use_RTE:
pass
'''
Step 1: Heterogeneous Mutual Attention
'''
q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k)
k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
k_mat = torch.bmm(k_mat.transpose(1, 0), self.relation_att[relation_type]).transpose(1, 0)
res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk
'''
Step 2: Heterogeneous Message Passing
'''
v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
res_msg[idx] = torch.bmm(v_mat.transpose(1, 0), self.relation_msg[relation_type]).transpose(1, 0)
'''
Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization.
'''
self.att = softmax(res_att, edge_index_i)
res = res_msg * self.att.view(-1, self.n_heads, 1)
del res_msg, res_att
return res.view(-1, self.out_dim)
def update(self, aggr_out, node_inp, node_type):
'''
Step 3: Target-specific Aggregation
x = W[node_type] * gelu(Agg(x)) + x
'''
aggr_out = F.gelu(aggr_out)
res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device)
for target_type in range(self.num_types):
idx = (node_type == int(target_type))
if idx.sum() == 0:
continue
trans_out = self.dropout(self.a_linears[target_type](aggr_out[idx]))
'''
Add skip connection with learnable weight self.skip[t_id]
'''
alpha = torch.sigmoid(self.skip[target_type])
if self.use_norm:
res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha))
else:
res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha)
return res
def __repr__(self):
return '{}(in_dim={}, out_dim={}, num_types={}, num_relations={})'.format(
self.__class__.__name__, self.in_dim, self.out_dim,
self.num_types, self.num_relations)