forked from palver7/CFLPytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformclassexp.py
145 lines (116 loc) · 5.5 KB
/
transformclassexp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""This script is for experimenting with custom data transforms with classes on custom datasets"""
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
from PIL import Image
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import mytransforms
import os
class SUN360Dataset(Dataset):
def __init__(self, file, transform=None, target_transform=None, joint_transform=None):
"""
Args:
json_file (string): Path to the json file with annotations.
transform (callable, optional): Optional transform to be applied
on an image.
target_file (callable, optional): Optional transform to be applied
on a map (edge and corner).
"""
self.images_data = pd.read_json(file)
self.transform = transform
self.target_transform = target_transform
self.joint_transform = joint_transform
def __len__(self):
return len(self.images_data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.images_data.iloc[idx, 0]
EM_name = self.images_data.iloc[idx, 1]
CM_name = self.images_data.iloc[idx, 2]
CL_name = self.images_data.iloc[idx, 3]
image = Image.open(img_name)
EM = Image.open(EM_name)
CM = Image.open(CM_name)
with open(CL_name, mode='r') as f:
cor = np.array([line.strip().split() for line in f], np.int32)
if(len(cor)%2 != 0) :
print (CL_name.split('/')[-1])
"""
EM = np.asarray(EM)
EM = np.expand_dims(EM, axis=2)
CM = np.asarray(CM)
CM = np.expand_dims(CM, axis=2)
gt = np.concatenate((EM,CM),axis = 2)
maps = Image.fromarray(gt)
"""
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
CM = self.target_transform(CM)
EM = self.target_transform(EM)
if self.joint_transform is not None:
image, EM, CM, cor = self.joint_transform([image, EM, CM, cor])
return image, EM, CM
class SplitDataset(Dataset):
def __init__(self, dataset, transform=None, target_transform=None):
"""
Args:
json_file (string): Path to the json file with annotations.
transform (callable, optional): Optional transform to be applied
on an image.
target_file (callable, optional): Optional transform to be applied
on a map (edge and corner).
"""
self.images_data = dataset
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.images_data)
def __getitem__(self, idx):
image, EM, CM = self.images_data[idx]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
CM = self.target_transform(CM)
EM = self.target_transform(EM)
return image, EM, CM
#transform = transforms.Compose([transforms.ToTensor(),HorizontalRotation()])
#target_transform = transforms.Compose([transforms.ToTensor()])
roll_gen = mytransforms.RandomHorizontalRollGenerator()
flip_gen = mytransforms.RandomHorizontalFlipGenerator()
panostretch_gen = mytransforms.RandomPanoStretchGenerator(max_stretch = 2.0)
joint_transform = mytransforms.Compose([panostretch_gen,
[mytransforms.RandomPanoStretch(panostretch_gen), mytransforms.RandomPanoStretch(panostretch_gen), mytransforms.RandomPanoStretch(panostretch_gen), None],
flip_gen,
[mytransforms.RandomHorizontalFlip(flip_gen),mytransforms.RandomHorizontalFlip(flip_gen),mytransforms.RandomHorizontalFlip(flip_gen), None],
[transforms.ToTensor(),transforms.ToTensor(),transforms.ToTensor(), None],
roll_gen,
[mytransforms.RandomHorizontalRoll(roll_gen),mytransforms.RandomHorizontalRoll(roll_gen),mytransforms.RandomHorizontalRoll(roll_gen), None],
[transforms.RandomErasing(p=0.5,value=0), None, None, None],
])
trainset = SUN360Dataset(file="traindata.json",transform = None, target_transform = None, joint_transform=joint_transform)
train_loader = DataLoader(trainset, batch_size=1,
shuffle=True, num_workers=2)
topil=transforms.ToPILImage()
if not os.path.exists('result/RGB/'):
os.makedirs('result/RGB/')
if not os.path.exists('result/EM/'):
os.makedirs('result/EM/')
if not os.path.exists('result/CM/'):
os.makedirs('result/CM/')
for i, data in enumerate(train_loader):
images, EM, CM, cor = data
images, EM, CM = torch.squeeze(images), torch.squeeze(EM), torch.squeeze(CM)
im,edges,corners = topil(images), topil(EM), topil(CM)
if len (str(i))<2:
num = '0'+str(i)
else:
num = str(i)
im.save("result/RGB/RGB_{}.jpg".format(num))
edges.save("result/EM/EM_{}.jpg".format(num))
corners.save("result/CM/CM_{}.jpg".format(num))