From d9b38785495146041ff010049162c907dd3a8b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Cecille?= Date: Thu, 6 Apr 2023 11:31:50 +0200 Subject: [PATCH] Add warning when augmetnations fail --- aloscene/tensors/augmented_tensor.py | 41 ++++++++++++++------ aloscene/tensors/spatial_augmented_tensor.py | 27 ++++++++++--- 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/aloscene/tensors/augmented_tensor.py b/aloscene/tensors/augmented_tensor.py index 56d3d53d..61f06590 100644 --- a/aloscene/tensors/augmented_tensor.py +++ b/aloscene/tensors/augmented_tensor.py @@ -7,7 +7,7 @@ def _torch_function_get_self(cls, func, types, args, kwargs): - """ Based on this dicussion https://github.com/pytorch/pytorch/issues/63767 + """Based on this dicussion https://github.com/pytorch/pytorch/issues/63767 "A simple solution would be to scan the args for the first subclass of this class. My question is more: will forcing this to be a subclass actually be a problem for some use case? @@ -36,7 +36,7 @@ class AugmentedTensor(torch.Tensor): # Ignore named tansors userwarning. ERROR_MSG = "Named tensors and all their associated APIs are an experimental feature and subject to change" - warnings.filterwarnings(action='ignore', message=ERROR_MSG) + warnings.filterwarnings(action="ignore", message=ERROR_MSG) @staticmethod def __new__(cls, x, names=None, device=None, *args, **kwargs): @@ -335,7 +335,6 @@ def __getitem__(self, idx): tensor = tensor.reset_names() if len(self.shape) == len(tensor.shape) else tensor.as_tensor() - # if not idx.dtype == torch.bool: # if not torch.equal(idx ** 3, idx): # raise IndexError(f"Unvalid mask. Expected mask elements to be in [0, 1, True, False]") @@ -475,6 +474,7 @@ def _fillup_dict(dm, sub_label, dim, target_dim): else: for s in range(len(sub_label)): _fillup_dict(dm[s], sub_label[s], dim + 1, target_dim) + _fillup_dict(dict_merge[key], label, 0, target_dim) return dict_merge @@ -499,7 +499,9 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None setattr(n_tensor, prop, None) else: values = set([prop_name_to_value[prop], getattr(tensor, prop)]) - raise RuntimeError(f"Encountered different values for property '{prop}' while merging AugmentedTensor: {values}") + raise RuntimeError( + f"Encountered different values for property '{prop}' while merging AugmentedTensor: {values}" + ) else: prop_name_to_value[prop] = getattr(tensor, prop) @@ -545,7 +547,9 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None if intersection: del labels_dict2list[label_name][key] else: - raise RuntimeError(f"Error during merging. Some tensors have label '{label_name}' with key '{key}' and some don't") + raise RuntimeError( + f"Error during merging. Some tensors have label '{label_name}' with key '{key}' and some don't" + ) else: args = list(args) args[0] = labels_dict2list[label_name][key] @@ -553,7 +557,9 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None # if we removed all keys, set this child to None if intersection and not labels_dict2list[label_name]: labels_dict2list[label_name] = None - elif intersection and (len(labels_dict2list[label_name]) != dim_size or (None in labels_dict2list[label_name])): + elif intersection and ( + len(labels_dict2list[label_name]) != dim_size or (None in labels_dict2list[label_name]) + ): labels_dict2list[label_name] = None else: args = list(args) @@ -595,7 +601,6 @@ def __iter__(self): for t in range(len(self)): yield self[t] - @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): self = _torch_function_get_self(cls, func, types, args, kwargs) @@ -614,10 +619,10 @@ def _merging_frame(args): if func.__name__ == "__reduce_ex__": self.rename_(None, auto_restore_names=True) tensor = super().__torch_function__(func, types, args, kwargs) - #tensor = super().torch_func_method(func, types, args, kwargs) + # tensor = super().torch_func_method(func, types, args, kwargs) else: tensor = super().__torch_function__(func, types, args, kwargs) - #tensor = super().torch_func_method(func, types, args, kwargs) + # tensor = super().torch_func_method(func, types, args, kwargs) if isinstance(tensor, type(self)): tensor._property_list = self._property_list @@ -853,6 +858,9 @@ def _hflip_label(self, label, **kwargs): try: label_flipped = label._hflip(**kwargs) except AttributeError: + print( + f"[WARNING] Horizontal flip returned AttributeError on {type(label).__name__}, returning unflipped tensor." + ) return label else: return label_flipped @@ -914,6 +922,7 @@ def resize_func(label): label_resized = label._resize(size01, **kwargs) return label_resized except AttributeError: + print(f"[WARNING] resize returned AttributeError on {type(label).__name__}, returning initial tensor.") return label resized = self._resize(size01, **kwargs) @@ -924,7 +933,7 @@ def resize_func(label): def _resize(self, *args, **kwargs): raise Exception("This Augmented tensor should implement this method") - def rotate(self, angle, center=None,**kwargs): + def rotate(self, angle, center=None, **kwargs): """ Rotate AugmentedTensor, and its labels recursively @@ -941,12 +950,15 @@ def rotate(self, angle, center=None,**kwargs): def rotate_func(label): try: - label_rotated = label._rotate(angle, center,**kwargs) + label_rotated = label._rotate(angle, center, **kwargs) return label_rotated except AttributeError: + print( + f"[WARNING] Rotate returned AttributeError on {type(label).__name__}, returning unrotated tensor." + ) return label - rotated = self._rotate(angle, center,**kwargs) + rotated = self._rotate(angle, center, **kwargs) rotated.recursive_apply_on_children_(rotate_func) return rotated @@ -956,6 +968,7 @@ def _crop_label(self, label, H_crop, W_crop, **kwargs): label_resized = label._crop(H_crop, W_crop, **kwargs) return label_resized except AttributeError: + print(f"[WARNING] Crop returned AttributeError on {type(label).__name__}, returning uncropped tensor.") return label def crop(self, H_crop: tuple, W_crop: tuple, **kwargs): @@ -991,6 +1004,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs): label_pad = label._pad(offset_y, offset_x, **kwargs) return label_pad except AttributeError: + print(f"[WARNING] Padding returned AttributeError on {type(label).__name__}, returning unpadded tensor.") return label def pad(self, offset_y: tuple = None, offset_x: tuple = None, multiple: int = None, **kwargs): @@ -1043,6 +1057,9 @@ def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs): label_shift = label._spatial_shift(shift_y, shift_x, **kwargs) return label_shift except AttributeError: + print( + f"[WARNING] Spatial shift returned AttributeError on {type(label).__name__}, returning unshifted tensor." + ) return label def spatial_shift(self, shift_y: float, shift_x: float, **kwargs): diff --git a/aloscene/tensors/spatial_augmented_tensor.py b/aloscene/tensors/spatial_augmented_tensor.py index 3b8ede0e..b4cb5009 100644 --- a/aloscene/tensors/spatial_augmented_tensor.py +++ b/aloscene/tensors/spatial_augmented_tensor.py @@ -14,6 +14,7 @@ import warnings + class SpatialAugmentedTensor(AugmentedTensor): """Spatial Augmented Tensor. Used to represets any 2D data. The spatial augmented tensor can be used as a basis for images, depth or and spatially related data. Moreover, for stereo setup, the augmented tensor @@ -131,7 +132,10 @@ def get_view(self, views: list = [], exclude=[], size=None, grid_size=None, titl """ _views = [v for v in views if isinstance(v, View)] if len(_views) > 0: - return View(Renderer.get_grid_view(_views, grid_size=None, cell_grid_size=size, add_title=add_title, **kwargs), title=title) + return View( + Renderer.get_grid_view(_views, grid_size=None, cell_grid_size=size, add_title=add_title, **kwargs), + title=title, + ) # Include type include_type = [ @@ -427,9 +431,15 @@ def _relative_to_absolute_hs_ws(self, hs=None, ws=None, assert_integer=True, war assert hs is None or isinstance(hs, (list, tuple)), "hs should be a list or a tuple of floats" assert ws is None or isinstance(ws, (list, tuple)), "ws should be a list or a tuple of floats" if hs is not None: - hs = [self.relative_to_absolute(h, "H", assert_integer=assert_integer, warn_non_integer=warn_non_integer) for h in hs] + hs = [ + self.relative_to_absolute(h, "H", assert_integer=assert_integer, warn_non_integer=warn_non_integer) + for h in hs + ] if ws is not None: - ws = [self.relative_to_absolute(w, "W", assert_integer=assert_integer, warn_non_integer=warn_non_integer) for w in ws] + ws = [ + self.relative_to_absolute(w, "W", assert_integer=assert_integer, warn_non_integer=warn_non_integer) + for w in ws + ] return hs, ws def _hflip_label(self, label, **kwargs): @@ -441,6 +451,9 @@ def _hflip_label(self, label, **kwargs): frame_size=self.HW, cam_intrinsic=self.cam_intrinsic, cam_extrinsic=self.cam_extrinsic, **kwargs ) except AttributeError: + print( + f"[WARNING] Horizontal flip returned AttributeError on {type(label).__name__}, returning unflipped tensor." + ) return label else: return label_flipped @@ -454,6 +467,9 @@ def _vflip_label(self, label, **kwargs): frame_size=self.HW, cam_intrinsic=self.cam_intrinsic, cam_extrinsic=self.cam_extrinsic, **kwargs ) except AttributeError: + print( + f"[WARNING] Vertical flip returned AttributeError on {type(label).__name__}, returning unflipped tensor." + ) return label else: return label_flipped @@ -528,7 +544,7 @@ def _resize(self, size, interpolation=InterpolationMode.BILINEAR, **kwargs): return self.rename(None).view(shapes).reset_names() return F.resize(self.rename(None), (h, w), interpolation=interpolation).reset_names() - def _rotate(self, angle, center=None,**kwargs): + def _rotate(self, angle, center=None, **kwargs): """Rotate SpatialAugmentedTensor, but not its labels Parameters @@ -546,7 +562,7 @@ def _rotate(self, angle, center=None,**kwargs): assert not ( ("N" in self.names and self.size("N") == 0) or ("C" in self.names and self.size("C") == 0) ), "rotation is not possible on an empty tensor" - return F.rotate(self.rename(None), angle,center=center).reset_names() + return F.rotate(self.rename(None), angle, center=center).reset_names() def _crop(self, H_crop: tuple, W_crop: tuple, **kwargs): """Crop the SpatialAugmentedTensor @@ -576,6 +592,7 @@ def _pad_label(self, label, offset_y, offset_x, **kwargs): label_pad = label._pad(offset_y, offset_x, **kwargs) return label_pad except AttributeError: + print(f"[WARNING] Padding returned AttributeError on {type(label).__name__}, returning unpadded tensor.") return label def _pad(self, offset_y: tuple, offset_x: tuple, **kwargs):