-
Notifications
You must be signed in to change notification settings - Fork 2
/
recurrent_NNs.py
87 lines (59 loc) · 2.42 KB
/
recurrent_NNs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch.nn as nn
# RNN based language model
class RNNLM(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
super(RNNLM, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
def forward(self, x, h, epoch):
# Embed word ids to vectors
x = self.embed(x)
# Forward propagate LSTM
out, (h, c) = self.lstm(x, h)
# Reshape output to (batch_size*sequence_length, hidden_size)
out = out.reshape(out.size(0) * out.size(1), out.size(2))
# Decode hidden states of all time steps
out = self.linear(out)
return out, (h, c)
class RNNModel(nn.Module):
def __init__(self, embed_size, hidden_size, num_layers, vocab_size):
super(RNNModel, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
# Number of hidden dimensions
self.hidden_size = hidden_size
# Number of hidden layers
self.num_layers = num_layers
# RNN
self.rnn = nn.RNN(
embed_size, hidden_size, num_layers, batch_first=True, nonlinearity="relu"
)
# Readout layer
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, h0):
x = self.embed(x).requires_grad_()
# One time step
out, hn = self.rnn(x, h0)
out = self.fc(out)
# print(out.shape)
return out, hn
class RNNLM_bilstm(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
super(RNNLM_bilstm, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(
embed_size, hidden_size, num_layers, batch_first=True, bidirectional=True
)
self.linear = nn.Linear(int(hidden_size * 2), vocab_size)
def forward(self, x, h, epoch, out_inds):
# Embed word ids to vectors
x = self.embed(x)
# Forward propagate LSTM
out, (h, c) = self.lstm(x, h)
# Reshape output to (batch_size*sequence_length, hidden_size)
out = out.reshape(out.size(0) * out.size(1), out.size(2))
# print(out.shape)
out = out[[i for i in out_inds]]
# Decode hidden states of all time steps
out = self.linear(out)
return out, (h, c)