Skip to content

Commit

Permalink
fix: apply pre-commit hooks and formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Oct 4, 2024
1 parent fccd251 commit 20f788b
Showing 1 changed file with 76 additions and 32 deletions.
108 changes: 76 additions & 32 deletions luxonis_ml/data/augmentations/custom/mosaic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -8,13 +9,12 @@
)
from albumentations.core.transforms_interface import (
BoxInternalType,
ImageColorType,
KeypointInternalType,
)

from ..batch_transform import BatchBasedTransform
from ..utils import AUGMENTATIONS
import random


@AUGMENTATIONS.register_module()
class Mosaic4(BatchBasedTransform):
Expand Down Expand Up @@ -99,7 +99,8 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]:
)

def _generate_random_crop_center(self) -> Tuple[int, int]:
"""Generate a random crop center within the bounds of the mosaic image size."""
"""Generate a random crop center within the bounds of the mosaic
image size."""
crop_x = random.randint(0, max(0, self.out_width))
crop_y = random.randint(0, max(0, self.out_height))
return crop_x, crop_y
Expand All @@ -114,7 +115,12 @@ def targets_as_params(self):
return ["image_batch"]

def apply_to_image_batch(
self, image_batch: List[np.ndarray], indices: List[int], x_crop: int, y_crop: int, **params
self,
image_batch: List[np.ndarray],
indices: List[int],
x_crop: int,
y_crop: int,
**params,
) -> List[np.ndarray]:
"""Applies the transformation to a batch of images.
Expand All @@ -139,13 +145,23 @@ def apply_to_image_batch(
]
image_chunk = [image_batch[i] for i in idx_chunk]
mosaiced = mosaic4(
image_chunk, self.out_height, self.out_width, x_crop, y_crop, self.value,
image_chunk,
self.out_height,
self.out_width,
x_crop,
y_crop,
self.value,
)
output_batch.append(mosaiced)
return output_batch

def apply_to_mask_batch(
self, mask_batch: List[np.ndarray], indices: List[int], x_crop: int, y_crop: int, **params
self,
mask_batch: List[np.ndarray],
indices: List[int],
x_crop: int,
y_crop: int,
**params,
) -> List[np.ndarray]:
"""Applies the transformation to a batch of masks.
Expand All @@ -170,7 +186,12 @@ def apply_to_mask_batch(
]
mask_chunk = [mask_batch[i] for i in idx_chunk]
mosaiced = mosaic4(
mask_chunk, self.out_height, self.out_width, x_crop, y_crop, self.mask_value
mask_chunk,
self.out_height,
self.out_width,
x_crop,
y_crop,
self.mask_value,
)
output_batch.append(mosaiced)
return output_batch
Expand Down Expand Up @@ -299,7 +320,7 @@ def get_params_dependent_on_targets(
f"The batch size (= {n}) should be larger than "
+ f"{self.n_tiles} x out_batch_size (= {self.n_tiles * self.out_batch_size})"
)
indices = [0,1,2,3]
indices = [0, 1, 2, 3]
image_shapes = [tuple(image.shape[:2]) for image in image_batch]
x_crop, y_crop = self._generate_random_crop_center()
return {
Expand All @@ -309,6 +330,7 @@ def get_params_dependent_on_targets(
"y_crop": y_crop,
}


def mosaic4(
image_batch: List[np.ndarray],
height: int,
Expand All @@ -317,12 +339,12 @@ def mosaic4(
y_crop: int,
value: Optional[int] = None,
) -> np.ndarray:
"""Arrange the images in a 2x2 grid layout. The input images should have the same
number of channels but can have different widths and heights. The gaps are filled by
the value.
"""Arrange the images in a 2x2 grid layout. The input images should
have the same number of channels but can have different widths and
heights. The gaps are filled by the value.
@param image_batch: Image list. The length should be four. Each image can has
different size.
@param image_batch: Image list. The length should be four. Each
image can has different size.
@type image_batch: List[np.ndarray]
@param height: Height of output mosaic image
@type height: int
Expand All @@ -339,11 +361,15 @@ def mosaic4(
"""
N_TILES = 4
if len(image_batch) != N_TILES:
raise ValueError(f"Length of image_batch should be 4. Got {len(image_batch)}")
raise ValueError(
f"Length of image_batch should be 4. Got {len(image_batch)}"
)

for i in range(N_TILES - 1):
if image_batch[0].shape[2:] != image_batch[i + 1].shape[2:]:
raise ValueError("All images should have the same number of channels.")
raise ValueError(
"All images should have the same number of channels."
)

dtype = image_batch[0].dtype
img4 = np.full(
Expand All @@ -355,7 +381,6 @@ def mosaic4(
xc = width // 2
yc = height // 2


for i, img in enumerate(image_batch):
(h, w) = img.shape[:2]

Expand All @@ -366,10 +391,20 @@ def mosaic4(
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, width * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(height * 2, yc + h)
x1a, y1a, x2a, y2a = (
max(xc - w, 0),
yc,
xc,
min(height * 2, yc + h),
)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, width * 2), min(height * 2, yc + h)
x1a, y1a, x2a, y2a = (
xc,
yc,
min(xc + w, width * 2),
min(height * 2, yc + h),
)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

img4_region = img4[y1a:y2a, x1a:x2a]
Expand All @@ -384,8 +419,8 @@ def mosaic4(
img4[y1a : y1a + min_h, x1a : x1a + min_w] = img[
y1b : y1b + min_h, x1b : x1b + min_w
]
img4 = img4[y_crop:y_crop+height, x_crop:x_crop+width]

img4 = img4[y_crop : y_crop + height, x_crop : x_crop + width]

return img4

Expand All @@ -400,20 +435,21 @@ def bbox_mosaic4(
x_crop: int,
y_crop: int,
) -> BoxInternalType:
"""Adjust bounding box coordinates to account for mosaic grid position.
"""Adjust bounding box coordinates to account for mosaic grid
position.
This function modifies bounding boxes according to their placement in a 2x2 grid
mosaic, shifting their coordinates based on the tile's relative position within the
mosaic.
This function modifies bounding boxes according to their placement
in a 2x2 grid mosaic, shifting their coordinates based on the tile's
relative position within the mosaic.
@param bbox: Bounding box coordinates to be transformed.
@type bbox: BoxInternalType
@param rows: Height of the original image.
@type rows: int
@param cols: Width of the original image.
@type cols: int
@param position_index: Position of the image in the 2x2 grid. (0 = top-left, 1 =
top-right, 2 = bottom-left, 3 = bottom-right).
@param position_index: Position of the image in the 2x2 grid. (0 =
top-left, 1 = top-right, 2 = bottom-left, 3 = bottom-right).
@type position_index: int
@param height: Height of the final output mosaic image.
@type height: int
Expand Down Expand Up @@ -444,12 +480,18 @@ def bbox_mosaic4(
shift_x = xc
shift_y = yc

bbox = (bbox[0] + shift_x - x_crop, bbox[1] + shift_y - y_crop, bbox[2] + shift_x - x_crop, bbox[3] + shift_y - y_crop)
bbox = (
bbox[0] + shift_x - x_crop,
bbox[1] + shift_y - y_crop,
bbox[2] + shift_x - x_crop,
bbox[3] + shift_y - y_crop,
)

bbox = normalize_bbox(bbox, height, width)

return bbox


def keypoint_mosaic4(
keypoint: KeypointInternalType,
rows: int,
Expand All @@ -462,17 +504,19 @@ def keypoint_mosaic4(
) -> KeypointInternalType:
"""Adjust keypoint coordinates based on mosaic grid position.
This function adjusts the keypoint coordinates by placing them in one of the 2x2
mosaic grid cells, with shifts relative to the mosaic center.
This function adjusts the keypoint coordinates by placing them in
one of the 2x2 mosaic grid cells, with shifts relative to the mosaic
center.
@param keypoint: Keypoint coordinates and attributes (x, y, angle, scale).
@param keypoint: Keypoint coordinates and attributes (x, y, angle,
scale).
@type keypoint: KeypointInternalType
@param rows: Height of the original image.
@type rows: int
@param cols: Width of the original image.
@type cols: int
@param position_index: Position of the image in the 2x2 grid. (0 = top-left, 1 =
top-right, 2 = bottom-left, 3 = bottom-right).
@param position_index: Position of the image in the 2x2 grid. (0 =
top-left, 1 = top-right, 2 = bottom-left, 3 = bottom-right).
@type position_index: int
@param height: Height of the final output mosaic image.
@type height: int
Expand Down

0 comments on commit 20f788b

Please sign in to comment.