diff --git a/alodataset/transforms.py b/alodataset/transforms.py index b5592369..9d288c98 100644 --- a/alodataset/transforms.py +++ b/alodataset/transforms.py @@ -622,13 +622,14 @@ def set_params(self): def apply(self, frame: Frame): n_frame = frame.norm01() - gaussian_noise = torch.normal(mean=0, std=self.gaussian_std, size=frame.shape) - shot_noise = torch.normal(mean=0, std=self.shot_std, size=frame.shape) + gaussian_noise = torch.normal(mean=0, std=self.gaussian_std, size=frame.shape, device=frame.device) + shot_noise = torch.normal(mean=0, std=self.shot_std, size=frame.shape, device=frame.device) noisy_frame = n_frame + n_frame * n_frame * shot_noise + gaussian_noise noisy_frame = torch.clip(noisy_frame, 0, 1) - if n_frame.normalization != frame.normalization: - n_frame = n_frame.norm_as(frame) + if noisy_frame.normalization != frame.normalization: + noisy_frame = noisy_frame.norm_as(frame) + return noisy_frame diff --git a/aloscene/__init__.py b/aloscene/__init__.py index 07faefb3..77e3b830 100644 --- a/aloscene/__init__.py +++ b/aloscene/__init__.py @@ -8,7 +8,62 @@ from .points_2d import Points2D from .points_3d import Points3D from .disparity import Disparity +from .pose import Pose from .bounding_boxes_2d import BoundingBoxes2D from .bounding_boxes_3d import BoundingBoxes3D from .oriented_boxes_2d import OrientedBoxes2D from .frame import Frame +from .tensors.spatial_augmented_tensor import SpatialAugmentedTensor + +from .renderer import Renderer + +def batch_list(tensors): + return SpatialAugmentedTensor.batch_list(tensors) + +_renderer = None +def render( + views: list, + renderer: str = "cv", + size=None, + record_file: str = None, + fps=30, + grid_size=None, + skip_views=False, + ): + """Render a list of view. + + Parameters + ---------- + views : list + List of np.darray to display + renderer : str + String to set the renderer to use. Can be either ("cv" or "matplotlib") + cell_grid_size : tuple + Tuple or None. If not None, the tuple values (height, width) will be used + to set the size of the each grid cell of the display. If only one view is used, + the view will be resize to the cell grid size. + record_file : str + None by default. Used to save the rendering into one video. + skip_views : bool, optional + Skip views, in order to speed up the render process, by default False + """ + global _renderer + _renderer = Renderer() if _renderer is None else _renderer + + _renderer.render( + views=views, + renderer=renderer, + cell_grid_size=size, + record_file=record_file, + fps=fps, + grid_size=grid_size, + skip_views=skip_views + ) + + +def save_renderer(): + """If render() was called with a `record_file`, then this method will + save the final video on the system. Warning: It is currently not possible + to save multiple stream directly with aloception. Todo so, one can manually create multiple `Renderer`. + """ + _renderer.save() diff --git a/aloscene/camera_calib.py b/aloscene/camera_calib.py index a2b0d8f0..51efb776 100644 --- a/aloscene/camera_calib.py +++ b/aloscene/camera_calib.py @@ -187,6 +187,23 @@ def __new__(cls, x, *args, **kwargs): tensor = super().__new__(cls, x, *args, **kwargs) return tensor + def translation_with(self, tgt_pos): + """ Compute the translation with an other pos + + Parameters + ---------- + tgt_pos: aloscene.Pose + Target pose to compute the translation with + + Returns + ------- + n_pos: torch.tensor + Translation tensor of shape (..., 3) + """ + Ttgt2self = torch.linalg.solve(self.as_tensor(), tgt_pos.as_tensor()) + translation = Ttgt2self[..., :3, -1] + return translation + def __init__(self, x, *args, **kwargs): assert x.shape[-2] == 4 and x.shape[-1] == 4 super().__init__(x) diff --git a/aloscene/depth.py b/aloscene/depth.py index c942a5ae..b8bb1288 100644 --- a/aloscene/depth.py +++ b/aloscene/depth.py @@ -23,7 +23,7 @@ class Depth(aloscene.tensors.SpatialAugmentedTensor): occlusion : aloscene.Mask Occlusion mask for this Depth map. Default value : None. is_bsolute: bool - Either depth values refer to real values or shifted and scaled ones. + Either depth values refer to real values or shifted and scaled ones. scale: float Scale used to to shift depth. Pass this argument only if is_bsolute is set to True shift: float @@ -32,14 +32,14 @@ class Depth(aloscene.tensors.SpatialAugmentedTensor): @staticmethod def __new__( - cls, - x, - occlusion: Mask = None, - is_absolute=False, - scale=None, - shift=None, - *args, - names=("C", "H", "W"), + cls, + x, + occlusion: Mask = None, + is_absolute=False, + scale=None, + shift=None, + *args, + names=("C", "H", "W"), **kwargs): if is_absolute and not (shift and scale): raise AttributeError('absolute depth requires shift and scale arguments') @@ -58,9 +58,20 @@ def __new__( def __init__(self, x, *args, **kwargs): super().__init__(x) - def encode_inverse(self): + def encode_inverse(self, prior_clamp_min=None, prior_clamp_max=None, post_clamp_min=None, post_clamp_max=None): """Undo encode_absolute tansformation - + + Parameters + ---------- + prior_clamp_min: float | None + Clamp min depth before to convert to idepth + prior_clamp_max: float | None + Clamp max depth before to convert to idepth + post_clamp_min: float | None + Clamp min output idepth + post_clamp_max: float | None + Clamp max output idepth + Exemples ------- >>> not_absolute_depth = Depth(torch.ones((1, 1, 1)), is_absolute=False) @@ -72,23 +83,40 @@ def encode_inverse(self): depth = self if not depth.is_absolute: raise ExecError('can not inverse depth, already inversed') + shift = depth.shift if depth.shift is not None else 0 + scale = depth.scale if depth.scale is not None else 1 + + if prior_clamp_min is not None or prior_clamp_max is not None: + depth = torch.clamp(depth, min=prior_clamp_min, max=prior_clamp_max) + depth = 1 / depth - depth = (depth - depth.shift) / depth.scale + depth = (depth - shift) / scale + + if post_clamp_min is not None or post_clamp_max is not None: + depth = torch.clamp(depth, min=post_clamp_min, max=post_clamp_max) + depth.scale = None depth.shift = None depth.is_absolute = False return depth - - def encode_absolute(self, scale=1, shift=0): + + def encode_absolute(self, scale=1, shift=0, prior_clamp_min=None, prior_clamp_max=None, post_clamp_min=None, post_clamp_max=None): """Transforms inverted depth to absolute depth - + Parameters ---------- scale: (: float) Multiplication factor. Default is 1. - shift: (: float) Addition intercept. Default is 0. + prior_clamp_min: float | None + Clamp min idepth before to convert to depth + prior_clamp_max: float | None + Clamp max idepth before to convert to depth + post_clamp_min: float | None + Clamp min output idepth + post_clamp_max: float | None + Clamp max output idepth Exemples -------- @@ -100,12 +128,24 @@ def encode_absolute(self, scale=1, shift=0): depth, names = self.rename(None), self.names if depth.is_absolute: raise ExecError('depth already in absolute state, call encode_inverse first') + depth = depth * scale + shift + + if prior_clamp_min is not None or prior_clamp_max is not None: + depth = torch.clamp(depth, min=prior_clamp_min, max=prior_clamp_max) + depth[torch.unsqueeze(depth < 1e-8, dim=0)] = 1e-8 depth.scale = scale depth.shift = shift depth.is_absolute = True - return (1 / depth).rename(*names) + + n_depth = (1 / depth).rename(*names) + + if post_clamp_min is not None or post_clamp_max is not None: + n_depth = torch.clamp(n_depth, min=post_clamp_min, max=post_clamp_max) + + return n_depth + def append_occlusion(self, occlusion: Mask, name: str = None): """Attach an occlusion mask to the depth tensor. diff --git a/aloscene/frame.py b/aloscene/frame.py index b851493a..f531d451 100644 --- a/aloscene/frame.py +++ b/aloscene/frame.py @@ -6,7 +6,7 @@ import aloscene from aloscene.renderer import View -from aloscene import BoundingBoxes2D, BoundingBoxes3D, Depth, Disparity, Flow, Mask, Labels, Points2D, Points3D +from aloscene import BoundingBoxes2D, BoundingBoxes3D, Depth, Disparity, Flow, Mask, Labels, Points2D, Points3D, Pose # from aloscene.camera_calib import CameraExtrinsic, CameraIntrinsic from aloscene.io.image import load_image @@ -113,6 +113,7 @@ def __new__( tensor.add_child("depth", depth, align_dim=["B", "T"], mergeable=True) tensor.add_child("segmentation", segmentation, align_dim=["B", "T"], mergeable=False) tensor.add_child("labels", labels, align_dim=["B", "T"], mergeable=True) + tensor.add_child("pose", labels, align_dim=["B", "T"], mergeable=True) # Add other tensor property tensor.add_property("normalization", normalization) @@ -304,6 +305,19 @@ def append_segmentation(self, segmentation: Mask, name: str = None): """ self._append_child("segmentation", segmentation, name) + def append_pose(self, pose: Pose, name: str = None): + """Attach a pose to the frame. + + Parameters + ---------- + pose : :mod:`Pose ` + Depth to attach to the Frame + name : str + If none, the pose will be attached without name (if possible). Otherwise if no other unnamed + pose is attached to the frame, the pose will be added to the set of poses. + """ + self._append_child("pose", pose, name) + @staticmethod def _get_mean_std_tensor(shape, names, mean_std: tuple, device="cpu"): """Utils method to a get the mean and the std diff --git a/aloscene/pose.py b/aloscene/pose.py new file mode 100644 index 00000000..f6be2572 --- /dev/null +++ b/aloscene/pose.py @@ -0,0 +1,38 @@ +from aloscene.camera_calib import CameraExtrinsic +import torch + + +class Pose(CameraExtrinsic): + """Pose Tensor. Usually use to store World2Frame coordinates + + Parameters + ---------- + x: torch.Tensor + Pose matrix + """ + + @staticmethod + def __new__(cls, x, *args, names=(None, None), **kwargs): + tensor = super().__new__(cls, x, *args, names=names, **kwargs) + return tensor + + def __init__(self, x, *args, **kwargs): + super().__init__(x) + + def _hflip(self, *args, **kwargs): + return self.clone() + + def _vflip(self, *args, **kwargs): + return self.clone() + + def _resize(self, *args, **kwargs): + # Resize image does not change cam extrinsic + return self.clone() + + def _crop(self, *args, **kwargs): + # Cropping image does not change cam extrinsic + return self.clone() + + def _pad(self, *args, **kwargs): + # Padding image does not change cam extrinsic + return self.clone()