forked from wangxu-scu/DRSL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
114 lines (94 loc) · 4.02 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageDNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(ImageDNN, self).__init__()
self.Sequential = nn.Sequential(nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU()
)
def forward(self, x):
# x = F.leaky_relu(self.fc1(x), 0.05)
# x = F.leaky_relu(self.fc2(x), 0.05)
x = self.Sequential(x)
return x
class TextDNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(TextDNN, self).__init__()
self.Sequential = nn.Sequential(nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU()
)
def forward(self, x):
# x = F.leaky_relu(self.fc1(x), 0.05)
# x = F.leaky_relu(self.fc2(x), 0.05)
x = self.Sequential(x)
return x
class RelationDNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(RelationDNN, self).__init__()
self.Sequential = nn.Sequential(nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
x = self.Sequential(x)
# x = F.sigmoid(x)
return x
class Model(nn.Module):
def __init__(
self,
input_dim_I,
input_dim_T,
hidden_dim_I,
hidden_dim_T,
hidden_dim_R,
output_dim_I,
output_dim_T,
output_dim_R):
super(Model, self).__init__()
self.ImageDNN = ImageDNN(input_dim_I, hidden_dim_I, output_dim_I)
self.TextDNN = TextDNN(input_dim_T, hidden_dim_T, output_dim_T)
self.RelationDNN = RelationDNN(output_dim_I + output_dim_T, hidden_dim_R, output_dim_R)
# self.RelationDNN = RelationDNN(input_dim_I + input_dim_T, hidden_dim_R, output_dim_R)
def forward(self, img, text, return_relation_score=True):
# Image Pathway
y_I = self.ImageDNN(img)
# y_I = img
# Text Pathway
y_T = self.TextDNN(text)
# y_T = text
if return_relation_score is False:
return y_I, y_T
relation_score = self.cal_relation_score(y_I, y_T)
return relation_score
def cal_relation_score(self, y_I, y_T):
ni = y_I.size(0)
di = y_I.size(1)
nt = y_T.size(0)
dt = y_T.size(1)
y_I = y_I.unsqueeze(1).expand(ni, nt, di)
y_I = y_I.reshape(-1, di)
y_T = y_T.unsqueeze(0).expand(ni, nt, dt)
y_T = y_T.reshape(-1, dt)
y = torch.cat((y_I, y_T), 1)
# y = y_I * y_T
relation_score = self.RelationDNN(y)
return relation_score