forked from pytorch/examples
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdata.py
70 lines (49 loc) · 2 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
from dataset import DatasetFromFolder
def download_bsd300(dest="dataset"):
output_image_dir = join(dest, "BSDS300/images")
if not exists(output_image_dir):
makedirs(dest)
url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
print("downloading url ", url)
data = urllib.request.urlopen(url)
file_path = join(dest, basename(url))
with open(file_path, 'wb') as f:
f.write(data.read())
print("Extracting data")
with tarfile.open(file_path) as tar:
for item in tar:
tar.extract(item, dest)
remove(file_path)
return output_image_dir
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
def input_transform(crop_size, upscale_factor):
return Compose([
CenterCrop(crop_size),
Resize(crop_size // upscale_factor),
ToTensor(),
])
def target_transform(crop_size):
return Compose([
CenterCrop(crop_size),
ToTensor(),
])
def get_training_set(upscale_factor):
root_dir = download_bsd300()
train_dir = join(root_dir, "train")
crop_size = calculate_valid_crop_size(256, upscale_factor)
return DatasetFromFolder(train_dir,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size))
def get_test_set(upscale_factor):
root_dir = download_bsd300()
test_dir = join(root_dir, "test")
crop_size = calculate_valid_crop_size(256, upscale_factor)
return DatasetFromFolder(test_dir,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size))