-
Notifications
You must be signed in to change notification settings - Fork 63
/
bilstm.py
35 lines (30 loc) · 1.41 KB
/
bilstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class BiLSTMSentiment(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, label_size, use_gpu, batch_size, dropout=0.5):
super(BiLSTMSentiment, self).__init__()
self.hidden_dim = hidden_dim
self.use_gpu = use_gpu
self.batch_size = batch_size
self.dropout = dropout
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, bidirectional=True)
self.hidden2label = nn.Linear(hidden_dim*2, label_size)
self.hidden = self.init_hidden()
def init_hidden(self):
# first is the hidden h
# second is the cell c
if self.use_gpu:
return (Variable(torch.zeros(2, self.batch_size, self.hidden_dim).cuda()),
Variable(torch.zeros(2, self.batch_size, self.hidden_dim).cuda()))
else:
return (Variable(torch.zeros(2, self.batch_size, self.hidden_dim)),
Variable(torch.zeros(2, self.batch_size, self.hidden_dim)))
def forward(self, sentence):
x = self.embeddings(sentence).view(len(sentence), self.batch_size, -1)
lstm_out, self.hidden = self.lstm(x, self.hidden)
y = self.hidden2label(lstm_out[-1])
log_probs = F.log_softmax(y)
return log_probs