Skip to content

Commit

Permalink
Added GridDropout and RandomDropPlane
Browse files Browse the repository at this point in the history
  • Loading branch information
ZFTurbo committed Jul 24, 2020
1 parent 6392f6c commit bead760
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,7 @@ class ElasticTransformPseudo2D(DualTransform):
class ElasticTransform(DualTransform):
class Rotate(DualTransform):
class RandomCropFromBorders(DualTransform):
class GridDropout(DualTransform):
class RandomDropPlane(DualTransform):
```

10 changes: 9 additions & 1 deletion volumentations/augmentations/functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,12 @@ def clamping_crop(img, sh0_min, sh1_min, sh2_min, sh0_max, sh1_max, sh2_max):
sh1_max = h
if sh2_max > w:
sh2_max = w
return img[int(sh0_min): int(sh0_max), int(sh1_min): int(sh1_max), int(sh2_min): int(sh2_max)]
return img[int(sh0_min): int(sh0_max), int(sh1_min): int(sh1_max), int(sh2_min): int(sh2_max)]


def cutout(img, holes, fill_value=0):
# Make a copy of the input image since we don't want to modify it directly
img = img.copy()
for x1, y1, z1, x2, y2, z2 in holes:
img[y1:y2, x1:x2, z1:z2] = fill_value
return img
209 changes: 209 additions & 0 deletions volumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,212 @@ def apply(self, img, sh0_min=0, sh0_max=0, sh1_min=0, sh1_max=0, sh2_min=0, sh2_

def apply_to_mask(self, mask, sh0_min=0, sh0_max=0, sh1_min=0, sh1_max=0, sh2_min=0, sh2_max=0, **params):
return F.clamping_crop(mask, sh0_min, sh1_min, sh2_min, sh0_max, sh1_max, sh2_max)


class GridDropout(DualTransform):
"""GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion.
Args:
ratio (float): the ratio of the mask holes to the unit_size (same for horizontal and vertical directions).
Must be between 0 and 1. Default: 0.5.
unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge.
If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
unit_size_max (int): maximum size of the grid unit. Must be between 2 and the image shorter edge.
If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
holes_number_x (int): the number of grid units in x direction. Must be between 1 and image width//2.
If 'None', grid unit width is set as image_width//10. Default: `None`.
holes_number_y (int): the number of grid units in y direction. Must be between 1 and image height//2.
If `None`, grid unit height is set equal to the grid unit width or image height, whatever is smaller.
holes_number_z (int): the number of grid units in z direction. Must be between 1 and image depth//2.
If `None`, grid unit depth is set equal to the grid unit width or image height, whatever is smaller.
shift_x (int): offsets of the grid start in x direction from (0,0) coordinate.
Clipped between 0 and grid unit_width - hole_width. Default: 0.
shift_y (int): offsets of the grid start in y direction from (0,0) coordinate.
Clipped between 0 and grid unit height - hole_height. Default: 0.
shift_z (int): offsets of the grid start in z direction from (0,0) coordinate.
Clipped between 0 and grid unit depth - hole_depth. Default: 0.
random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size
If 'True', entered shift_x, shift_y, shift_z are ignored and set randomly. Default: `False`.
fill_value (int): value for the dropped pixels. Default = 0
mask_fill_value (int): value for the dropped pixels in mask.
If `None`, tranformation is not applied to the mask. Default: `None`.
Targets:
image, mask
Image types:
uint8, float32
References:
https://arxiv.org/abs/2001.04086
"""

def __init__(
self,
ratio: float = 0.5,
unit_size_min: int = None,
unit_size_max: int = None,
holes_number_x: int = None,
holes_number_y: int = None,
holes_number_z: int = None,
shift_x: int = 0,
shift_y: int = 0,
shift_z: int = 0,
random_offset: bool = False,
fill_value: int = 0,
mask_fill_value: int = None,
always_apply: bool = False,
p: float = 0.5,
):
super(GridDropout, self).__init__(always_apply, p)
self.ratio = ratio
self.unit_size_min = unit_size_min
self.unit_size_max = unit_size_max
self.holes_number_x = holes_number_x
self.holes_number_y = holes_number_y
self.holes_number_z = holes_number_z
self.shift_x = shift_x
self.shift_y = shift_y
self.shift_z = shift_z
self.random_offset = random_offset
self.fill_value = fill_value
self.mask_fill_value = mask_fill_value
if not 0 < self.ratio <= 1:
raise ValueError("ratio must be between 0 and 1.")

def apply(self, image, holes=(), **params):
return F.cutout(image, holes, self.fill_value)

def apply_to_mask(self, image, holes=(), **params):
if self.mask_fill_value is None:
return image

return F.cutout(image, holes, self.mask_fill_value)

def get_params(self, **data):
img = data["image"]
height, width, depth = img.shape[:3]
# set grid using unit size limits
if self.unit_size_min and self.unit_size_max:
if not 2 <= self.unit_size_min <= self.unit_size_max:
raise ValueError("Max unit size should be >= min size, both at least 2 pixels.")
if self.unit_size_max > min(height, width):
raise ValueError("Grid size limits must be within the shortest image edge.")
unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1)
unit_height = unit_width
unit_depth = unit_width
else:
# set grid using holes numbers
if self.holes_number_x is None:
unit_width = max(2, width // 10)
else:
if not 1 <= self.holes_number_x <= width // 2:
raise ValueError("The hole_number_x must be between 1 and image width//2.")
unit_width = width // self.holes_number_x
if self.holes_number_y is None:
unit_height = max(min(unit_width, height), 2)
else:
if not 1 <= self.holes_number_y <= height // 2:
raise ValueError("The hole_number_y must be between 1 and image height//2.")
unit_height = height // self.holes_number_y
if self.holes_number_z is None:
unit_depth = max(min(unit_height, depth), 2)
else:
if not 1 <= self.holes_number_z <= depth // 2:
raise ValueError("The hole_number_z must be between 1 and image depth//2.")
unit_depth = depth // self.holes_number_z

hole_width = int(unit_width * self.ratio)
hole_height = int(unit_height * self.ratio)
hole_depth = int(unit_depth * self.ratio)
# min 1 pixel and max unit length - 1
hole_width = min(max(hole_width, 1), unit_width - 1)
hole_height = min(max(hole_height, 1), unit_height - 1)
hole_depth = min(max(hole_depth, 1), unit_depth - 1)
# set offset of the grid
if self.shift_x is None:
shift_x = 0
else:
shift_x = min(max(0, self.shift_x), unit_width - hole_width)
if self.shift_y is None:
shift_y = 0
else:
shift_y = min(max(0, self.shift_y), unit_height - hole_height)
if self.shift_z is None:
shift_z = 0
else:
shift_z = min(max(0, self.shift_z), unit_depth - hole_depth)
if self.random_offset:
shift_x = random.randint(0, unit_width - hole_width)
shift_y = random.randint(0, unit_height - hole_height)
shift_z = random.randint(0, unit_depth - hole_depth)
holes = []
for i in range(width // unit_width + 1):
for j in range(height // unit_height + 1):
for k in range(depth // unit_depth + 1):
x1 = min(shift_x + unit_width * i, width)
y1 = min(shift_y + unit_height * j, height)
z1 = min(shift_z + unit_depth * j, depth)
x2 = min(x1 + hole_width, width)
y2 = min(y1 + hole_height, height)
z2 = min(z1 + hole_depth, depth)
holes.append((x1, y1, z1, x2, y2, z2))

return {"holes": holes}

def get_transform_init_args_names(self):
return (
"ratio",
"unit_size_min",
"unit_size_max",
"holes_number_x",
"holes_number_y",
"shift_x",
"shift_y",
"mask_fill_value",
"random_offset",
)


class RandomDropPlane(DualTransform):
"""Randomly drop some planes in random axis
Args:
plane_drop_prob (float): float value in (0.0, 1.0) range. Default: 0.1
axes (tuple). Default: 0
p (float): probability of applying the transform. Default: 1.
Targets:
image, mask
Image types:
uint8, float32
"""

def __init__(
self,
plane_drop_prob=0.1,
axes=(0,),
always_apply=False,
p=1.0
):
super(RandomDropPlane, self).__init__(always_apply, p)
self.plane_drop_prob = plane_drop_prob
self.axes = axes

def get_params(self, **data):
img = data["image"]
axis = random.choice(self.axes)
r = img.shape[axis]
indexes = []
for i in range(r):
if random.uniform(0, 1) > self.plane_drop_prob:
indexes.append(i)
if len(indexes) == 0:
indexes.append(0)

return {
"indexes": indexes, "axis": axis,
}

def apply(self, img, indexes=(), axis=0, **params):
return np.take(img, indexes, axis=axis)

def apply_to_mask(self, mask, indexes=(), axis=0, **params):
return np.take(mask, indexes, axis=axis)

0 comments on commit bead760

Please sign in to comment.