Skip to content

Commit

Permalink
Merge pull request #348 from Visual-Behavior/augmentations_warning
Browse files Browse the repository at this point in the history
Add warning when augmentations fail
  • Loading branch information
thibo73800 authored Apr 6, 2023
2 parents 7b98304 + d9b3878 commit e78bdac
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
41 changes: 29 additions & 12 deletions aloscene/tensors/augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def _torch_function_get_self(cls, func, types, args, kwargs):
""" Based on this dicussion https://github.com/pytorch/pytorch/issues/63767
"""Based on this dicussion https://github.com/pytorch/pytorch/issues/63767
"A simple solution would be to scan the args for the first subclass of this class.
My question is more: will forcing this to be a subclass actually be a problem for some use case?
Expand Down Expand Up @@ -36,7 +36,7 @@ class AugmentedTensor(torch.Tensor):

# Ignore named tansors userwarning.
ERROR_MSG = "Named tensors and all their associated APIs are an experimental feature and subject to change"
warnings.filterwarnings(action='ignore', message=ERROR_MSG)
warnings.filterwarnings(action="ignore", message=ERROR_MSG)

@staticmethod
def __new__(cls, x, names=None, device=None, *args, **kwargs):
Expand Down Expand Up @@ -335,7 +335,6 @@ def __getitem__(self, idx):

tensor = tensor.reset_names() if len(self.shape) == len(tensor.shape) else tensor.as_tensor()


# if not idx.dtype == torch.bool:
# if not torch.equal(idx ** 3, idx):
# raise IndexError(f"Unvalid mask. Expected mask elements to be in [0, 1, True, False]")
Expand Down Expand Up @@ -475,6 +474,7 @@ def _fillup_dict(dm, sub_label, dim, target_dim):
else:
for s in range(len(sub_label)):
_fillup_dict(dm[s], sub_label[s], dim + 1, target_dim)

_fillup_dict(dict_merge[key], label, 0, target_dim)

return dict_merge
Expand All @@ -499,7 +499,9 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
setattr(n_tensor, prop, None)
else:
values = set([prop_name_to_value[prop], getattr(tensor, prop)])
raise RuntimeError(f"Encountered different values for property '{prop}' while merging AugmentedTensor: {values}")
raise RuntimeError(
f"Encountered different values for property '{prop}' while merging AugmentedTensor: {values}"
)
else:
prop_name_to_value[prop] = getattr(tensor, prop)

Expand Down Expand Up @@ -545,15 +547,19 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
if intersection:
del labels_dict2list[label_name][key]
else:
raise RuntimeError(f"Error during merging. Some tensors have label '{label_name}' with key '{key}' and some don't")
raise RuntimeError(
f"Error during merging. Some tensors have label '{label_name}' with key '{key}' and some don't"
)
else:
args = list(args)
args[0] = labels_dict2list[label_name][key]
labels_dict2list[label_name][key] = func(*tuple(args), **kwargs)
# if we removed all keys, set this child to None
if intersection and not labels_dict2list[label_name]:
labels_dict2list[label_name] = None
elif intersection and (len(labels_dict2list[label_name]) != dim_size or (None in labels_dict2list[label_name])):
elif intersection and (
len(labels_dict2list[label_name]) != dim_size or (None in labels_dict2list[label_name])
):
labels_dict2list[label_name] = None
else:
args = list(args)
Expand Down Expand Up @@ -595,7 +601,6 @@ def __iter__(self):
for t in range(len(self)):
yield self[t]


@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
self = _torch_function_get_self(cls, func, types, args, kwargs)
Expand All @@ -614,10 +619,10 @@ def _merging_frame(args):
if func.__name__ == "__reduce_ex__":
self.rename_(None, auto_restore_names=True)
tensor = super().__torch_function__(func, types, args, kwargs)
#tensor = super().torch_func_method(func, types, args, kwargs)
# tensor = super().torch_func_method(func, types, args, kwargs)
else:
tensor = super().__torch_function__(func, types, args, kwargs)
#tensor = super().torch_func_method(func, types, args, kwargs)
# tensor = super().torch_func_method(func, types, args, kwargs)

if isinstance(tensor, type(self)):
tensor._property_list = self._property_list
Expand Down Expand Up @@ -853,6 +858,9 @@ def _hflip_label(self, label, **kwargs):
try:
label_flipped = label._hflip(**kwargs)
except AttributeError:
print(
f"[WARNING] Horizontal flip returned AttributeError on {type(label).__name__}, returning unflipped tensor."
)
return label
else:
return label_flipped
Expand Down Expand Up @@ -914,6 +922,7 @@ def resize_func(label):
label_resized = label._resize(size01, **kwargs)
return label_resized
except AttributeError:
print(f"[WARNING] resize returned AttributeError on {type(label).__name__}, returning initial tensor.")
return label

resized = self._resize(size01, **kwargs)
Expand All @@ -924,7 +933,7 @@ def resize_func(label):
def _resize(self, *args, **kwargs):
raise Exception("This Augmented tensor should implement this method")

