forked from Wovchena/text-detection-fots.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
236 lines (207 loc) · 12.4 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import math
import os
import random
import re
import cv2
import numpy as np
import scipy.io
import torch
import torch.utils.data
import torchvision
from shapely.geometry import Polygon, box
import shapely
def point_dist_to_line(p1, p2, p3):
"""Compute the distance from p3 to p2-p1."""
if not np.array_equal(p1, p2):
return np.abs(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
else:
return np.linalg.norm(p3 - p1)
IN_OUT_RATIO = 4
IN_SIDE = 640
OUT_SIDE = IN_SIDE // IN_OUT_RATIO
def transform(im, quads, texts, normalizer, data_set):
# upscale
scale = 2560 / np.maximum(im.shape[0], im.shape[1])
upscaled = cv2.resize(im, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
quads = quads * scale
# rotate
# grab the dimensions of the image and then determine the
# center
(h, w) = upscaled.shape[:2]
(cX, cY) = (w / 2, h / 2)
# grab the rotation matrix (applying the negative of the
# angle to rotate clockwise), then grab the sine and cosine
# (i.e., the rotation components of the matrix)
angle = torch.empty(1).uniform_(-10, 10).item()
M = cv2.getRotationMatrix2D((cX, cY), angle=angle, scale=1.0)
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
# compute the new bounding dimensions of the image
nW = int((h * sin) + (w * cos)) # TODO replace with round and do it later
nH = int((h * cos) + (w * sin))
# adjust the rotation matrix to take into account translation
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
# perform the actual rotation and return the image
rotated = cv2.warpAffine(upscaled, M, (nW, nH))
quads = cv2.transform(quads, M)
# stretch
strechK = torch.empty(1).uniform_(0.8, 1.2).item()
stretched = cv2.resize(rotated, None, fx=1, fy=strechK, interpolation=cv2.INTER_CUBIC)
quads[:, :, 1] = quads[:, :, 1] * strechK
quads /= IN_OUT_RATIO
training_mask = np.ones((OUT_SIDE, OUT_SIDE), dtype=float)
classification = np.zeros((OUT_SIDE, OUT_SIDE), dtype=float)
regression = np.zeros((4,) + classification.shape, dtype=float)
tmp_cls = np.empty(classification.shape, dtype=float)
thetas = np.zeros(classification.shape, dtype=float)
# crop
crop_max_y = stretched.shape[0] // IN_OUT_RATIO - OUT_SIDE # since Synth has some low images, there is a chance that y coord of crop can be zero only
if 0 != crop_max_y:
crop_point = (torch.randint(low=0, high=stretched.shape[1] // IN_OUT_RATIO - OUT_SIDE, size=(1,), dtype=torch.int16).item(),
torch.randint(low=0, high=stretched.shape[0] // IN_OUT_RATIO - OUT_SIDE, size=(1,), dtype=torch.int16).item())
else:
crop_point = (torch.randint(low=0, high=stretched.shape[1] // IN_OUT_RATIO - OUT_SIDE, size=(1,), dtype=torch.int16).item(),
0)
crop_box = box(crop_point[0], crop_point[1], crop_point[0] + OUT_SIDE, crop_point[1] + OUT_SIDE)
for quad_id, quad in enumerate(quads):
polygon = Polygon(quad)
try:
intersected_polygon = polygon.intersection(crop_box)
except shapely.errors.TopologicalError: # some points of quads in Synth can be in wrong order
quad[1], quad[2] = quad[2], quad[1]
polygon = Polygon(quad)
intersected_polygon = polygon.intersection(crop_box)
if type(intersected_polygon) is Polygon:
intersected_quad = np.array(intersected_polygon.exterior.coords[:-1])
intersected_quad -= crop_point
intersected_minAreaRect = cv2.minAreaRect(intersected_quad.astype(np.float32))
intersected_minAreaRect_boxPoints = cv2.boxPoints(intersected_minAreaRect)
cv2.fillConvexPoly(training_mask, intersected_minAreaRect_boxPoints.round().astype(int), 0)
minAreaRect = cv2.minAreaRect(quad.astype(np.float32))
shrinkage = min(minAreaRect[1][0], minAreaRect[1][1]) * 0.6
shrunk_width_and_height = (intersected_minAreaRect[1][0] - shrinkage, intersected_minAreaRect[1][1] - shrinkage)
if shrunk_width_and_height[0] >= 0 and shrunk_width_and_height[1] >= 0 and texts[quad_id]:
shrunk_minAreaRect = intersected_minAreaRect[0], shrunk_width_and_height, intersected_minAreaRect[2]
poly = intersected_minAreaRect_boxPoints
if intersected_minAreaRect[2] >= -45:
poly = np.array([poly[1], poly[2], poly[3], poly[0]])
else:
poly = np.array([poly[2], poly[3], poly[0], poly[1]])
angle_cos = (poly[2, 0] - poly[3, 0]) / np.sqrt(
(poly[2, 0] - poly[3, 0]) ** 2 + (poly[2, 1] - poly[3, 1]) ** 2 + 1e-5) # TODO tg or ctg
angle = np.arccos(angle_cos)
if poly[2, 1] > poly[3, 1]:
angle *= -1
angle += 45 * np.pi / 180 # [0, pi/2] for learning, actually [-pi/4, pi/4]
tmp_cls.fill(0)
round_shrink_minAreaRect_boxPoints = cv2.boxPoints(shrunk_minAreaRect)
cv2.fillConvexPoly(tmp_cls, round_shrink_minAreaRect_boxPoints.round(out=round_shrink_minAreaRect_boxPoints).astype(int), 1)
cv2.rectangle(tmp_cls, (0, 0), (tmp_cls.shape[1] - 1, tmp_cls.shape[0] - 1), 0, thickness=int(round(shrinkage * 2)))
classification += tmp_cls
training_mask += tmp_cls
thetas += tmp_cls * angle
points = np.nonzero(tmp_cls)
pointsT = np.transpose(points)
for point in pointsT:
for plane in range(3): # TODO widht - dist, height - other dist and more percise dist
regression[(plane,) + tuple(point)] = point_dist_to_line(poly[plane], poly[plane + 1], np.array((point[1], point[0]))) * IN_OUT_RATIO
regression[(3,) + tuple(point)] = point_dist_to_line(poly[3], poly[0], np.array((point[1], point[0]))) * IN_OUT_RATIO
if 0 == np.count_nonzero(classification) and 0.1 < torch.rand(1).item():
return data_set[torch.randint(low=0, high=len(data_set), size=(1,), dtype=torch.int16).item()]
# avoiding training on black corners decreases hmean, see d9c727a8defbd1c8022478ae798c907ccd2fa0b2. This may happen
# because of OHEM: it already guides the training and it won't select back corner pixels if the net is good at
# classifying them. It can be easily verified by removing OHEM, but I didn't test it
cropped = stretched[crop_point[1] * IN_OUT_RATIO:crop_point[1] * IN_OUT_RATIO + IN_SIDE, crop_point[0] * IN_OUT_RATIO:crop_point[0] * IN_OUT_RATIO + IN_SIDE]
cropped = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB).astype(np.float64) / 255
permuted = np.transpose(cropped, (2, 0, 1))
permuted = torch.from_numpy(permuted).float()
permuted = normalizer(permuted)
return permuted, torch.from_numpy(classification).float(), torch.from_numpy(regression).float(), torch.from_numpy(
thetas).float(), torch.from_numpy(training_mask).float()
class ICDAR2015(torch.utils.data.Dataset):
def __init__(self, root, transform):
self.transform = transform
self.root = root
self.img_dir = 'ch4_training_images'
self.labels_dir = 'ch4_training_localization_transcription_gt'
self.image_prefix = []
self.pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$')
self.normalizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
for dirEntry in os.scandir(os.path.join(root, 'ch4_training_images')):
self.image_prefix.append(dirEntry.name[:-4])
def __len__(self):
return len(self.image_prefix)
def __getitem__(self, idx):
img = cv2.imread(os.path.join(os.path.join(self.root, self.img_dir), self.image_prefix[idx] + '.jpg'), cv2.IMREAD_COLOR).astype(np.float32)
quads = []
texts = []
lines = [line.rstrip('\n') for line in open(os.path.join(os.path.join(self.root, self.labels_dir), 'gt_' + self.image_prefix[idx] + '.txt'),
encoding='utf-8-sig')]
for line in lines:
matches = self.pattern.findall(line)[0]
numbers = np.array(matches[:8], dtype=float)
quads.append(numbers.reshape((4, 2)))
texts.append('###' != matches[8])
return transform(img, np.stack(quads), texts, self.normalizer, self)
class SynthText(torch.utils.data.Dataset):
def __init__(self, root, transform):
self.transform = transform
self.root = root
self.labels = scipy.io.loadmat(os.path.join(root, 'gt.mat'))
self.broken_image_ids = set()
#sample_path = labels['imnames'][0, 1][0]
#sample_boxes = np.transpose(labels['wordBB'][0, 1], (2, 1, 0))
self.pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$')
self.normalizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def __len__(self):
return self.labels['imnames'].shape[1] // 105 # there are more than 105 text images for each source image
def __getitem__(self, idx):
idx = (idx * 105) + random.randint(0, 104) # compensate dataset size, while maintain diversity
if idx in self.broken_image_ids:
return self[torch.randint(low=0, high=len(self), size=(1,), dtype=torch.int16).item()]
img = cv2.imread(os.path.join(self.root, self.labels['imnames'][0, idx][0]), cv2.IMREAD_COLOR).astype(np.float32)
if 190 >= img.shape[0]: # image is too low, it will not be possible to crop 640x640 after transformations
self.broken_image_ids.add(idx)
return self[torch.randint(low=0, high=len(self), size=(1,), dtype=torch.int16).item()]
coordinates = self.labels['wordBB'][0, idx]
if len(coordinates.shape) == 2:
coordinates = np.expand_dims(coordinates, axis=2)
transposed = np.transpose(coordinates, (2, 1, 0))
if (transposed > 0).all() and (transposed[:, :, 1] < img.shape[1]).all() and (transposed[:, :, 1] < img.shape[0]).all():
if ((transposed[:, 0] != transposed[:, 1]).all() and
(transposed[:, 0] != transposed[:, 2]).all() and
(transposed[:, 0] != transposed[:, 3]).all() and
(transposed[:, 1] != transposed[:, 2]).all() and
(transposed[:, 1] != transposed[:, 3]).all() and
(transposed[:, 2] != transposed[:, 3]).all()): # boxes can be in a form [p1, p1, p2, p2], while we need [p1, p2, p3, p4]
return transform(img, transposed, (True, ) * len(transposed), self.normalizer, self)
self.broken_image_ids.add(idx)
return self[torch.randint(low=0, high=len(self), size=(1,), dtype=torch.int16).item()]
if '__main__' == __name__:
icdar = ICDAR2015('C:\\Users\\vzlobin\\Documents\\repo\\FOTS.PyTorch\\data\\icdar\\icdar2015\\4.4\\training', transform)
# dl = torch.utils.data.DataLoader(icdar, batch_size=4, shuffle=False, sampler=None, batch_sampler=None, num_workers=4, pin_memory = False, drop_last = False, timeout = 0, worker_init_fn = None)
for image_i in range(len(icdar)):
normalized, classification, regression, thetas, training_mask = icdar[image_i]
permuted = normalized * torch.tensor([0.229, 0.224, 0.225])[:, None, None] + torch.tensor([0.485, 0.456, 0.406])[:, None, None]
cropped = permuted.permute(1, 2, 0).numpy()
cv2.imshow('orig', cv2.resize(cropped[:, :, ::-1], (640, 640)))
cropped = cv2.resize(cropped, (160, 160))
cv2.imshow('img', cv2.resize(cropped[:, :, ::-1] * training_mask.numpy()[:, :, None], (640, 640)))
cv2.imshow('training_mask', cv2.resize(training_mask.numpy() * 255, (640, 640)))
cv2.imshow('classification', cv2.resize(classification.numpy() * 255, (640, 640)))
regression = regression.numpy()
for i in range(4):
m = np.amax(regression[i])
if 0 != m:
cv2.imshow(str(i), cv2.resize(regression[i, :, :] / m, (640, 640)))
else:
cv2.imshow(str(i), cv2.resize(regression[i, :, :], (640, 640)))
thetas = thetas.numpy()
minim = np.amin(thetas)
m = np.amax(thetas)
print(m * 180 / np.pi)
cv2.imshow('angle', cv2.resize(np.array(np.around(thetas * 255 / m * 180 / np.pi), dtype=np.uint8), (640, 640)))
cv2.waitKey(0)