diff --git a/pyproject.toml b/pyproject.toml index 608261844..bd940c0b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,9 @@ cloudvol = [ "zetta_utils[tensor_ops]" ] augmentations = [ - "zetta_utils[tensor_ops]" + "zetta_utils[tensor_ops]", + "imgaug == 0.4.0", + "imagecorruptions == 1.1.2", ] convnet = [ "torch >= 2.0", diff --git a/zetta_utils/augmentations/__init__.py b/zetta_utils/augmentations/__init__.py index 25f615a96..2e405230a 100644 --- a/zetta_utils/augmentations/__init__.py +++ b/zetta_utils/augmentations/__init__.py @@ -1,2 +1,3 @@ from . import tensor from .common import prob_aug +from .imgaug import imgaug diff --git a/zetta_utils/augmentations/imgaug.py b/zetta_utils/augmentations/imgaug.py new file mode 100644 index 000000000..9a184042a --- /dev/null +++ b/zetta_utils/augmentations/imgaug.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import einops +from imgaug import augmenters as iaa + +from zetta_utils import builder +from zetta_utils.tensor_ops import convert + + +@builder.register("imgaug") +def imgaug(augmentables, augmenters): + for k, v in augmentables.items(): + if k in ("image", "images", "heatmaps", "segmentation_maps"): + augmentables[k] = convert.to_np(einops.rearrange(v, "C X Y Z -> Z X Y C")) + elif k in ("keypoints", "bounding_boxes", "polygons", "line_strings"): + pass + else: + raise ValueError(f"Unsupported augmentable {k}") + + for augmenter in augmenters: + augmentables = augmenter.augment(**augmentables, return_batch=True) + augmentables = { + k: getattr(augmentables, f"{k}_aug") for k in augmentables.get_column_names() + } + + for k, v in augmentables.items(): + if k in ("image", "images", "heatmaps", "segmentation_maps"): + augmentables[k] = convert.to_torch(einops.rearrange(v, "Z X Y C -> C X Y Z")) + + return augmentables + + +def unpack_kwargs(cls): + class wrapper(cls): + def __call__(kwargs_tuple): + super.__call__(**kwargs_tuple) + + return wrapper + + +for k in dir(iaa): + if k[0].isupper() and hasattr(getattr(iaa, k), "augment"): + builder.register(f"imgaug.augmenters.{k}")(unpack_kwargs(getattr(iaa, k))) + +for k in dir(iaa.imgcorruptlike): + if k[0].isupper() and hasattr(getattr(iaa.imgcorruptlike, k), "augment"): + builder.register(f"imgaug.augmenters.imgcorruptlike.{k}")( + unpack_kwargs(getattr(iaa.imgcorruptlike, k)) + )