-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
46 lines (35 loc) · 1.61 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
import torch.nn as nn
# ==================================================================================================
# -- model -----------------------------------------------------------------------------------------
# ==================================================================================================
class ImageCaptioningNet(nn.Module):
def __init__(self, encoder, decoder):
super(ImageCaptioningNet, self).__init__()
self.encoder = encoder
self.decoder = decoder
def trainable_parameters(self):
return self.encoder.trainable_parameters() + self.decoder.trainable_parameters()
def forward(self, imgs, train_caps, lengths):
features = self.encoder(imgs)
out, state = self.decoder(features, train_caps, lengths)
return out, state
def predict(self, img, max_seq_length=25):
"""
Predicts a caption.
:param img: float tensor of shape (channels, height, width)
:param max_seq_len: maximum length of the predicted caption.
"""
# -------
# Encoder
# -------
# features shape -- (batch_size(1), num_pixels, encoder_size)
img = img.unsqueeze(0) # Adding batch size dim: (batch_size(1), channels, height, width)
features = self.encoder(img)
# -------
# Decoder
# -------
# out -- list of predicted indices.
# alphas -- list of alphas for each predicted word. shape: (batch_size(1), num_pixels)
out, alphas = self.decoder.sample(features, max_seq_length)
return out, alphas