forked from YU1ut/imprinted-weights
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
89 lines (74 loc) · 3.22 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from PIL import Image
import os
import pandas as pd
import math
from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class ImageLoader(torch.utils.data.Dataset):
def __init__(self, root, transform=None, target_transform=None, train=False, num_classes=100, num_train_sample=0, novel_only=False, aug=False,
loader=pil_loader):
img_folder = os.path.join(root, "images")
img_paths = pd.read_csv(os.path.join(root, "images.txt"), sep=" ", header=None, names=['idx', 'path'])
img_labels = pd.read_csv(os.path.join(root, "image_class_labels.txt"), sep=" ", header=None, names=['idx', 'label'])
train_test_split = pd.read_csv(os.path.join(root, "train_test_split.txt"), sep=" ", header=None, names=['idx', 'train_flag'])
data = pd.concat([img_paths, img_labels, train_test_split], axis=1)
data = data[data['train_flag'] == train]
data['label'] = data['label'] - 1
# split dataset
data = data[data['label'] < num_classes]
base_data = data[data['label'] < 100]
novel_data = data[data['label'] >= 100]
# sampling from novel classes
if num_train_sample != 0:
novel_data = novel_data.groupby('label', group_keys=False).apply(lambda x: x.iloc[:num_train_sample])
# whether only return data of novel classes
if novel_only:
data = novel_data
else:
data = pd.concat([base_data, novel_data])
# repeat 5 times for data augmentation
if aug:
tmp_data = pd.DataFrame()
for i in range(5):
tmp_data = pd.concat([tmp_data, data])
data = tmp_data
imgs = data.reset_index(drop=True)
if len(imgs) == 0:
raise(RuntimeError("no csv file"))
self.root = img_folder
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.train = train
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
item = self.imgs.iloc[int(index)]
file_path = item['path']
target = item['label']
img = self.loader(os.path.join(self.root, file_path))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
def get_balanced_sampler(self):
img_labels = np.array(self.imgs['label'].tolist())
class_sample_count = np.array([len(np.where(img_labels==t)[0]) for t in np.unique(img_labels)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in img_labels])
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
return sampler