def rotate(self, angle, center=None,**kwargs):
def rotate(self, angle, center=None, **kwargs):
"""
Rotate AugmentedTensor, and its labels recursively
Expand All @@ -941,12 +950,15 @@ def rotate(self, angle, center=None,**kwargs):

def rotate_func(label):
try:
label_rotated = label._rotate(angle, center,**kwargs)
label_rotated = label._rotate(angle, center, **kwargs)
return label_rotated
except AttributeError:
print(
f"[WARNING] Rotate returned AttributeError on {type(label).__name__}, returning unrotated tensor."
)
return label

rotated = self._rotate(angle, center,**kwargs)
rotated = self._rotate(angle, center, **kwargs)
rotated.recursive_apply_on_children_(rotate_func)

return rotated
Expand All @@ -956,6 +968,7 @@ def _crop_label(self, label, H_crop, W_crop, **kwargs):
label_resized = label._crop(H_crop, W_crop, **kwargs)
return label_resized
except AttributeError:
print(f"[WARNING] Crop returned AttributeError on {type(label).__name__}, returning uncropped tensor.")
return label

def crop(self, H_crop: tuple, W_crop: tuple, **kwargs):
Expand Down Expand Up @@ -991,6 +1004,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs):
label_pad = label._pad(offset_y, offset_x, **kwargs)
return label_pad
except AttributeError:
print(f"[WARNING] Padding returned AttributeError on {type(label).__name__}, returning unpadded tensor.")
return label

def pad(self, offset_y: tuple = None, offset_x: tuple = None, multiple: int = None, **kwargs):
Expand Down Expand Up @@ -1043,6 +1057,9 @@ def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs):
label_shift = label._spatial_shift(shift_y, shift_x, **kwargs)
return label_shift
except AttributeError:
print(
f"[WARNING] Spatial shift returned AttributeError on {type(label).__name__}, returning unshifted tensor."
)
return label

def spatial_shift(self, shift_y: float, shift_x: float, **kwargs):
Expand Down
27 changes: 22 additions & 5 deletions aloscene/tensors/spatial_augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import warnings


class SpatialAugmentedTensor(AugmentedTensor):
"""Spatial Augmented Tensor. Used to represets any 2D data. The spatial augmented tensor can be used as a
basis for images, depth or and spatially related data. Moreover, for stereo setup, the augmented tensor
Expand Down Expand Up @@ -131,7 +132,10 @@ def get_view(self, views: list = [], exclude=[], size=None, grid_size=None, titl
"""
_views = [v for v in views if isinstance(v, View)]
if len(_views) > 0:
return View(Renderer.get_grid_view(_views, grid_size=None, cell_grid_size=size, add_title=add_title, **kwargs), title=title)
return View(
Renderer.get_grid_view(_views, grid_size=None, cell_grid_size=size, add_title=add_title, **kwargs),
title=title,
)

# Include type
include_type = [
Expand Down Expand Up @@ -427,9 +431,15 @@ def _relative_to_absolute_hs_ws(self, hs=None, ws=None, assert_integer=True, war
assert hs is None or isinstance(hs, (list, tuple)), "hs should be a list or a tuple of floats"
assert ws is None or isinstance(ws, (list, tuple)), "ws should be a list or a tuple of floats"
if hs is not None:
hs = [self.relative_to_absolute(h, "H", assert_integer=assert_integer, warn_non_integer=warn_non_integer) for h in hs]
hs = [
self.relative_to_absolute(h, "H", assert_integer=assert_integer, warn_non_integer=warn_non_integer)
for h in hs
]
if ws is not None:
ws = [self.relative_to_absolute(w, "W", assert_integer=assert_integer, warn_non_integer=warn_non_integer) for w in ws]
ws = [
self.relative_to_absolute(w, "W", assert_integer=assert_integer, warn_non_integer=warn_non_integer)
for w in ws
]
return hs, ws

def _hflip_label(self, label, **kwargs):
Expand All @@ -441,6 +451,9 @@ def _hflip_label(self, label, **kwargs):
frame_size=self.HW, cam_intrinsic=self.cam_intrinsic, cam_extrinsic=self.cam_extrinsic, **kwargs
)
except AttributeError:
print(
f"[WARNING] Horizontal flip returned AttributeError on {type(label).__name__}, returning unflipped tensor."
)
return label
else:
return label_flipped
Expand All @@ -454,6 +467,9 @@ def _vflip_label(self, label, **kwargs):
frame_size=self.HW, cam_intrinsic=self.cam_intrinsic, cam_extrinsic=self.cam_extrinsic, **kwargs
)
except AttributeError:
print(
f"[WARNING] Vertical flip returned AttributeError on {type(label).__name__}, returning unflipped tensor."
)
return label
else:
return label_flipped
Expand Down Expand Up @@ -528,7 +544,7 @@ def _resize(self, size, interpolation=InterpolationMode.BILINEAR, **kwargs):
return self.rename(None).view(shapes).reset_names()
return F.resize(self.rename(None), (h, w), interpolation=interpolation).reset_names()

def _rotate(self, angle, center=None,**kwargs):
def _rotate(self, angle, center=None, **kwargs):
"""Rotate SpatialAugmentedTensor, but not its labels
Parameters
Expand All @@ -546,7 +562,7 @@ def _rotate(self, angle, center=None,**kwargs):
assert not (
("N" in self.names and self.size("N") == 0) or ("C" in self.names and self.size("C") == 0)
), "rotation is not possible on an empty tensor"
return F.rotate(self.rename(None), angle,center=center).reset_names()
return F.rotate(self.rename(None), angle, center=center).reset_names()

def _crop(self, H_crop: tuple, W_crop: tuple, **kwargs):
"""Crop the SpatialAugmentedTensor
Expand Down Expand Up @@ -576,6 +592,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs):
label_pad = label._pad(offset_y, offset_x, **kwargs)
return label_pad
except AttributeError:
print(f"[WARNING] Padding returned AttributeError on {type(label).__name__}, returning unpadded tensor.")
return label

def _pad(self, offset_y: tuple, offset_x: tuple, **kwargs):
Expand Down

0 comments on commit e78bdac

Please sign in to comment.