Skip to content

Commit

Permalink
feat: partial imgaug library support
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Jul 21, 2023
1 parent c306552 commit 5a4026a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ cloudvol = [
"zetta_utils[tensor_ops]"
]
augmentations = [
"zetta_utils[tensor_ops]"
"zetta_utils[tensor_ops]",
"imgaug == 0.4.0",
"imagecorruptions == 1.1.2",
]
convnet = [
"torch >= 2.0",
Expand Down
1 change: 1 addition & 0 deletions zetta_utils/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import tensor
from .common import prob_aug
from .imgaug import imgaug
49 changes: 49 additions & 0 deletions zetta_utils/augmentations/imgaug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import einops
from imgaug import augmenters as iaa

from zetta_utils import builder
from zetta_utils.tensor_ops import convert


@builder.register("imgaug")
def imgaug(augmentables, augmenters):
for k, v in augmentables.items():
if k in ("image", "images", "heatmaps", "segmentation_maps"):
augmentables[k] = convert.to_np(einops.rearrange(v, "C X Y Z -> Z X Y C"))
elif k in ("keypoints", "bounding_boxes", "polygons", "line_strings"):
pass
else:
raise ValueError(f"Unsupported augmentable {k}")

for augmenter in augmenters:
augmentables = augmenter.augment(**augmentables, return_batch=True)
augmentables = {
k: getattr(augmentables, f"{k}_aug") for k in augmentables.get_column_names()
}

for k, v in augmentables.items():
if k in ("image", "images", "heatmaps", "segmentation_maps"):
augmentables[k] = convert.to_torch(einops.rearrange(v, "Z X Y C -> C X Y Z"))

return augmentables


def unpack_kwargs(cls):
class wrapper(cls):
def __call__(kwargs_tuple):
super.__call__(**kwargs_tuple)

return wrapper


for k in dir(iaa):
if k[0].isupper() and hasattr(getattr(iaa, k), "augment"):
builder.register(f"imgaug.augmenters.{k}")(unpack_kwargs(getattr(iaa, k)))

for k in dir(iaa.imgcorruptlike):
if k[0].isupper() and hasattr(getattr(iaa.imgcorruptlike, k), "augment"):
builder.register(f"imgaug.augmenters.imgcorruptlike.{k}")(
unpack_kwargs(getattr(iaa.imgcorruptlike, k))
)

0 comments on commit 5a4026a

Please sign in to comment.