-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathload_data.py
105 lines (87 loc) · 4.49 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torch.utils.data.dataset import Dataset
from scipy.io import loadmat, savemat
from torch.utils.data import DataLoader
from util import BackgroundGenerator
import numpy as np
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
class CustomDataSet(Dataset):
def __init__(self, images, texts, labels):
self.images = images
self.texts = texts
self.labels = labels
def __getitem__(self, index):
img = self.images[index]
text = self.texts[index]
label = self.labels[index]
return img, text, label
def __len__(self):
count = len(self.images)
return count
class SingleModalDataSet(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
data = self.data[index]
label = self.labels[index]
return data, label
def __len__(self):
count = len(self.data)
return count
def get_loader(path, batch_size, INCOMPLETE=False, USE_INCOMPLETE=False):
img_train = loadmat(path + "train_img.mat")['train_img']
img_test = loadmat(path + "test_img.mat")['test_img']
text_train = loadmat(path + "train_txt.mat")['train_txt']
text_test = loadmat(path + "test_txt.mat")['test_txt']
label_train = loadmat(path + "train_lab.mat")['train_lab']
label_test = loadmat(path + "test_lab.mat")['test_lab']
# Incomplete modal
split = img_train.shape[0] // 5
if INCOMPLETE:
text_train[split * 1: split * 3] = np.zeros_like(text_train[split * 1: split * 3])
img_train[split * 3: split * 5] = np.zeros_like(img_train[split * 3: split * 5])
imgs = {'train': img_train, 'test': img_test}
texts = {'train': text_train, 'test': text_test}
labels = {'train': label_train, 'test': label_test}
if USE_INCOMPLETE:
shuffle = {'train_complete': True, 'train_img': True, 'train_text': True, 'test': False}
dataset = {'train_complete': CustomDataSet(images=imgs['train'][:split * 1],
texts=texts['train'][:split * 1],
labels=labels['train'][:split * 1]),
'train_img': SingleModalDataSet(data=imgs['train'][split * 1:split * 3],
labels=labels['train'][split * 1:split * 3]),
'train_text': SingleModalDataSet(data=texts['train'][split * 3: split * 5],
labels=labels['train'][split * 3: split * 5]),
'test': CustomDataSet(images=imgs['test'], texts=texts['test'], labels=labels['test'])}
dataloader = {'train_complete': DataLoaderX(dataset['train_complete'], batch_size=batch_size // 5,
shuffle=shuffle['train_complete'], num_workers=0),
'train_img': DataLoaderX(dataset['train_img'], batch_size=batch_size // 5 * 2,
shuffle=shuffle['train_img'], num_workers=0),
'train_text': DataLoaderX(dataset['train_text'], batch_size=batch_size // 5 * 2,
shuffle=shuffle['train_text'], num_workers=0),
'test': DataLoaderX(dataset['test'], batch_size=batch_size,
shuffle=shuffle['test'], num_workers=0),
}
else:
shuffle = {'train': True, 'test': False}
dataset = {x: CustomDataSet(images=imgs[x], texts=texts[x], labels=labels[x])
for x in ['train', 'test']}
dataloader = {x: DataLoaderX(dataset[x], batch_size=batch_size,
shuffle=shuffle[x], num_workers=0) for x in ['train', 'test']}
img_dim = img_train.shape[1]
text_dim = text_train.shape[1]
num_class = label_train.shape[1]
input_data_par = {}
input_data_par['img_test'] = img_test
input_data_par['text_test'] = text_test
input_data_par['label_test'] = label_test
input_data_par['img_train'] = img_train
input_data_par['text_train'] = text_train
input_data_par['label_train'] = label_train
input_data_par['img_dim'] = img_dim
input_data_par['text_dim'] = text_dim
input_data_par['num_class'] = num_class
return dataloader, input_data_par