-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvLSTM_slip_detection_2.py
129 lines (96 loc) · 3.57 KB
/
convLSTM_slip_detection_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable
# Define some constants
KERNEL_SIZE = 3
PADDING = KERNEL_SIZE // 2
class ConvLSTMCell(nn.Module):
"""
Generate a convolutional LSTM cell
"""
def __init__(self, input_size, hidden_size):
super(ConvLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING)
self.height, self.width = 4, 8
self.linear = nn.Linear(self.hidden_size*self.height*self.width, 2)
def forward(self, input_, prev_state):
# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (
Variable(torch.zeros(state_size)),
Variable(torch.zeros(state_size))
)
prev_hidden, prev_cell = prev_state
# data size is [batch, channel, height, width]
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)
# chunk across channel dimension
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
# apply sigmoid non linearity
in_gate = f.sigmoid(in_gate)
remember_gate = f.sigmoid(remember_gate)
out_gate = f.sigmoid(out_gate)
# apply tanh non linearity
cell_gate = f.tanh(cell_gate)
# compute current cell and hidden state
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * f.tanh(cell)
# print hidden.size()
flat = hidden.view(-1, hidden.size(1)*hidden.size(2)*hidden.size(3))
# print flat.size()
out = self.linear(flat)
return out, (hidden, cell)
def _main():
"""
Run some basic tests on the API
"""
# define batch_size, channels, height, width
b, c, h, w = 16, 3, 4, 8
d = 5 # hidden state size
lr = 1e-1 # learning rate
T = 10 # sequence length
max_epoch = 20 # number of epochs
# set manual seed
torch.manual_seed(0)
print('Instantiate model')
model = ConvLSTMCell(c, d)
print(repr(model))
print('Create input and target Variables')
x = Variable(torch.rand(T, b, c, h, w))
# y = Variable(torch.randn(T, b, d, h, w))
y = Variable(torch.rand(b))
print x
print('Create a MSE criterion')
loss_fn = nn.CrossEntropyLoss()
print('Run for', max_epoch, 'iterations')
for epoch in range(0, max_epoch):
state = None
loss = 0
out = None
for t in range(0, T):
out, state = model(x[t], state)
# loss += loss_fn(state[0], y[t])
# out = out.long()
y = y.long()
# print out.size(), y.size()
loss = loss_fn(out, y)
print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0]))
# zero grad parameters
model.zero_grad()
# compute new grad parameters through time!
loss.backward()
# learning_rate step against the gradient
for p in model.parameters():
p.data.sub_(p.grad.data * lr)
print('Input size:', list(x.data.size()))
print('Target size:', list(y.data.size()))
print('Last hidden state size:', list(state[0].size()))
if __name__ == '__main__':
_main()