Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augmentation SpeedUp #147

Merged
merged 15 commits into from
Aug 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 81 additions & 56 deletions luxonis_ml/data/augmentations/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import albumentations as A
import cv2
Expand All @@ -7,7 +7,7 @@
from luxonis_ml.utils.registry import Registry

from ..utils.enums import LabelType
from .batch_compose import BatchCompose, ForEach
from .batch_compose import BatchCompose
from .batch_transform import BatchBasedTransform

AUGMENTATIONS = Registry(name="augmentations")
Expand Down Expand Up @@ -69,26 +69,21 @@ def _parse_cfg(
else:
resize = A.Resize(image_size[0], image_size[1])

pixel_augs = []
spatial_augs = []
batched_augs = []
spatial_augs.append(resize)
if augmentations:
for aug in augmentations:
curr_aug = AUGMENTATIONS.get(aug["name"])(**aug.get("params", {}))
if isinstance(curr_aug, A.ImageOnlyTransform):
pixel_augs.append(curr_aug)
elif isinstance(curr_aug, A.DualTransform):
spatial_augs.append(curr_aug)
elif isinstance(curr_aug, BatchBasedTransform):
if isinstance(curr_aug, BatchBasedTransform):
self.is_batched = True
self.aug_batch_size = max(self.aug_batch_size, curr_aug.batch_size)
batched_augs.append(curr_aug)
# NOTE: always perform resize last
spatial_augs.append(resize)
else:
spatial_augs.append(curr_aug)

batch_transform = BatchCompose(
[
ForEach(pixel_augs),
*batched_augs,
],
bbox_params=A.BboxParams(
Expand Down Expand Up @@ -136,6 +131,10 @@ def __call__(
@return: Output image and its annotations
"""

present_annotations = {
key for _, annotations in data for key in annotations.keys()
}
return_mask = LabelType.SEGMENTATION in present_annotations
image_batch = []
mask_batch = []
bboxes_batch = []
Expand All @@ -145,10 +144,8 @@ def __call__(
keypoints_visibility_batch = []
keypoints_classes_batch = []

present_annotations = set()
bbox_counter = 0
for img, annotations in data:
present_annotations.update(annotations.keys())
(
classes,
mask,
Expand All @@ -157,10 +154,13 @@ def __call__(
keypoints_points,
keypoints_visibility,
keypoints_classes,
) = self.prepare_img_annotations(annotations, *img.shape[:-1], nk=nk)
) = self.prepare_img_annotations(
annotations, *img.shape[:-1], nk=nk, return_mask=return_mask
)

image_batch.append(img)
mask_batch.append(mask)
if return_mask:
mask_batch.append(mask)

bboxes_batch.append(bboxes_points)
bboxes_visibility_batch.append(
Expand All @@ -173,32 +173,37 @@ def __call__(
keypoints_visibility_batch.append(keypoints_visibility)
keypoints_classes_batch.append(keypoints_classes)

# Apply transforms
# NOTE: All keys (including label_fields) must have _batch suffix when using BatchCompose
transformed = self.batch_transform(
image_batch=image_batch,
mask_batch=mask_batch,
bboxes_batch=bboxes_batch,
bboxes_visibility_batch=bboxes_visibility_batch,
bboxes_classes_batch=bboxes_classes_batch,
keypoints_batch=keypoints_batch,
keypoints_visibility_batch=keypoints_visibility_batch,
keypoints_classes_batch=keypoints_classes_batch,
)
transform_args = {
"image_batch": image_batch,
"bboxes_batch": bboxes_batch,
"bboxes_visibility_batch": bboxes_visibility_batch,
"bboxes_classes_batch": bboxes_classes_batch,
"keypoints_batch": keypoints_batch,
"keypoints_visibility_batch": keypoints_visibility_batch,
"keypoints_classes_batch": keypoints_classes_batch,
}
if return_mask:
transform_args["mask_batch"] = mask_batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we see any concerns with mask_batch missing downstream? Would it be safer to keep it in dict but have it as an empty value?

Copy link
Contributor

@JSabadin JSabadin Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the augmentations will still work if mask_batch is simply missing. I think it does not pose any threat.

Additionally, if we simply enter an empty value for masks instead of adding it when necessary, it fails.
It cannot be empty; it would have to contain zeros, but this is slow because augmentations are then applied to that dummy mask. This is what we were doing before, and it was slowing the code down

Moreover, using dummy keypoints (just one keypoint at x:0, y:0) or bounding boxes does not affect the time it takes to apply the augmentations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC: @JSabadin


# convert to numpy arrays
for key in transformed:
transformed[key] = np.array(transformed[key][0])
# Apply transforms
transformed = self.batch_transform(force_apply=False, **transform_args)
transformed = {key: np.array(value[0]) for key, value in transformed.items()}

# Prepare the spatial transform arguments
spatial_transform_args = {
"image": transformed["image_batch"],
"bboxes": transformed["bboxes_batch"],
"bboxes_visibility": transformed["bboxes_visibility_batch"],
"bboxes_classes": transformed["bboxes_classes_batch"],
"keypoints": transformed["keypoints_batch"],
"keypoints_visibility": transformed["keypoints_visibility_batch"],
"keypoints_classes": transformed["keypoints_classes_batch"],
}
if return_mask:
spatial_transform_args["mask"] = transformed["mask_batch"]

transformed = self.spatial_transform(
image=transformed["image_batch"],
mask=transformed["mask_batch"],
bboxes=transformed["bboxes_batch"],
bboxes_visibility=transformed["bboxes_visibility_batch"],
bboxes_classes=transformed["bboxes_classes_batch"],
keypoints=transformed["keypoints_batch"],
keypoints_visibility=transformed["keypoints_visibility_batch"],
keypoints_classes=transformed["keypoints_classes_batch"],
force_apply=False, **spatial_transform_args
)

out_image, out_mask, out_bboxes, out_keypoints = self.post_transform_process(
Expand All @@ -207,12 +212,13 @@ def __call__(
nk=nk,
filter_kpts_by_bbox=(LabelType.BOUNDINGBOX in present_annotations)
and (LabelType.KEYPOINTS in present_annotations),
return_mask=return_mask,
)

out_annotations = {}
for key in present_annotations:
if key == LabelType.CLASSIFICATION:
out_annotations[LabelType.CLASSIFICATION] = classes
out_annotations[LabelType.CLASSIFICATION] = classes # type: ignore
elif key == LabelType.SEGMENTATION:
out_annotations[LabelType.SEGMENTATION] = out_mask
elif key == LabelType.BOUNDINGBOX:
Expand All @@ -223,10 +229,15 @@ def __call__(
return out_image, out_annotations

def prepare_img_annotations(
self, annotations: Dict[LabelType, np.ndarray], ih: int, iw: int, nk: int
self,
annotations: Dict[LabelType, np.ndarray],
ih: int,
iw: int,
nk: int,
return_mask: bool = True,
) -> Tuple[
np.ndarray,
np.ndarray,
Optional[np.ndarray],
np.ndarray,
np.ndarray,
np.ndarray,
Expand All @@ -241,16 +252,20 @@ def prepare_img_annotations(
@param ih: Input image height
@type iw: int
@param iw: Input image width
@rtype: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
np.ndarray, np.ndarray]
@type return_mask: bool
@param return_mask: Whether to compute and return mask
@rtype: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, np.ndarray,
np.ndarray, np.ndarray, np.ndarray]
@return: Annotations in albumentations format
"""

classes = annotations.get(LabelType.CLASSIFICATION, np.zeros(1))

seg = annotations.get(LabelType.SEGMENTATION, np.zeros((1, ih, iw)))
mask = np.argmax(seg, axis=0) + 1
mask[np.sum(seg, axis=0) == 0] = 0 # only background has value 0
mask = None
if return_mask:
seg = annotations.get(LabelType.SEGMENTATION, np.zeros((1, ih, iw)))
mask = np.argmax(seg, axis=0) + 1
mask[np.sum(seg, axis=0) == 0] = 0 # only background has value 0

# COCO format in albumentations is [x,y,w,h] non-normalized
bboxes = annotations.get(LabelType.BOUNDINGBOX, np.zeros((0, 5)))
Expand Down Expand Up @@ -287,7 +302,8 @@ def post_transform_process(
ns: int,
nk: int,
filter_kpts_by_bbox: bool,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
return_mask: bool = True,
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, np.ndarray]:
"""Postprocessing of albumentations output to LuxonisLoader format.

@type transformed_data: Dict[str, np.ndarray]
Expand All @@ -299,21 +315,30 @@ def post_transform_process(
@type filter_kpts_by_bbox: bool
@param filter_kpts_by_bbox: If True removes keypoint instances if its bounding
box was removed.
@rtype: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
@rtype: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, np.ndarray]
@return: Postprocessed annotations
"""

out_image = transformed_data["image"].astype(np.float32)
out_image = transformed_data["image"]
ih, iw, _ = out_image.shape
if not self.train_rgb:
out_image = cv2.cvtColor(out_image, cv2.COLOR_RGB2BGR)

transformed_mask = transformed_data["mask"]
out_mask = np.zeros((ns, *transformed_mask.shape))
for key in np.unique(transformed_mask):
if key != 0:
out_mask[int(key) - 1, ...] = transformed_mask == key
out_mask[out_mask > 0] = 1
out_image = out_image.astype(np.float32)

out_mask = None
if return_mask:
transformed_mask = transformed_data.get("mask")
out_mask = (
np.zeros((ns, *transformed_mask.shape))
if transformed_mask is not None
else None
)
if transformed_mask is not None:
assert out_mask is not None
for key in np.unique(transformed_mask):
if key != 0:
out_mask[int(key) - 1, ...] = transformed_mask == key
out_mask[out_mask > 0] = 1

if transformed_data["bboxes"]:
transformed_bboxes_classes = np.expand_dims(
Expand Down
Loading