-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
95 lines (85 loc) · 3.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torchvision
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from PIL import Image
class MSCTD(Dataset):
"""
:param root: root path of the dataset
:param split: train, dev, test
:param image_transform: transform for image
:param text_transform: transform for text
:param sentiment_transform: transform for sentiment
:param has_data: dict, whether the dataset has image, text
:param text_path: path of the text file
:param image_path: path of the image folder
:param sentiment_path: path of the sentiment file
:param image_index_path: path of the image index file
:return: combination of image, sentiment, text, image_index
Example:
>>> from torchvision import transforms
>>> image_transform = transforms.Compose([
>>> transforms.Resize((640, 1280)),
>>> transforms.Lambda(lambda x: x.permute(1, 2, 0))
>>> ])
>>> text_transform = None
>>> sentiment_transform = None
>>> dataset = MSCTD(root='data', split='train', image_transform=image_transform,
>>> text_transform=text_transform, sentiment_transform=sentiment_transform)
>>> image, text, sentiment = dataset[0]
"""
def __init__(self, root, split, image_transform=None, text_transform=None, sentiment_transform=None,
has_data={'image': True, 'text': True}, text_path=None, image_path=None, sentiment_path=None,
image_index_path=None):
data_path = os.path.join(root, split)
default_path = {
'text': os.path.join(data_path, 'english_' + split + '.txt'),
'image': os.path.join(data_path, 'image'),
'sentiment': os.path.join(data_path, 'sentiment_' + split + '.txt'),
'image_index': os.path.join(data_path, 'image_index_' + split + '.txt'),
}
self.image = [] if has_data['image'] else None
self.image_transform = image_transform
self.image_path = image_path if image_path else default_path['image']
self.text = [] if has_data['text'] else None
self.text_transform = text_transform
self.text_path = text_path if text_path else default_path['text']
self.sentiment_path = sentiment_path if sentiment_path else default_path['sentiment']
self.image_index_path = image_index_path if image_index_path else default_path['image_index']
self.sentiment = []
self.image_index = []
self.sentiment_transform = sentiment_transform
self.load_data()
def load_data(self):
self.sentiment = np.loadtxt(self.sentiment_path, dtype=int)
with open(self.text_path, 'r') as f:
self.text = f.readlines()
self.text = [x.strip() for x in self.text]
with open(self.image_index_path, 'r') as f:
data = f.readlines()
self.image_index = [list(map(int, x[1:-2].split(','))) for x in data]
def __getitem__(self, index):
image = None
text = None
sentiment = self.sentiment[index]
if self.image is not None:
imag_path = os.path.join(self.image_path, str(index)+'.jpg')
image = Image.open(imag_path)
if self.image_transform:
image = self.image_transform(image)
if self.text is not None:
text = self.text[index]
if self.text_transform:
text = self.text_transform(text)
if self.sentiment_transform:
sentiment = self.sentiment_transform(sentiment)
if text is not None and image is not None:
return image, text, sentiment
elif text is not None:
return text, sentiment
elif image is not None:
return image, sentiment
else:
raise Exception('Either image or text should be not None')
def __len__(self):
return len(self.sentiment)