Skip to content

Commit

Permalink
Add pre and post processing steps to allow non float dtypes (kornia#2882
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ashnair1 authored Jun 22, 2024
1 parent aec0dd8 commit eb30590
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 16 deletions.
74 changes: 67 additions & 7 deletions kornia/augmentation/container/augment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import torch

from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D
from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D
from kornia.augmentation.base import _AugmentationBase
Expand All @@ -21,7 +23,11 @@

_BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH}
_KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS}
_IMG_MSK_OPTIONS = {DataKey.INPUT, DataKey.MASK}
_IMG_OPTIONS = {DataKey.INPUT, DataKey.IMAGE}
_MSK_OPTIONS = {DataKey.MASK}
_CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL}

MaskDataType = Union[Tensor, List[Tensor]]


class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
Expand Down Expand Up @@ -195,6 +201,9 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
dict_keys(['image', 'mask', 'mask-b', 'bbox', 'bbox-other'])
"""

input_dtype = None
mask_dtype = None

def __init__(
self,
*args: Union[_AugmentationBase, ImageSequential],
Expand Down Expand Up @@ -332,13 +341,23 @@ def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) ->
def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]:
inp: List[DataType] = []
for arg, dcate in zip(args, data_keys):
if DataKey.get(dcate) in _IMG_MSK_OPTIONS:
if DataKey.get(dcate) in _IMG_OPTIONS:
arg = cast(Tensor, arg)
self.input_dtype = arg.dtype
inp.append(arg)
elif DataKey.get(dcate) in _MSK_OPTIONS:
if isinstance(inp, list):
arg = cast(List[Tensor], arg)
self.mask_dtype = arg[0].dtype
else:
arg = cast(Tensor, arg)
self.mask_dtype = arg.dtype
inp.append(self._preproc_mask(arg))
elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
inp.append(self._preproc_keypoints(arg, dcate))
elif DataKey.get(dcate) in _BOXES_OPTIONS:
inp.append(self._preproc_boxes(arg, dcate))
elif DataKey.get(dcate) is DataKey.CLASS:
elif DataKey.get(dcate) in _CLS_OPTIONS:
inp.append(arg)
else:
raise NotImplementedError(f"input type of {dcate} is not implemented.")
Expand All @@ -349,10 +368,13 @@ def _arguments_postproc(
) -> List[DataType]:
out: List[DataType] = []
for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys):
if DataKey.get(dcate) in _IMG_MSK_OPTIONS:
if DataKey.get(dcate) in _IMG_OPTIONS:
# It is tensor type already.
out.append(out_arg)
# TODO: may add the float to integer (for masks), etc.
elif DataKey.get(dcate) in _MSK_OPTIONS:
_out_m = self._postproc_mask(cast(MaskDataType, out_arg))
out.append(_out_m)

elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
_out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate)
Expand All @@ -372,7 +394,7 @@ def _arguments_postproc(
_out_b = _out_b.type(in_arg.dtype)
out.append(_out_b)

elif DataKey.get(dcate) is DataKey.CLASS:
elif DataKey.get(dcate) in _CLS_OPTIONS:
out.append(out_arg)

else:
Expand Down Expand Up @@ -472,6 +494,30 @@ def retrieve_key(key: str) -> DataKey:

return [DataKey.get(retrieve_key(k)) for k in keys]

def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
new_arg = []
for a in arg:
a_new = a.to(self.input_dtype) if self.input_dtype else a.to(torch.float)
new_arg.append(a_new)
return new_arg

else:
arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float)
return arg

def _postproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
new_arg = []
for a in arg:
a_new = a.to(self.mask_dtype) if self.mask_dtype else a.to(torch.float)
new_arg.append(a_new)
return new_arg

else:
arg = arg.to(self.mask_dtype) if self.mask_dtype else arg.to(torch.float)
return arg

def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes:
if DataKey.get(dcate) in [DataKey.BBOX]:
mode = "vertices_plus"
Expand Down Expand Up @@ -509,17 +555,31 @@ def _postproc_boxes(self, in_arg: DataType, out_arg: Boxes, dcate: DataKey) -> U
return out_arg.to_tensor(mode=mode)

def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints:
dtype = None

if self.contains_video_sequential:
arg = cast(Union[Tensor, List[Tensor]], arg)
return VideoKeypoints.from_tensor(arg)
if isinstance(arg, list):
if not torch.is_floating_point(arg[0]):
dtype = arg[0].dtype
arg = [a.float() for a in arg]
elif not torch.is_floating_point(arg):
dtype = arg.dtype
arg = arg.float()
video_result = VideoKeypoints.from_tensor(arg)
return video_result.type(dtype) if dtype else video_result
elif self.contains_3d_augmentation:
raise NotImplementedError("3D keypoint handlers are not yet supported.")
elif isinstance(arg, (Keypoints,)):
return arg
else:
arg = cast(Tensor, arg)
if not torch.is_floating_point(arg):
dtype = arg.dtype
arg = arg.float()
# TODO: Add List[Tensor] in the future.
return Keypoints.from_tensor(arg)
result = Keypoints.from_tensor(arg)
return result.type(dtype) if dtype else result

def _postproc_keypoint(
self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey
Expand Down
9 changes: 6 additions & 3 deletions kornia/augmentation/container/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def transform(
outputs = []
for inp, dcate in zip(arg, _data_keys):
op = self._get_op(dcate)
extra_arg = extra_args[dcate] if dcate in extra_args else {}
extra_arg = extra_args.get(dcate, {})
if dcate.name == "MASK" and isinstance(inp, list):
outputs.append(MaskSequentialOps.transform_list(inp, module, param=param, extra_args=extra_arg))
else:
Expand Down Expand Up @@ -240,6 +240,7 @@ def transform(cls, input: Tensor, module: Module, param: ParamItem, extra_args:
to apply transformations.
param: the corresponding parameters to the module.
"""

