diff --git a/pyproject.toml b/pyproject.toml index 1df488792..bf02155ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,11 @@ cloudvol = [ "tensorstore == 0.1.33", "zetta_utils[tensor_ops]", ] -augmentations = ["zetta_utils[tensor_ops]"] +augmentations = [ + "zetta_utils[tensor_ops]", + "imgaug @ git+https://github.com/u7122029/imgaug.git@418f22d4275e9f90274900e27b595ed678bb4bcc", + "imagecorruptions == 1.1.2", +] convnet = ["torch >= 2.0", "artificery >= 0.0.3.3"] alignment = ["metroem >= 0.1.2", "torch >= 2.0"] mazepa = [ @@ -87,7 +91,7 @@ mazepa-addons = [ ] segmentation = ["onnx >= 1.13.0", "onnxruntime-gpu >= 1.13.1"] training = [ - "zetta_utils[tensor_ops,cloudvol,convnet,viz,gcs]", + "zetta_utils[tensor_ops,cloudvol,convnet,viz,gcs,augmentations]", "torch >= 2.0", "pytorch-lightning == 1.7.7", # mypy error on newer "torchmetrics == 0.11.4", diff --git a/tests/unit/augmentations/test_imgaug.py b/tests/unit/augmentations/test_imgaug.py new file mode 100644 index 000000000..c5540c62d --- /dev/null +++ b/tests/unit/augmentations/test_imgaug.py @@ -0,0 +1,175 @@ +import numpy as np +import pytest +import torch +from imgaug import augmenters as iaa + +import zetta_utils as zu +from zetta_utils.augmentations.imgaug import imgaug_augment + +from ..helpers import assert_array_equal + + +def test_imgaug_basic_ndarray(): + image = np.random.randint(0, 255, (3, 128, 128, 2), dtype=np.uint8) + segmap = np.random.randint(0, 2 ** 15, (1, 64, 64, 2), dtype=np.uint16) + keypoints = [[(0.0, 0.0), (0.0, 128.0)], [(128.0, 128.0), (64.0, 64.0)]] + bboxes = [[(0.0, 0.0, 128.0, 128.0)], [(64.0, 64.0, 128.0, 128.0)]] + polygons = [[(0.0, 0.0), (0.0, 128.0), (128.0, 128.0)]], [ + [(128.0, 0.0), (128.0, 128.0), (0.0, 0.0)] + ] + line_strings = [[(0.0, 0.0), (0.0, 128.0), (128.0, 128.0)]], [ + [(128.0, 0.0), (128.0, 128.0), (0.0, 0.0)] + ] + + aug = iaa.Rot90() + augmented = imgaug_augment( + augmenters=aug, + images=image, + segmentation_maps=segmap, + keypoints=keypoints, + bounding_boxes=bboxes, + polygons=polygons, + line_strings=line_strings, + ) + + assert augmented.keys() == { + "images", + "segmentation_maps", + "keypoints", + "bounding_boxes", + "polygons", + "line_strings", + } + assert augmented["images"].shape == (3, 128, 128, 2) + assert augmented["segmentation_maps"].shape == (1, 64, 64, 2) + assert_array_equal(augmented["images"], np.rot90(image, axes=(2, 1))) + assert_array_equal(augmented["segmentation_maps"], np.rot90(segmap, axes=(2, 1))) + assert augmented["keypoints"] == [[(128.0, 0.0), (0.0, 0.0)], [(0.0, 128.0), (64.0, 64.0)]] + + +def test_imgaug_basic_tensor(): + image = torch.randint(0, 255, (3, 128, 128, 2), dtype=torch.uint8) + segmap = torch.randint(0, 2 ** 15, (1, 64, 64, 2), dtype=torch.int16) + keypoints = [[(0.0, 0.0), (0.0, 128.0)], [(128.0, 128.0), (64.0, 64.0)]] + + aug = iaa.Rot90() + augmented = imgaug_augment( + augmenters=aug, images=image, segmentation_maps=segmap, keypoints=keypoints + ) + + assert augmented.keys() == {"images", "segmentation_maps", "keypoints"} + assert augmented["images"].shape == (3, 128, 128, 2) + assert augmented["segmentation_maps"].shape == (1, 64, 64, 2) + assert_array_equal(augmented["images"], torch.rot90(image, dims=(2, 1))) + assert_array_equal(augmented["segmentation_maps"], torch.rot90(segmap, dims=(2, 1))) + assert augmented["keypoints"] == [[(128.0, 0.0), (0.0, 0.0)], [(0.0, 128.0), (64.0, 64.0)]] + + +def test_imgaug_basic_lists_ndarray(): + image = [np.random.randint(0, 255, (1, 128, 128, 1), dtype=np.uint8) for _ in range(2)] + heatmap = [np.random.rand(3, 64, 64, 1).astype(np.float32) for _ in range(2)] + + aug = iaa.Fliplr() + augmented = imgaug_augment(augmenters=[aug], images=image, heatmaps=heatmap) + + assert augmented.keys() == {"images", "heatmaps"} + assert len(augmented["images"]) == 2 + assert len(augmented["heatmaps"]) == 2 + assert_array_equal(augmented["images"][0], np.flip(image[0], axis=2)) + assert_array_equal(augmented["images"][1], np.flip(image[1], axis=2)) + assert_array_equal(augmented["heatmaps"][0], np.flip(heatmap[0], axis=2)) + assert_array_equal(augmented["heatmaps"][1], np.flip(heatmap[1], axis=2)) + + +def test_imgaug_basic_lists_tensor(): + image = [torch.randint(0, 255, (1, 128, 128, 1), dtype=torch.uint8) for _ in range(2)] + heatmap = [torch.rand(3, 64, 64, 1, dtype=torch.float32) for _ in range(2)] + + aug = iaa.Fliplr() + augmented = imgaug_augment(augmenters=[aug], images=image, heatmaps=heatmap) + + assert augmented.keys() == {"images", "heatmaps"} + assert len(augmented["images"]) == 2 + assert len(augmented["heatmaps"]) == 2 + assert_array_equal(augmented["images"][0], torch.flip(image[0], dims=(2,))) + assert_array_equal(augmented["images"][1], torch.flip(image[1], dims=(2,))) + assert_array_equal(augmented["heatmaps"][0], torch.flip(heatmap[0], dims=(2,))) + assert_array_equal(augmented["heatmaps"][1], torch.flip(heatmap[1], dims=(2,))) + + +def test_imgaug_custom_lists(): + image = [np.random.randint(0, 255, (1, 128, 128, 1), dtype=np.uint8) for _ in range(2)] + seg = [np.random.randint(0, 2 ** 15, (1, 128, 128, 1), dtype=np.uint16) for _ in range(2)] + aff = [np.random.rand(3, 64, 64, 1).astype(np.float32) for _ in range(2)] + + aug = iaa.Add(10) + augmented = imgaug_augment( + augmenters=aug, + src_img=image[0], + tgt_img=image[1], + src_seg=seg[0], + tgt_seg=seg[1], + src_aff=aff[0], + tgt_aff=aff[1], + ) + + assert augmented.keys() == {"src_img", "tgt_img", "src_seg", "tgt_seg", "src_aff", "tgt_aff"} + assert augmented["src_img"].shape == (1, 128, 128, 1) + assert augmented["tgt_img"].shape == (1, 128, 128, 1) + assert augmented["src_seg"].shape == (1, 128, 128, 1) + assert augmented["tgt_seg"].shape == (1, 128, 128, 1) + assert augmented["src_aff"].shape == (3, 64, 64, 1) + assert augmented["tgt_aff"].shape == (3, 64, 64, 1) + + assert_array_equal(augmented["src_img"], (image[0].clip(0, 245) + 10)) + + +def test_imgaug_mixed_lists(): + image_group = np.random.randint(0, 255, (2, 64, 64, 10), dtype=np.uint8) + another_image = np.random.randint(0, 255, (3, 1024, 1024, 1), dtype=np.uint8) + + aug = [iaa.Invert()] + augmented = imgaug_augment( + augmenters=aug, + images=image_group, + another_img=another_image, + ) + + assert augmented.keys() == {"images", "another_img"} + assert augmented["images"].shape == (2, 64, 64, 10) + assert augmented["another_img"].shape == (3, 1024, 1024, 1) + assert_array_equal(augmented["images"], np.invert(image_group)) + assert_array_equal(augmented["another_img"], np.invert(another_image)) + + +def test_imgaug_exceptions(): + seg = [np.random.randint(0, 2 ** 15, (1, 128, 128, 1), dtype=np.uint16) for _ in range(2)] + aug = iaa.Invert() + + with pytest.raises(ValueError): + imgaug_augment(aug, data_seg=seg) + + with pytest.raises(ValueError): + imgaug_augment(aug, data_unknownsuffix=seg) + + +def test_imgaug_builder(): + zu.load_all_modules() # pylint: disable=protected-access + spec = zu.builder.build( + spec={ + "@type": "imgaug_readproc", + "@mode": "partial", + "augmenters": [ + { + "@type": "imgaug.augmenters.Sequential", + "children": [ + {"@type": "imgaug.augmenters.Add", "value": 0}, + {"@type": "imgaug.augmenters.imgcorruptlike.DefocusBlur", "severity": 5}, + ], + }, + ], + } + ) + arr = np.zeros((1, 128, 128, 1), dtype=np.uint8) + assert spec({"images": arr}).keys() == {"images"} + assert_array_equal(spec(arr), arr) diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 9a7c427ff..88a96511c 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -22,6 +22,7 @@ from zetta_utils.alignment.misalignment_detector import MisalignmentDetector, naive_misd from zetta_utils.alignment.online_finetuner import align_with_online_finetuner from zetta_utils.augmentations.common import prob_aug +from zetta_utils.augmentations.imgaug import imgaug_augment, imgaug_readproc from zetta_utils.augmentations.tensor import ( add_scalar_aug, clamp_values_aug, diff --git a/zetta_utils/augmentations/__init__.py b/zetta_utils/augmentations/__init__.py index 25f615a96..17018ad01 100644 --- a/zetta_utils/augmentations/__init__.py +++ b/zetta_utils/augmentations/__init__.py @@ -1,2 +1,3 @@ from . import tensor from .common import prob_aug +from .imgaug import imgaug_augment diff --git a/zetta_utils/augmentations/imgaug.py b/zetta_utils/augmentations/imgaug.py new file mode 100644 index 000000000..a312d21f2 --- /dev/null +++ b/zetta_utils/augmentations/imgaug.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from typing import Any, Final, Sequence, Sized, Tuple, TypeVar, overload + +from imgaug import augmenters as iaa +from imgaug.augmenters.meta import Augmenter +from numpy.typing import NDArray + +from zetta_utils import builder +from zetta_utils.tensor_ops import common, convert +from zetta_utils.tensor_typing import Tensor, TensorTypeVar + +SizedTypeVar = TypeVar("SizedTypeVar", bound=Sized) +TensorListTypeVar = TypeVar("TensorListTypeVar", Tensor, Sequence) +T = TypeVar("T") + +SUFFIX_MAPPING: Final = { + "img": "images", + "seg": "segmentation_maps", + "mask": "segmentation_maps", + "hm": "heatmaps", + "aff": "heatmaps", + "kp": "keypoints", + "bb": "bounding_boxes", + "poly": "polygons", + "ls": "line_strings", +} + + +def _ensure_list(augmenter: Augmenter | Sequence[Augmenter]) -> Sequence[Augmenter]: + return augmenter if isinstance(augmenter, Sequence) else [augmenter] + + +@overload +def _ensure_nxyc(x: Tensor) -> NDArray: + ... + + +@overload +def _ensure_nxyc(x: Sequence[Tensor]) -> list[NDArray]: + ... + + +def _ensure_nxyc(x: Tensor | Sequence[Tensor]) -> NDArray | list[NDArray]: + if isinstance(x, Sequence): + return [convert.to_np(common.rearrange(v, pattern="C X Y 1 -> X Y C")) for v in x] + else: + return convert.to_np(common.rearrange(x, pattern="C X Y N -> N X Y C")) + + +@overload +def _ensure_cxyn(x: Tensor, ref: TensorTypeVar) -> TensorTypeVar: + ... + + +@overload +def _ensure_cxyn(x: Sequence[Tensor], ref: Sequence[TensorTypeVar]) -> list[TensorTypeVar]: + ... + + +def _ensure_cxyn( + x: Tensor | Sequence[Tensor], ref: TensorTypeVar | Sequence[TensorTypeVar] +) -> TensorTypeVar | list[TensorTypeVar]: + if isinstance(ref, Sequence): + assert isinstance(x, Sequence) + return [ + convert.astype( + common.rearrange(v, pattern="X Y C -> C X Y 1"), reference=ref[i] # type: ignore + ) + for i, v in enumerate(x) + ] + else: + return convert.astype( + common.rearrange(x, pattern="N X Y C -> C X Y N"), reference=ref # type: ignore + ) + + +def _group_kwargs( + **kwargs: TensorListTypeVar | None, +) -> tuple[dict[str, dict[str, TensorListTypeVar]], dict[str, str]]: + groups: dict[str, dict[str, TensorListTypeVar]] = {} + kwarg_mapping: dict[str, str] = {} + for k, v in kwargs.items(): + if v is None: + continue + group_name, suffix = k.rsplit("_", 1) + try: + mapped_suffix = SUFFIX_MAPPING[suffix] + kwarg_mapping[k] = mapped_suffix + except KeyError as e: + raise ValueError( + f"Expected suffix `_img`, `_seg`, or `_aff` in custom augmentable key {k}" + ) from e + + groups.setdefault(group_name, {})[mapped_suffix] = v + + return groups, kwarg_mapping + + +def _ungroup_kwargs( + groups: dict[str, dict[str, TensorListTypeVar]], kwarg_mapping: dict[str, str] +) -> dict[str, TensorListTypeVar]: + kwargs = {} + for k, mapped_suffix in kwarg_mapping.items(): + group_name, suffix = k.rsplit("_", 1) + if group_name == "_args": + kwargs[mapped_suffix] = groups[group_name][mapped_suffix] + else: + kwargs[f"{group_name}_{suffix}"] = groups[group_name][mapped_suffix] + + return kwargs + + +@builder.register("imgaug_readproc") +def imgaug_readproc( + *args, # the zetta_utils builder puts the layer/layerset as the first argument + **kwargs, # and the augmenters in the kwargs +): + assert len(args) == 1 + augmenters = kwargs.pop("augmenters", None) + assert augmenters is not None + if isinstance(args[0], dict): + return imgaug_augment(augmenters, **args[0], **kwargs) + else: # Tensor + return imgaug_augment(augmenters, images=args[0], **kwargs)["images"] + + +def imgaug_augment( + augmenters: Augmenter | Sequence[Augmenter], + *, + images: Tensor | Sequence[Tensor] | None = None, + heatmaps: Tensor | Sequence[Tensor] | None = None, + segmentation_maps: Tensor | Sequence[Tensor] | None = None, + keypoints: Sequence[Sequence[Tuple[float, float]]] | None = None, + bounding_boxes: Sequence[Sequence[Tuple[float, float, float, float]]] | None = None, + polygons: Sequence[Sequence[Sequence[Tuple[float, float]]]] | None = None, + line_strings: Sequence[Sequence[Sequence[Tuple[float, float]]]] | None = None, + **kwargs: Tensor | Sequence[Tensor], +) -> dict[str, Any]: + """This function is a wrapper for imgaug.augment to handle the CXYZ/ZXYC conversion. + It will call each provided augmenter on the provided augmentable dict. + + For additionally supported types, see: + https://github.com/aleju/imgaug/blob/0101108d4fed06bc5056c4a03e2bcb0216dac326/imgaug/augmenters/meta.py#L1757-L1842 + + :param augmenters: A sequence of imgaug augmenters. + :param images: Either CXYZ tensor or list of CXY1 tensors. If not specified, + at least one kwarg with `_img` suffix is required. + :param heatmaps: Either CXYZ tensor or list of CXY1 tensors. + :param segmentation_maps: Either CXYZ tensor or list of CXY1 tensors. + :param keypoints: List of lists of (x, y) coordinates. + :param bounding_boxes: List of lists of (x1, y1, x2, y2) coordinates. + :param polygons: List of lists of lists of (x, y) coordinates. + :param line_strings: List of lists of lists of (x, y) coordinates. + :param kwargs: Additional/alternative augmentables, each a CXYZ tensor or list of CXY1 tensors + and suffixes: `_img`, `_seg`, `_hm`/`_aff`, `_kp`, `_bb`, `_poly`, `_ls`. + + :return: Augmented dictionary, same keys as input. + """ + augmenter: Augmenter = iaa.Sequential(_ensure_list(augmenters)).to_deterministic() + augmentables, kwarg_mapping = _group_kwargs( # type: ignore + _args_img=images, + _args_hm=heatmaps, + _args_seg=segmentation_maps, + _args_kp=keypoints, + _args_bb=bounding_boxes, + _args_poly=polygons, + _args_ls=line_strings, + **kwargs, + ) + + if not any("images" in group for group in augmentables.values()): + raise ValueError("Expected at least one image in `images` or `kwargs`") + + for aug_group in augmentables.values(): + res = augmenter.augment( + images=_ensure_nxyc(aug_group["images"]), # type: ignore + heatmaps=_ensure_nxyc(aug_group["heatmaps"]) # type: ignore + if "heatmaps" in aug_group + else None, + segmentation_maps=_ensure_nxyc(aug_group["segmentation_maps"]) # type: ignore + if "segmentation_maps" in aug_group + else None, + keypoints=aug_group.get("keypoints", None), + bounding_boxes=aug_group.get("bounding_boxes", None), + polygons=aug_group.get("polygons", None), + line_strings=aug_group.get("line_strings", None), + return_batch=True, + ) + + aug_group["images"] = _ensure_cxyn(res.images_aug, aug_group["images"]) # type: ignore + if "heatmaps" in aug_group: + aug_group["heatmaps"] = _ensure_cxyn( + res.heatmaps_aug, aug_group["heatmaps"] # type: ignore + ) + if "segmentation_maps" in aug_group: + aug_group["segmentation_maps"] = _ensure_cxyn( + res.segmentation_maps_aug, aug_group["segmentation_maps"] # type: ignore + ) + if "keypoints" in aug_group: + aug_group["keypoints"] = res.keypoints_aug + if "bounding_boxes" in aug_group: + aug_group["bounding_boxes"] = res.bounding_boxes_aug + if "polygons" in aug_group: + aug_group["polygons"] = res.polygons_aug + if "line_strings" in aug_group: + aug_group["line_strings"] = res.line_strings_aug + + return _ungroup_kwargs(augmentables, kwarg_mapping=kwarg_mapping) + + +for attr in dir(iaa): + if attr[0].isupper() and hasattr(getattr(iaa, attr), "augment"): + builder.register(f"imgaug.augmenters.{attr}")(getattr(iaa, attr)) + +for attr in dir(iaa.imgcorruptlike): + if attr[0].isupper() and hasattr(getattr(iaa.imgcorruptlike, attr), "augment"): + builder.register(f"imgaug.augmenters.imgcorruptlike.{attr}")( + getattr(iaa.imgcorruptlike, attr) + ) diff --git a/zetta_utils/tensor_ops/convert.py b/zetta_utils/tensor_ops/convert.py index 78dd0390f..7910dadd4 100644 --- a/zetta_utils/tensor_ops/convert.py +++ b/zetta_utils/tensor_ops/convert.py @@ -50,6 +50,8 @@ def to_torch(data: Tensor, device: torch.types.Device = None) -> torch.Tensor: raise ValueError("Unable to convert uint32 dtype to int32") data = data.astype(np.int32) + if any(v < 0 for v in data.strides): # torch.from_numpy does not support negative strides + data = data.copy("K") result = torch.from_numpy(data).to(device) # type: ignore # pytorch bug return result