forked from facebookresearch/swav
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulticropdataset.py
94 lines (81 loc) · 2.96 KB
/
multicropdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import random
from logging import getLogger
from PIL import ImageFilter
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
logger = getLogger()
class MultiCropDataset(datasets.ImageFolder):
def __init__(
self,
data_path,
size_crops,
nmb_crops,
min_scale_crops,
max_scale_crops,
size_dataset=-1,
return_index=False,
):
super(MultiCropDataset, self).__init__(data_path)
assert len(size_crops) == len(nmb_crops)
assert len(min_scale_crops) == len(nmb_crops)
assert len(max_scale_crops) == len(nmb_crops)
if size_dataset >= 0:
self.samples = self.samples[:size_dataset]
self.return_index = return_index
color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
mean = [0.485, 0.456, 0.406]
std = [0.228, 0.224, 0.225]
trans = []
for i in range(len(size_crops)):
randomresizedcrop = transforms.RandomResizedCrop(
size_crops[i],
scale=(min_scale_crops[i], max_scale_crops[i]),
)
trans.extend([transforms.Compose([
randomresizedcrop,
transforms.RandomHorizontalFlip(p=0.5),
transforms.Compose(color_transform),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)])
] * nmb_crops[i])
self.trans = trans
def __getitem__(self, index):
path, _ = self.samples[index]
image = self.loader(path)
multi_crops = list(map(lambda trans: trans(image), self.trans))
if self.return_index:
return index, multi_crops
return multi_crops
class PILRandomGaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image. Take the radius and probability of
application as the parameter.
This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = np.random.rand() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
def get_color_distortion(s=1.0):
# s is the strength of color distortion.
color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
return color_distort