if isinstance(module, (K.GeometricAugmentationBase2D,)):
input = module.transform_masks(
input,
Expand Down Expand Up @@ -269,7 +270,8 @@ def transform(cls, input: Tensor, module: Module, param: ParamItem, extra_args:
input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

elif isinstance(module, (K.auto.operations.OperationBase,)):
return MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
input = MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

return input

@classmethod
Expand Down Expand Up @@ -344,6 +346,7 @@ def inverse(cls, input: Tensor, module: Module, param: ParamItem, extra_args: Di
to apply transformations.
param: the corresponding parameters to the module.
"""

if isinstance(module, (K.GeometricAugmentationBase2D,)):
if module.transform_matrix is None:
raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
Expand All @@ -365,7 +368,7 @@ def inverse(cls, input: Tensor, module: Module, param: ParamItem, extra_args: Di
input = module.inverse_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

elif isinstance(module, (K.auto.operations.OperationBase,)):
return MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
input = MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)

return input

Expand Down
14 changes: 8 additions & 6 deletions tests/augmentation/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,17 +555,18 @@ def test_masks_without_channel_dim(self, device, dtype, B, C_i, C_m, keepdim):
else:
assert (*(1,) * (4 - len(img_shape)), *img_shape) == out[0].shape

out_mask_shape = tuple(x if x else 1 for x in (B, C_m, *img_shape[-2:]))
out_mask_shape = tuple(x or 1 for x in (B, C_m, *img_shape[-2:]))
assert out[1].shape == out_mask_shape

@pytest.mark.slow
@pytest.mark.parametrize("random_apply", [1, (2, 2), (1, 2), (2,), 10, True, False])
def test_forward_and_inverse(self, random_apply, device, dtype):
@pytest.mark.parametrize("mask_dtype", [torch.int32, torch.int64, torch.float32])
def test_forward_and_inverse(self, random_apply, device, dtype, mask_dtype):
inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype)
keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype)
mask = bbox_to_mask(
torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500
torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=mask_dtype), 1000, 500
)[:, None]
aug = K.AugmentationSequential(
K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)),
Expand Down Expand Up @@ -715,11 +716,12 @@ def test_inverse_and_forward_return_transform(self, random_apply, device, dtype)
if random_apply is False:
reproducibility_test((inp, mask, bbox, keypoints, bbox_2, bbox_wh, bbox_wh_2), aug)

def test_transform_list_of_masks_and_boxes(self, device, dtype):
@pytest.mark.parametrize("mask_dtype", [torch.int32, torch.int64, torch.float32])
def test_transform_list_of_masks_and_boxes(self, device, dtype, mask_dtype):
input = torch.randn(2, 3, 256, 256, device=device, dtype=dtype)
mask = [
torch.ones(1, 3, 256, 256, device=device, dtype=dtype),
torch.ones(1, 2, 256, 256, device=device, dtype=dtype),
torch.ones(1, 3, 256, 256, device=device, dtype=mask_dtype),
torch.ones(1, 2, 256, 256, device=device, dtype=mask_dtype),
]

bbox = [
Expand Down

0 comments on commit eb30590

Please sign in to comment.