Skip to content

Commit

Permalink
feat: imgaug library support + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Aug 25, 2023
1 parent 6c528fb commit cd58bfe
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ cloudvol = [
"tensorstore == 0.1.33",
"zetta_utils[tensor_ops]",
]
augmentations = ["zetta_utils[tensor_ops]"]
augmentations = [
"zetta_utils[tensor_ops]",
"imgaug == 0.4.0",
"imagecorruptions == 1.1.2",
]
convnet = ["torch >= 2.0", "artificery >= 0.0.3.3"]
alignment = ["metroem >= 0.1.2", "torch >= 2.0"]
mazepa = [
Expand All @@ -87,7 +91,7 @@ mazepa-addons = [
]
segmentation = ["onnx >= 1.13.0", "onnxruntime-gpu >= 1.13.1"]
training = [
"zetta_utils[tensor_ops,cloudvol,convnet,viz,gcs]",
"zetta_utils[tensor_ops,cloudvol,convnet,viz,gcs,augmentations]",
"torch >= 2.0",
"pytorch-lightning == 1.7.7", # mypy error on newer
"torchmetrics == 0.11.4",
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/augmentations/test_imgaug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np

from zetta_utils.augmentations.imgaug import imgaug_augment

from ..helpers import assert_array_equal
from imgaug import augmenters as iaa


def test_imgaug_basic_ndarray():
image = np.random.randint(0, 255, (3, 128, 128, 2), dtype=np.uint8)
segmap = np.random.randint(0, 2 ** 15, (1, 64, 64, 2), dtype=np.uint16)
keypoints = [[(0.0, 0.0), (0.0, 128.0)], [(128.0, 128.0), (64.0, 64.0)]]

aug = iaa.Rot90()
augmented = imgaug_augment(
augmenters=aug, images=image, segmentation_maps=segmap, keypoints=keypoints
)

assert augmented.keys() == {"images", "segmentation_maps", "keypoints"}
assert augmented["images"].shape == (3, 128, 128, 2)
assert augmented["segmentation_maps"].shape == (1, 64, 64, 2)
assert_array_equal(augmented["images"], np.rot90(image, axes=(2, 1)))
assert_array_equal(augmented["segmentation_maps"], np.rot90(segmap, axes=(2, 1)))
assert augmented["keypoints"] == [[(128.0, 0.0), (0.0, 0.0)], [(0.0, 128.0), (64.0, 64.0)]]


def test_imgaug_basic_lists():
image = [np.random.randint(0, 255, (1, 128, 128, 1), dtype=np.uint8) for _ in range(2)]
heatmap = [np.random.rand(3, 64, 64, 1).astype(np.float32) for _ in range(2)]

aug = iaa.Fliplr()
augmented = imgaug_augment(augmenters=[aug], images=image, heatmaps=heatmap)

assert augmented.keys() == {"images", "heatmaps"}
assert len(augmented["images"]) == 2
assert len(augmented["heatmaps"]) == 2
assert_array_equal(augmented["images"][0], np.flip(image[0], axis=2))
assert_array_equal(augmented["images"][1], np.flip(image[1], axis=2))
assert_array_equal(augmented["heatmaps"][0], np.flip(heatmap[0], axis=2))
assert_array_equal(augmented["heatmaps"][1], np.flip(heatmap[1], axis=2))


def test_imgaug_custom_lists():
image = [np.random.randint(0, 255, (1, 128, 128, 1), dtype=np.uint8) for _ in range(2)]
seg = [np.random.randint(0, 2 ** 15, (1, 128, 128, 1), dtype=np.uint16) for _ in range(2)]
aff = [np.random.rand(3, 64, 64, 1).astype(np.float32) for _ in range(2)]

aug = iaa.Add(10)
augmented = imgaug_augment(
augmenters=aug,
src_img=image[0],
tgt_img=image[1],
src_seg=seg[0],
tgt_seg=seg[1],
src_aff=aff[0],
tgt_aff=aff[1],
)

assert augmented.keys() == {"src_img", "tgt_img", "src_seg", "tgt_seg", "src_aff", "tgt_aff"}
assert augmented["src_img"].shape == (1, 128, 128, 1)
assert augmented["tgt_img"].shape == (1, 128, 128, 1)
assert augmented["src_seg"].shape == (1, 128, 128, 1)
assert augmented["tgt_seg"].shape == (1, 128, 128, 1)
assert augmented["src_aff"].shape == (3, 64, 64, 1)
assert augmented["tgt_aff"].shape == (3, 64, 64, 1)

assert_array_equal(augmented["src_img"], (image[0].clip(0, 245) + 10))


def test_imgaug_mixed_lists():
image_group = np.random.randint(0, 255, (2, 64, 64, 10), dtype=np.uint8)
another_image = np.random.randint(0, 255, (3, 1024, 1024, 1), dtype=np.uint8)

aug = iaa.Invert()
augmented = imgaug_augment(
augmenters=aug,
images=image_group,
another_img=another_image,
)

assert augmented.keys() == {"images", "another_img"}
assert augmented["images"].shape == (2, 64, 64, 10)
assert augmented["another_img"].shape == (3, 1024, 1024, 1)
assert_array_equal(augmented["images"], np.invert(image_group))
assert_array_equal(augmented["another_img"], np.invert(another_image))
1 change: 1 addition & 0 deletions zetta_utils/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import tensor
from .common import prob_aug
from .imgaug import imgaug_augment
187 changes: 187 additions & 0 deletions zetta_utils/augmentations/imgaug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

from typing import Any, Final, Optional, Sequence, Sized, Tuple, TypeVar, Union

from imgaug import augmenters as iaa
from imgaug.augmenters.meta import Augmenter

from zetta_utils import builder
from zetta_utils.tensor_ops import common, convert
from zetta_utils.tensor_typing import Tensor

SizedTypeVar = TypeVar("SizedTypeVar", bound=Sized)
TensorListTypeVar = TypeVar(
"TensorListTypeVar", Tensor, Sequence[Tensor], Union[Tensor, Sequence[Tensor]]
)

SUFFIX_MAPPING: Final = {
"img": "images",
"seg": "segmentation_maps",
"hm": "heatmaps",
"aff": "heatmaps",
"kp": "keypoints",
"bb": "bounding_boxes",
"poly": "polygons",
"ls": "line_strings",
}


def _ensure_list(x: SizedTypeVar | Sequence[SizedTypeVar] | None) -> Sequence[SizedTypeVar]:
if x is None:
return []
elif isinstance(x, Sequence):
return x
else:
return [x]


def _group_kwargs(
**kwargs: Optional[SizedTypeVar],
) -> tuple[dict[str, dict[str, SizedTypeVar]], dict[str, str]]:
groups: dict[str, dict[str, SizedTypeVar]] = {}
kwarg_mapping = {}
for k, v in kwargs.items():
if v is None:
continue
group_name, suffix = k.rsplit("_", 1)
try:
mapped_suffix = SUFFIX_MAPPING[suffix]
kwarg_mapping[k] = mapped_suffix
except KeyError:
raise ValueError(
f"Expected suffix `_img`, `_seg`, or `_aff` in custom augmentable key {k}"
)

groups.setdefault(group_name, {})[mapped_suffix] = v

return groups, kwarg_mapping


def _ungroup_kwargs(
groups: dict[str, dict[str, SizedTypeVar]], kwarg_mapping: dict[str, str]
) -> dict[str, SizedTypeVar]:
kwargs: dict[str, SizedTypeVar] = {}
for k, mapped_suffix in kwarg_mapping.items():
group_name, suffix = k.rsplit("_", 1)
if group_name == "_args":
kwargs[mapped_suffix] = groups[group_name][mapped_suffix]
else:
kwargs[f"{group_name}_{suffix}"] = groups[group_name][mapped_suffix]

return kwargs


def _ensure_nxyc(x: TensorListTypeVar) -> TensorListTypeVar:
if isinstance(x, Sequence):
return [convert.to_np(common.rearrange(v, pattern="C X Y 1 -> X Y C")) for v in x]
else:
return convert.to_np(common.rearrange(x, pattern="C X Y N -> N X Y C"))


def _ensure_cxyn(x: TensorListTypeVar, ref: TensorListTypeVar) -> TensorListTypeVar:
if isinstance(x, Sequence):
return [
convert.astype(common.rearrange(v, pattern="X Y C -> C X Y 1"), reference=ref[i])
for i, v in enumerate(x)
]
else:
return convert.astype(common.rearrange(x, pattern="N X Y C -> C X Y N"), reference=ref)


@builder.register("imgaug_augment")
def imgaug_augment(
augmenters: Augmenter | Sequence[Augmenter],
*,
images: Tensor | Sequence[Tensor] | None = None,
heatmaps: Tensor | Sequence[Tensor] | None = None,
segmentation_maps: Tensor | Sequence[Tensor] | None = None,
keypoints: Sequence[Sequence[Tuple[float, float]]] | None = None,
bounding_boxes: Sequence[Sequence[Tuple[float, float, float, float]]] | None = None,
polygons: Sequence[Sequence[Sequence[Tuple[float, float]]]] | None = None,
line_strings: Sequence[Sequence[Sequence[Tuple[float, float]]]] | None = None,
**kwargs: Tensor | Sequence[Tensor],
) -> dict[str, Any]:
"""This function is a wrapper for imgaug.augment to handle the CXYZ/ZXYC conversion.
It will call each provided augmenter on the provided augmentable dict.
For additionally supported types, see:
https://github.com/aleju/imgaug/blob/0101108d4fed06bc5056c4a03e2bcb0216dac326/imgaug/augmenters/meta.py#L1757-L1842
:param augmenters: A sequence of imgaug augmenters.
:param images: Either CXYZ tensor or list of CXY1 tensors. If not specified, at least one kwarg with `_img` suffix is required.
:param heatmaps: Either CXYZ tensor or list of CXY1 tensors.
:param segmentation_maps: Either CXYZ tensor or list of CXY1 tensors.
:param keypoints: List of lists of (x, y) coordinates.
:param bounding_boxes: List of lists of (x1, y1, x2, y2) coordinates.
:param polygons: List of lists of lists of (x, y) coordinates.
:param line_strings: List of lists of lists of (x, y) coordinates.
:param kwargs: Additional/alternative augmentables, each a CXYZ tensor or list of CXY1 tensors and
suffixes: `_img`, `_seg`, `_hm`/`_aff`, `_kp`, `_bb`, `_poly`, `_ls`.
:return: Augmented dictionary, same keys as input.
"""
augmenter: Augmenter = iaa.Sequential(_ensure_list(augmenters)).to_deterministic()
augmentables, kwarg_mapping = _group_kwargs(
_args_img=images,
_args_hm=heatmaps,
_args_seg=segmentation_maps,
_args_kp=keypoints,
_args_bb=bounding_boxes,
_args_poly=polygons,
_args_ls=line_strings,
**kwargs,
)

if not augmentables:
raise ValueError("Expected at least one image in `images` or `kwargs`")

for aug_group in augmentables.values():
res = augmenter.augment(
images=_ensure_nxyc(aug_group["images"]),
heatmaps=_ensure_nxyc(aug_group["heatmaps"]) if "heatmaps" in aug_group else None,
segmentation_maps=_ensure_nxyc(aug_group["segmentation_maps"])
if "segmentation_maps" in aug_group
else None,
keypoints=aug_group.get("keypoints", None),
bounding_boxes=aug_group.get("bounding_boxes", None),
polygons=aug_group.get("polygons", None),
line_strings=aug_group.get("line_strings", None),
return_batch=True,
)

aug_group["images"] = _ensure_cxyn(res.images_aug, aug_group["images"])
if "heatmaps" in aug_group:
aug_group["heatmaps"] = _ensure_cxyn(res.heatmaps_aug, aug_group["heatmaps"])
if "segmentation_maps" in aug_group:
aug_group["segmentation_maps"] = _ensure_cxyn(
res.segmentation_maps_aug, aug_group["segmentation_maps"]
)
if "keypoints" in aug_group:
aug_group["keypoints"] = res.keypoints_aug
if "bounding_boxes" in aug_group:
aug_group["bounding_boxes"] = res.bounding_boxes_aug
if "polygons" in aug_group:
aug_group["polygons"] = res.polygons_aug
if "line_strings" in aug_group:
aug_group["line_strings"] = res.line_strings_aug

return _ungroup_kwargs(augmentables, kwarg_mapping=kwarg_mapping)


def _unpack_kwargs(cls):
class wrapper(cls): # type: ignore # mypy doesn't like dynamic base classes
def __call__(kwargs_tuple):
super.__call__(**kwargs_tuple)

return wrapper


for k in dir(iaa):
if k[0].isupper() and hasattr(getattr(iaa, k), "augment"):
builder.register(f"imgaug.augmenters.{k}")(_unpack_kwargs(getattr(iaa, k)))

for k in dir(iaa.imgcorruptlike):
if k[0].isupper() and hasattr(getattr(iaa.imgcorruptlike, k), "augment"):
builder.register(f"imgaug.augmenters.imgcorruptlike.{k}")(
_unpack_kwargs(getattr(iaa.imgcorruptlike, k))
)

0 comments on commit cd58bfe

Please sign in to comment.