-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: imgaug library support + tests
- Loading branch information
Showing
4 changed files
with
279 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
|
||
from zetta_utils.augmentations.imgaug import imgaug_augment | ||
|
||
from ..helpers import assert_array_equal | ||
from imgaug import augmenters as iaa | ||
|
||
|
||
def test_imgaug_basic_ndarray(): | ||
image = np.random.randint(0, 255, (3, 128, 128, 2), dtype=np.uint8) | ||
segmap = np.random.randint(0, 2 ** 15, (1, 64, 64, 2), dtype=np.uint16) | ||
keypoints = [[(0.0, 0.0), (0.0, 128.0)], [(128.0, 128.0), (64.0, 64.0)]] | ||
|
||
aug = iaa.Rot90() | ||
augmented = imgaug_augment( | ||
augmenters=aug, images=image, segmentation_maps=segmap, keypoints=keypoints | ||
) | ||
|
||
assert augmented.keys() == {"images", "segmentation_maps", "keypoints"} | ||
assert augmented["images"].shape == (3, 128, 128, 2) | ||
assert augmented["segmentation_maps"].shape == (1, 64, 64, 2) | ||
assert_array_equal(augmented["images"], np.rot90(image, axes=(2, 1))) | ||
assert_array_equal(augmented["segmentation_maps"], np.rot90(segmap, axes=(2, 1))) | ||
assert augmented["keypoints"] == [[(128.0, 0.0), (0.0, 0.0)], [(0.0, 128.0), (64.0, 64.0)]] | ||
|
||
|
||
def test_imgaug_basic_lists(): | ||
image = [np.random.randint(0, 255, (1, 128, 128, 1), dtype=np.uint8) for _ in range(2)] | ||
heatmap = [np.random.rand(3, 64, 64, 1).astype(np.float32) for _ in range(2)] | ||
|
||
aug = iaa.Fliplr() | ||
augmented = imgaug_augment(augmenters=[aug], images=image, heatmaps=heatmap) | ||
|
||
assert augmented.keys() == {"images", "heatmaps"} | ||
assert len(augmented["images"]) == 2 | ||
assert len(augmented["heatmaps"]) == 2 | ||
assert_array_equal(augmented["images"][0], np.flip(image[0], axis=2)) | ||
assert_array_equal(augmented["images"][1], np.flip(image[1], axis=2)) | ||
assert_array_equal(augmented["heatmaps"][0], np.flip(heatmap[0], axis=2)) | ||
assert_array_equal(augmented["heatmaps"][1], np.flip(heatmap[1], axis=2)) | ||
|
||
|
||
def test_imgaug_custom_lists(): | ||
image = [np.random.randint(0, 255, (1, 128, 128, 1), dtype=np.uint8) for _ in range(2)] | ||
seg = [np.random.randint(0, 2 ** 15, (1, 128, 128, 1), dtype=np.uint16) for _ in range(2)] | ||
aff = [np.random.rand(3, 64, 64, 1).astype(np.float32) for _ in range(2)] | ||
|
||
aug = iaa.Add(10) | ||
augmented = imgaug_augment( | ||
augmenters=aug, | ||
src_img=image[0], | ||
tgt_img=image[1], | ||
src_seg=seg[0], | ||
tgt_seg=seg[1], | ||
src_aff=aff[0], | ||
tgt_aff=aff[1], | ||
) | ||
|
||
assert augmented.keys() == {"src_img", "tgt_img", "src_seg", "tgt_seg", "src_aff", "tgt_aff"} | ||
assert augmented["src_img"].shape == (1, 128, 128, 1) | ||
assert augmented["tgt_img"].shape == (1, 128, 128, 1) | ||
assert augmented["src_seg"].shape == (1, 128, 128, 1) | ||
assert augmented["tgt_seg"].shape == (1, 128, 128, 1) | ||
assert augmented["src_aff"].shape == (3, 64, 64, 1) | ||
assert augmented["tgt_aff"].shape == (3, 64, 64, 1) | ||
|
||
assert_array_equal(augmented["src_img"], (image[0].clip(0, 245) + 10)) | ||
|
||
|
||
def test_imgaug_mixed_lists(): | ||
image_group = np.random.randint(0, 255, (2, 64, 64, 10), dtype=np.uint8) | ||
another_image = np.random.randint(0, 255, (3, 1024, 1024, 1), dtype=np.uint8) | ||
|
||
aug = iaa.Invert() | ||
augmented = imgaug_augment( | ||
augmenters=aug, | ||
images=image_group, | ||
another_img=another_image, | ||
) | ||
|
||
assert augmented.keys() == {"images", "another_img"} | ||
assert augmented["images"].shape == (2, 64, 64, 10) | ||
assert augmented["another_img"].shape == (3, 1024, 1024, 1) | ||
assert_array_equal(augmented["images"], np.invert(image_group)) | ||
assert_array_equal(augmented["another_img"], np.invert(another_image)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from . import tensor | ||
from .common import prob_aug | ||
from .imgaug import imgaug_augment |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Final, Optional, Sequence, Sized, Tuple, TypeVar, Union | ||
|
||
from imgaug import augmenters as iaa | ||
from imgaug.augmenters.meta import Augmenter | ||
|
||
from zetta_utils import builder | ||
from zetta_utils.tensor_ops import common, convert | ||
from zetta_utils.tensor_typing import Tensor | ||
|
||
SizedTypeVar = TypeVar("SizedTypeVar", bound=Sized) | ||
TensorListTypeVar = TypeVar( | ||
"TensorListTypeVar", Tensor, Sequence[Tensor], Union[Tensor, Sequence[Tensor]] | ||
) | ||
|
||
SUFFIX_MAPPING: Final = { | ||
"img": "images", | ||
"seg": "segmentation_maps", | ||
"hm": "heatmaps", | ||
"aff": "heatmaps", | ||
"kp": "keypoints", | ||
"bb": "bounding_boxes", | ||
"poly": "polygons", | ||
"ls": "line_strings", | ||
} | ||
|
||
|
||
def _ensure_list(x: SizedTypeVar | Sequence[SizedTypeVar] | None) -> Sequence[SizedTypeVar]: | ||
if x is None: | ||
return [] | ||
elif isinstance(x, Sequence): | ||
return x | ||
else: | ||
return [x] | ||
|
||
|
||
def _group_kwargs( | ||
**kwargs: Optional[SizedTypeVar], | ||
) -> tuple[dict[str, dict[str, SizedTypeVar]], dict[str, str]]: | ||
groups: dict[str, dict[str, SizedTypeVar]] = {} | ||
kwarg_mapping = {} | ||
for k, v in kwargs.items(): | ||
if v is None: | ||
continue | ||
group_name, suffix = k.rsplit("_", 1) | ||
try: | ||
mapped_suffix = SUFFIX_MAPPING[suffix] | ||
kwarg_mapping[k] = mapped_suffix | ||
except KeyError: | ||
raise ValueError( | ||
f"Expected suffix `_img`, `_seg`, or `_aff` in custom augmentable key {k}" | ||
) | ||
|
||
groups.setdefault(group_name, {})[mapped_suffix] = v | ||
|
||
return groups, kwarg_mapping | ||
|
||
|
||
def _ungroup_kwargs( | ||
groups: dict[str, dict[str, SizedTypeVar]], kwarg_mapping: dict[str, str] | ||
) -> dict[str, SizedTypeVar]: | ||
kwargs: dict[str, SizedTypeVar] = {} | ||
for k, mapped_suffix in kwarg_mapping.items(): | ||
group_name, suffix = k.rsplit("_", 1) | ||
if group_name == "_args": | ||
kwargs[mapped_suffix] = groups[group_name][mapped_suffix] | ||
else: | ||
kwargs[f"{group_name}_{suffix}"] = groups[group_name][mapped_suffix] | ||
|
||
return kwargs | ||
|
||
|
||
def _ensure_nxyc(x: TensorListTypeVar) -> TensorListTypeVar: | ||
if isinstance(x, Sequence): | ||
return [convert.to_np(common.rearrange(v, pattern="C X Y 1 -> X Y C")) for v in x] | ||
else: | ||
return convert.to_np(common.rearrange(x, pattern="C X Y N -> N X Y C")) | ||
|
||
|
||
def _ensure_cxyn(x: TensorListTypeVar, ref: TensorListTypeVar) -> TensorListTypeVar: | ||
if isinstance(x, Sequence): | ||
return [ | ||
convert.astype(common.rearrange(v, pattern="X Y C -> C X Y 1"), reference=ref[i]) | ||
for i, v in enumerate(x) | ||
] | ||
else: | ||
return convert.astype(common.rearrange(x, pattern="N X Y C -> C X Y N"), reference=ref) | ||
|
||
|
||
@builder.register("imgaug_augment") | ||
def imgaug_augment( | ||
augmenters: Augmenter | Sequence[Augmenter], | ||
*, | ||
images: Tensor | Sequence[Tensor] | None = None, | ||
heatmaps: Tensor | Sequence[Tensor] | None = None, | ||
segmentation_maps: Tensor | Sequence[Tensor] | None = None, | ||
keypoints: Sequence[Sequence[Tuple[float, float]]] | None = None, | ||
bounding_boxes: Sequence[Sequence[Tuple[float, float, float, float]]] | None = None, | ||
polygons: Sequence[Sequence[Sequence[Tuple[float, float]]]] | None = None, | ||
line_strings: Sequence[Sequence[Sequence[Tuple[float, float]]]] | None = None, | ||
**kwargs: Tensor | Sequence[Tensor], | ||
) -> dict[str, Any]: | ||
"""This function is a wrapper for imgaug.augment to handle the CXYZ/ZXYC conversion. | ||
It will call each provided augmenter on the provided augmentable dict. | ||
For additionally supported types, see: | ||
https://github.com/aleju/imgaug/blob/0101108d4fed06bc5056c4a03e2bcb0216dac326/imgaug/augmenters/meta.py#L1757-L1842 | ||
:param augmenters: A sequence of imgaug augmenters. | ||
:param images: Either CXYZ tensor or list of CXY1 tensors. If not specified, at least one kwarg with `_img` suffix is required. | ||
:param heatmaps: Either CXYZ tensor or list of CXY1 tensors. | ||
:param segmentation_maps: Either CXYZ tensor or list of CXY1 tensors. | ||
:param keypoints: List of lists of (x, y) coordinates. | ||
:param bounding_boxes: List of lists of (x1, y1, x2, y2) coordinates. | ||
:param polygons: List of lists of lists of (x, y) coordinates. | ||
:param line_strings: List of lists of lists of (x, y) coordinates. | ||
:param kwargs: Additional/alternative augmentables, each a CXYZ tensor or list of CXY1 tensors and | ||
suffixes: `_img`, `_seg`, `_hm`/`_aff`, `_kp`, `_bb`, `_poly`, `_ls`. | ||
:return: Augmented dictionary, same keys as input. | ||
""" | ||
augmenter: Augmenter = iaa.Sequential(_ensure_list(augmenters)).to_deterministic() | ||
augmentables, kwarg_mapping = _group_kwargs( | ||
_args_img=images, | ||
_args_hm=heatmaps, | ||
_args_seg=segmentation_maps, | ||
_args_kp=keypoints, | ||
_args_bb=bounding_boxes, | ||
_args_poly=polygons, | ||
_args_ls=line_strings, | ||
**kwargs, | ||
) | ||
|
||
if not augmentables: | ||
raise ValueError("Expected at least one image in `images` or `kwargs`") | ||
|
||
for aug_group in augmentables.values(): | ||
res = augmenter.augment( | ||
images=_ensure_nxyc(aug_group["images"]), | ||
heatmaps=_ensure_nxyc(aug_group["heatmaps"]) if "heatmaps" in aug_group else None, | ||
segmentation_maps=_ensure_nxyc(aug_group["segmentation_maps"]) | ||
if "segmentation_maps" in aug_group | ||
else None, | ||
keypoints=aug_group.get("keypoints", None), | ||
bounding_boxes=aug_group.get("bounding_boxes", None), | ||
polygons=aug_group.get("polygons", None), | ||
line_strings=aug_group.get("line_strings", None), | ||
return_batch=True, | ||
) | ||
|
||
aug_group["images"] = _ensure_cxyn(res.images_aug, aug_group["images"]) | ||
if "heatmaps" in aug_group: | ||
aug_group["heatmaps"] = _ensure_cxyn(res.heatmaps_aug, aug_group["heatmaps"]) | ||
if "segmentation_maps" in aug_group: | ||
aug_group["segmentation_maps"] = _ensure_cxyn( | ||
res.segmentation_maps_aug, aug_group["segmentation_maps"] | ||
) | ||
if "keypoints" in aug_group: | ||
aug_group["keypoints"] = res.keypoints_aug | ||
if "bounding_boxes" in aug_group: | ||
aug_group["bounding_boxes"] = res.bounding_boxes_aug | ||
if "polygons" in aug_group: | ||
aug_group["polygons"] = res.polygons_aug | ||
if "line_strings" in aug_group: | ||
aug_group["line_strings"] = res.line_strings_aug | ||
|
||
return _ungroup_kwargs(augmentables, kwarg_mapping=kwarg_mapping) | ||
|
||
|
||
def _unpack_kwargs(cls): | ||
class wrapper(cls): # type: ignore # mypy doesn't like dynamic base classes | ||
def __call__(kwargs_tuple): | ||
super.__call__(**kwargs_tuple) | ||
|
||
return wrapper | ||
|
||
|
||
for k in dir(iaa): | ||
if k[0].isupper() and hasattr(getattr(iaa, k), "augment"): | ||
builder.register(f"imgaug.augmenters.{k}")(_unpack_kwargs(getattr(iaa, k))) | ||
|
||
for k in dir(iaa.imgcorruptlike): | ||
if k[0].isupper() and hasattr(getattr(iaa.imgcorruptlike, k), "augment"): | ||
builder.register(f"imgaug.augmenters.imgcorruptlike.{k}")( | ||
_unpack_kwargs(getattr(iaa.imgcorruptlike, k)) | ||
) |