-
Notifications
You must be signed in to change notification settings - Fork 0
/
augmentation.py
117 lines (81 loc) · 3.48 KB
/
augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import random
import numpy as np
import torch
from skimage.transform import rotate
class Compose:
def __init__(self, transforms=None):
self.transforms = transforms
def __call__(self, sample):
for transform in self.transforms:
sample = transform(sample)
return sample
class NormalizeIntensity:
def __call__(self, sample):
img = sample['input']
img[:, :, :] = self.normalize_ct(img[:, :, :])
sample['input'] = img
return sample
@staticmethod
def normalize_ct(img):
norm_img = np.clip(img, -1024, 1024) / 1024
return norm_img
class ToTensor:
def __init__(self, mode='train'):
if mode not in ['train', 'test']:
raise ValueError(f"Argument 'mode' must be 'train' or 'test'. Received {mode}")
self.mode = mode
def __call__(self, sample):
if self.mode == 'train':
img, mask = sample['input'], sample['target']
img = np.expand_dims(img, axis=0)
mask = np.expand_dims(mask, axis=0)
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
sample['input'], sample['target'] = img, mask
else: # if self.mode == 'test'
img = sample['input']
img = torch.from_numpy(img).float()
sample['input'] = img
return sample
class Mirroring:
def __init__(self, p=0.5):
self.p = p
def __call__(self, sample):
if random.random() < self.p:
img, mask = sample['input'], sample['target']
n_axes = random.randint(0, 2)
random_axes = random.sample(range(2), n_axes)
img = np.flip(img, axis=tuple(random_axes))
mask = np.flip(mask, axis=tuple(random_axes))
sample['input'], sample['target'] = img.copy(), mask.copy()
return sample
class RandomRotation:
def __init__(self, p=0.5, angle_range=[5, 15]):
self.p = p
self.angle_range = angle_range
def __call__(self, sample):
if random.random() < self.p:
img, mask = sample['input'], sample['target']
num_of_seqs = img.shape[-1]
n_axes = random.randint(1, 3)
random_axes = random.sample([0, 1, 2], n_axes)
for axis in random_axes:
angle = random.randrange(*self.angle_range)
angle = -angle if random.random() < 0.5 else angle
img[:, :, :] = RandomRotation.rotate_3d_along_axis(img[:, :, :], angle, axis, 1)
mask[:, :, :] = RandomRotation.rotate_3d_along_axis(mask[:, :, :], angle, axis, 0)
sample['input'], sample['target'] = img, mask
return sample
@staticmethod
def rotate_3d_along_axis(img, angle, axis, order):
if axis == 0:
rot_img = rotate(img, angle, order=order, mode='symmetric', preserve_range=True)
if axis == 1:
rot_img = np.transpose(img, axes=(1, 2, 0))
rot_img = rotate(rot_img, angle, order=order, mode='symmetric', preserve_range=True)
rot_img = np.transpose(rot_img, axes=(2, 0, 1))
if axis == 2:
rot_img = np.transpose(img, axes=(2, 0, 1))
rot_img = rotate(rot_img, angle, order=order, mode='symmetric', preserve_range=True)
rot_img = np.transpose(rot_img, axes=(1, 2, 0))
return rot_img