diff --git a/robomimic/config/base_config.py b/robomimic/config/base_config.py index 62129cd0..c4befcaf 100644 --- a/robomimic/config/base_config.py +++ b/robomimic/config/base_config.py @@ -272,6 +272,16 @@ def observation_config(self): self.observation.encoder.scan.core_kwargs = Config() # See models/obs_core.py for important kwargs to set and defaults used self.observation.encoder.scan.core_kwargs.do_not_lock_keys() + # =============== Spatial default encoder (pointnet) =============== + self.observation.encoder.spatial = deepcopy(self.observation.encoder.rgb) + + # Scan: Modify the core class + kwargs, otherwise, is same as rgb encoder + self.observation.encoder.spatial.core_class = "SpatialCore" # Default ScanCore class uses Conv1D to process this modality + # self.observation.encoder.spatial.core_class = "SpatialTest" # Default ScanCore class uses Conv1D to process this modality + # self.observation.encoder.spatial.core_class = "SparseTransformer" # Default ScanCore class uses Conv1D to process this modality + self.observation.encoder.scan.core_kwargs = Config() # See models/obs_core.py for important kwargs to set and defaults used + self.observation.encoder.scan.core_kwargs.do_not_lock_keys() + def meta_config(self): """ This function populates the `config.meta` attribute of the config. This portion of the config diff --git a/robomimic/envs/env_robosuite.py b/robomimic/envs/env_robosuite.py index c17c0d7d..3fac8778 100644 --- a/robomimic/envs/env_robosuite.py +++ b/robomimic/envs/env_robosuite.py @@ -7,7 +7,9 @@ import numpy as np from copy import deepcopy +import mujoco import robosuite +from robosuite.utils.camera_utils import get_real_depth_map, get_camera_extrinsic_matrix, get_camera_intrinsic_matrix try: # this is needed for ensuring robosuite can find the additional mimicgen environments (see https://mimicgen.github.io) import mimicgen_envs @@ -70,7 +72,7 @@ def __init__( ignore_done=True, use_object_obs=True, use_camera_obs=use_image_obs, - camera_depths=False, + camera_depths=kwargs['camera_depths'], ) kwargs.update(update_kwargs) @@ -201,9 +203,21 @@ def get_observation(self, di=None): ret[k] = di[k][::-1] if self.postprocess_visual_obs: ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k) + if (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=k, obs_modality="depth"): + ret[k] = di[k][::-1] + ret[k] = get_real_depth_map(self.env.sim, ret[k]) + if self.postprocess_visual_obs: + ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k) # "object" key contains object information ret["object"] = np.array(di["object-state"]) + + # save camera intrinsics and extrinsics + for cam_idx, camera_name in enumerate(self.env.camera_names): + cam_height = self.env.camera_heights[cam_idx] + cam_width = self.env.camera_widths[cam_idx] + ret[f'{camera_name}_extrinsic'] = get_camera_extrinsic_matrix(self.env.sim, camera_name) + ret[f'{camera_name}_intrinsic'] = get_camera_intrinsic_matrix(self.env.sim, camera_name, cam_height, cam_width) if self._is_v1: for robot in self.env.robots: @@ -357,6 +371,8 @@ def create_for_data_processing( image_modalities = list(camera_names) if is_v1: image_modalities = ["{}_image".format(cn) for cn in camera_names] + if kwargs['camera_depths']: + depth_modalities = ["{}_depth".format(cn) for cn in camera_names] elif has_camera: # v0.3 only had support for one image, and it was named "rgb" assert len(image_modalities) == 1 @@ -367,6 +383,8 @@ def create_for_data_processing( "rgb": image_modalities, } } + if kwargs['camera_depths']: + obs_modality_specs['obs']['depth'] = depth_modalities ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs) # note that @postprocess_visual_obs is False since this env's images will be written to a dataset diff --git a/robomimic/models/base_nets.py b/robomimic/models/base_nets.py index 0a4927e0..a5f003bf 100644 --- a/robomimic/models/base_nets.py +++ b/robomimic/models/base_nets.py @@ -15,6 +15,9 @@ from torchvision import transforms from torchvision import models as vision_models +import sys +sys.path.append('/home/yixuan/general_dp/robomimic') + import robomimic.utils.tensor_utils as TensorUtils @@ -1109,3 +1112,233 @@ def forward(self, x): # weighted mean-pooling return torch.sum(x * self.agg_weight, dim=1) raise Exception("unexpected agg type: {}".forward(self.agg_type)) + +class Conv1dBNReLU(Module): + """Applies a 1D convolution over an input signal composed of several input planes, + optionally followed by batch normalization and ReLU activation. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, relu=True, bn=True, **kwargs + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, bias=(not bn), **kwargs + ) + self.bn = nn.BatchNorm1d(out_channels) if bn else None + self.relu = nn.ReLU(inplace=True) if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Conv2dBNReLU(Module): + """Applies a 2D convolution (optionally with batch normalization and relu activation) + over an input signal composed of several input planes. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, relu=True, bn=True, **kwargs + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, bias=(not bn), **kwargs + ) + self.bn = nn.BatchNorm2d(out_channels) if bn else None + self.relu = nn.ReLU(inplace=True) if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class LinearBNReLU(Module): + """Applies a linear transformation to the incoming data + optionally followed by batch normalization and relu activation + """ + + def __init__(self, in_channels, out_channels, relu=True, bn=True): + super(LinearBNReLU, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.fc = nn.Linear(in_channels, out_channels, bias=(not bn)) + self.bn = nn.BatchNorm1d(out_channels) if bn else None + self.relu = nn.ReLU(inplace=True) if relu else None + + def forward(self, x): + x = self.fc(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + +def mlp_bn_relu(in_channels, out_channels_list): + c_in = in_channels + layers = [] + for c_out in out_channels_list: + layers.append(LinearBNReLU(c_in, c_out, relu=True, bn=True)) + c_in = c_out + return nn.Sequential(*layers) + + +def mlp_relu(in_channels, out_channels_list): + c_in = in_channels + layers = [] + for c_out in out_channels_list: + layers.append(LinearBNReLU(c_in, c_out, relu=True, bn=False)) + c_in = c_out + return nn.Sequential(*layers) + + +def mlp1d_bn_relu(in_channels, out_channels_list): + c_in = in_channels + layers = [] + for c_out in out_channels_list: + layers.append(Conv1dBNReLU(c_in, c_out, 1, relu=True)) + c_in = c_out + return nn.Sequential(*layers) + + +def mlp1d_relu(in_channels, out_channels_list): + c_in = in_channels + layers = [] + for c_out in out_channels_list: + layers.append(Conv1dBNReLU(c_in, c_out, 1, relu=True, bn=False)) + c_in = c_out + return nn.Sequential(*layers) + + +def mlp2d_bn_relu(in_channels, out_channels_list): + c_in = in_channels + layers = [] + for c_out in out_channels_list: + layers.append(Conv2dBNReLU(c_in, c_out, 1, relu=True)) + c_in = c_out + return nn.Sequential(*layers) + + +def mlp2d_relu(in_channels, out_channels_list): + c_in = in_channels + layers = [] + for c_out in out_channels_list: + layers.append(Conv2dBNReLU(c_in, c_out, 1, relu=True, bn=False)) + c_in = c_out + return nn.Sequential(*layers) + +class PointNet(Module): + """PointNet for classification. + Notes: + 1. The original implementation includes dropout for global MLPs. + 2. The original implementation decays the BN momentum. + """ + + def __init__( + self, + in_channels=3, + local_channels=(64, 64, 64, 128, 1024), + global_channels=(512, 256), + use_bn=True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = (local_channels + global_channels)[-1] + self.use_bn = use_bn + + if use_bn: + self.mlp_local = mlp1d_bn_relu(in_channels, local_channels) + self.mlp_global = mlp_bn_relu(local_channels[-1], global_channels) + else: + self.mlp_local = mlp1d_relu(in_channels, local_channels) + self.mlp_global = mlp_relu(local_channels[-1], global_channels) + + self.reset_parameters() + + def forward_internal(self, points, points_feature=None, points_mask=None) -> dict: + # points: [B, 3, N]; points_feature: [B, C, N], points_mask: [B, N] + if points_feature is not None: + input_feature = torch.cat([points, points_feature], dim=1) + else: + input_feature = points + + local_feature = self.mlp_local(input_feature) + if points_mask is not None: + local_feature = torch.where( + points_mask.unsqueeze(1), local_feature, torch.zeros_like(local_feature) + ) + global_feature, max_indices = torch.max(local_feature, 2) + output_feature = self.mlp_global(global_feature) + + return {"feature": output_feature, "max_indices": max_indices} + + def forward(self, feats_points): + # points: [B, 3 + C, N] + points = feats_points[:, :3, :] + if feats_points.shape[1] > 3: + points_feature = feats_points[:, 3:, :] + else: + points_feature = None + return self.forward_internal(points, points_feature)['feature'] + + def reset_parameters(self): + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + if module.bias is not None: + nn.init.zeros_(module.bias) + if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): + module.momentum = 0.01 + +class PointMLP(Module): + def __init__( + self, + ): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(3 * 100 * 2, 512), + nn.ReLU(), + nn.Linear(512, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256), + ) + print('WARNING: PointMLP is only for testing purpose') + + def forward(self, points): + # points: [B, 3 + C, N] + points = points[:, :3, :] + B, _, N = points.shape + points = points.reshape(B, 3 * 100 * 2) + return self.model(points) + +def test_pointnet(): + feat_points = torch.rand(2, 3+1024, 1000) # (B, C, N) + point_net = PointNet(3) + out = point_net(feat_points) + print(out.shape) + +if __name__ == '__main__': + test_pointnet() diff --git a/robomimic/models/obs_core.py b/robomimic/models/obs_core.py index c784fa27..2c0e3d5e 100644 --- a/robomimic/models/obs_core.py +++ b/robomimic/models/obs_core.py @@ -4,17 +4,21 @@ and randomizers (e.g. Randomizer, CropRandomizer). """ -import abc -import numpy as np -import textwrap -import random - import torch import torch.nn as nn from torchvision.transforms import Lambda, Compose import torchvision.transforms.functional as TVF +import abc +import numpy as np +import textwrap +import random +import sys +sys.path.append('/home/yixuan22/general_dp/robomimic') + import robomimic.models.base_nets as BaseNets +from robomimic.models.pointnet_utils import PointNetEncoder +from robomimic.models.pointnet2_utils import PointNet2Encoder import robomimic.utils.tensor_utils as TensorUtils import robomimic.utils.obs_utils as ObsUtils from robomimic.utils.python_utils import extract_class_init_kwargs_from_dict @@ -310,6 +314,343 @@ def __repr__(self): msg = header + '(' + msg + '\n)' return msg +""" +================================================ +Spatial Core Networks (PointNet) +================================================ +""" +# class SpatialCore(EncoderCore, BaseNets.ConvBase): +# """ +# A PointNet +# """ +# def __init__(self, +# input_shape, +# output_dim=256): +# super(SpatialCore, self).__init__(input_shape=input_shape) +# self.output_dim = output_dim +# # self.nets = PointNet(in_channels=input_shape[0]) +# # self.nets = PointNet(in_channels=3) +# # self.nets = PointNetEncoder(global_feat=True, channel=3) +# self.nets = PointNet2Encoder(in_channel=3) +# self.compositional = True +# self.use_pos = True +# if self.use_pos: +# self.pos_mlp = nn.Sequential( +# nn.Linear(3, 64), +# nn.ReLU(), +# nn.Linear(64, 128), +# nn.ReLU(), +# nn.Linear(128, 256), +# nn.ReLU(), +# nn.Linear(256, 128), +# nn.ReLU(), +# nn.Linear(128, 64), +# ) +# if self.compositional: +# self.output_dim += 64 +# else: +# self.output_dim += 64 * 2 + +# if self.compositional: +# self.output_dim *= 2 + +# self.post_proc_mlp = nn.Sequential( +# nn.Linear(self.output_dim, 512), +# nn.ReLU(), +# nn.Linear(512, 512), +# nn.ReLU(), +# nn.Linear(512, 512), +# nn.ReLU(), +# nn.Linear(512, self.output_dim), +# ) + +# def output_shape(self, input_shape): +# return [self.output_dim] + +# def forward(self, inputs): +# """ +# Forward pass through visual core. +# """ +# ndim = len(self.input_shape) +# B, D, N = inputs.shape +# N_per_obj = 100 +# N_obj = N // N_per_obj +# inputs = inputs[:, :3, :] +# # pointnet_feats, _, _ = self.nets(inputs) +# if self.compositional: +# inputs = inputs.reshape(B, 3, N_obj, N_per_obj) +# inputs = inputs.permute(0, 2, 1, 3).reshape(B * N_obj, 3, N_per_obj) +# pointnet_feats, _ = self.nets(inputs) # (B * N_obj, 256) +# else: +# pointnet_feats, _ = self.nets(inputs) # (B, 256) + +# # append pos feats +# if self.use_pos: +# if self.compositional: +# pos_feats = self.pos_mlp(inputs[:, :3, :].mean(dim=-1)) +# else: +# inputs = inputs.reshape(B, 3, N_obj, N_per_obj) +# inputs = inputs.permute(0, 2, 1, 3).reshape(B * N_obj, 3, N_per_obj) +# pos_feats = self.pos_mlp(inputs[:, :3, :].mean(dim=-1)) # (B * N_obj, 64) +# pos_feats = pos_feats.reshape(B, N_obj * 64) # (B, N_obj * 64) +# pointnet_feats = torch.cat([pointnet_feats, pos_feats], dim=-1) # (B * N_obj or B, 256 + 64) + +# if self.compositional: +# pointnet_feats = pointnet_feats.reshape(B, self.output_dim) + +# pointnet_feats = self.post_proc_mlp(pointnet_feats) + +# return pointnet_feats + +class SpatialCore(EncoderCore, BaseNets.ConvBase): + """ + A PointNet + """ + def __init__(self, + input_shape, + output_dim=256): + super(SpatialCore, self).__init__(input_shape=input_shape) + self.output_dim = output_dim + self.compositional = True + self.use_pos = True + self.use_feats = False + self.preproc_feats = False + self.proj_feats = False + pn_net_cls = PointNet2Encoder # PointNet; PointNet2Encoder + if self.use_feats: + if self.preproc_feats: + self.nets = pn_net_cls(in_channels=3+4, use_bn=False) + elif self.proj_feats: + self.nets = pn_net_cls(in_channels=3+3, use_bn=False) + else: + self.nets = pn_net_cls(in_channels=input_shape[0], use_bn=False) + else: + self.nets = pn_net_cls(in_channels=3, use_bn=False) + if self.use_pos: + self.pos_mlp = nn.Sequential( + nn.Linear(3, 64), + nn.ReLU(), + nn.Linear(64, 128), + nn.ReLU(), + nn.Linear(128, 256), + nn.ReLU(), + nn.Linear(256, 128), + nn.ReLU(), + nn.Linear(128, 64), + ) + if self.compositional: + self.output_dim += 64 + else: + self.output_dim += 64 * 2 + if self.preproc_feats: + self.preproc_mlp = nn.Sequential( + nn.Linear(input_shape[0] - 3, 512), + nn.ReLU(), + nn.Linear(512, 256), + nn.ReLU(), + nn.Linear(256, 128), + nn.ReLU(), + nn.Linear(128, 64), + nn.ReLU(), + nn.Linear(64, 16), + nn.ReLU(), + nn.Linear(16, 4), + ) + + if self.proj_feats: + self.feats_proj = nn.Linear(input_shape[0] - 3, 3) + + if self.compositional: + self.output_dim *= 2 + + if self.use_pos: + self.post_proc_mlp = nn.Sequential( + nn.Linear(self.output_dim, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, self.output_dim), + ) + + def output_shape(self, input_shape): + return [self.output_dim] + + def forward(self, inputs): + """ + Forward pass through visual core. + """ + ndim = len(self.input_shape) + if not self.use_feats: + inputs = inputs[:, :3, :] + else: + if self.preproc_feats: + feats = inputs[:, 3:, :] + B_feats, D_feats, N_feats = feats.shape + feats = feats.permute(0, 2, 1).reshape(B_feats * N_feats, D_feats) + feats = self.preproc_mlp(feats) + feats = feats.reshape(B_feats, N_feats, 4).permute(0, 2, 1) + inputs = torch.cat([inputs[:, :3, :], feats], dim=1) + elif self.proj_feats: + feats = inputs[:, 3:, :] + B_feats, D_feats, N_feats = feats.shape + feats = feats.permute(0, 2, 1).reshape(B_feats * N_feats, D_feats) + feats = self.feats_proj(feats) + feats = feats.reshape(B_feats, N_feats, 3).permute(0, 2, 1) + inputs = torch.cat([inputs[:, :3, :], feats], dim=1) + + B, D, N = inputs.shape + N_per_obj = 100 + N_obj = N // N_per_obj + # pointnet_feats, _, _ = self.nets(inputs) + if self.compositional: + inputs = inputs.reshape(B, D, N_obj, N_per_obj) + inputs = inputs.permute(0, 2, 1, 3).reshape(B * N_obj, D, N_per_obj) + # pointnet_feats, _ = self.nets(inputs) # (B * N_obj, 256) + pointnet_feats = self.nets(inputs) # (B * N_obj, 256) + else: + # pointnet_feats, _ = self.nets(inputs) # (B, 256) + pointnet_feats = self.nets(inputs) # (B, 256) + + # append pos feats + if self.use_pos: + if self.compositional: + pos_feats = self.pos_mlp(inputs[:, :3, :].mean(dim=-1)) + else: + inputs = inputs.reshape(B, 3, N_obj, N_per_obj) + inputs = inputs.permute(0, 2, 1, 3).reshape(B * N_obj, 3, N_per_obj) + pos_feats = self.pos_mlp(inputs[:, :3, :].mean(dim=-1)) # (B * N_obj, 64) + pos_feats = pos_feats.reshape(B, N_obj * 64) # (B, N_obj * 64) + pointnet_feats = torch.cat([pointnet_feats, pos_feats], dim=-1) # (B * N_obj or B, 256 + 64) + + if self.compositional: + pointnet_feats = pointnet_feats.reshape(B, self.output_dim) + if self.use_pos: + pointnet_feats = self.post_proc_mlp(pointnet_feats) + + return pointnet_feats + +class SparseTransformer(EncoderCore, BaseNets.ConvBase): + def __init__(self, + input_shape, + output_dim=256): + super(SparseTransformer, self).__init__(input_shape=input_shape) + n_head = 4 + n_layer = 8 + p_drop_attn = 0.3 + self.pos_feat_dim = 128 + self.dino_feat_dim = 128 + self.use_feats = False + if self.use_feats: + n_emb = self.pos_feat_dim + self.dino_feat_dim + else: + n_emb = self.pos_feat_dim + self.pos_mlp = nn.Sequential( + nn.Linear(3, 64), + nn.ReLU(), + nn.Linear(64, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, self.pos_feat_dim), + ) + self.dino_feat_mlp = nn.Sequential( + nn.Linear(input_shape[0] - 3, 512), + nn.ReLU(), + nn.Linear(512, 256), + nn.ReLU(), + nn.Linear(256, 128), + nn.ReLU(), + nn.Linear(128, self.dino_feat_dim), + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, n_emb)) + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4*n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_layer + ) + self.output_dim = n_emb + # self.postproc_mlp = nn.Sequential( + # nn.Linear(n_emb, 256), + # nn.ReLU(), + # nn.Linear(256, 256), + # nn.ReLU(), + # nn.Linear(256, 256), + # nn.ReLU(), + # nn.Linear(256, output_dim), + # ) + # self.output_dim = output_dim + + # init cls token + nn.init.normal_(self.cls_token, std=0.02) + + def output_shape(self, input_shape): + return [self.output_dim] + + def forward(self, inputs): + B, D, N = inputs.shape + inputs = inputs.permute(0, 2, 1) # (B, N, D) + subsample = 20 + N_per_obj = 100 + N_obj = N // N_per_obj + assert N % N_per_obj == 0 # N must be divisible by N_per_obj + inputs = inputs.reshape(B, N_obj, N_per_obj, D) # (B, N_obj, N_per_obj, D) + inputs = inputs[:, :, :subsample, :] # (B, N_obj, subsample, D) + inputs = inputs.reshape(B, N_obj * subsample, D) # (B, N_obj * subsample, D) + + # preprocess inputs + pos_feats = self.pos_mlp(inputs[..., :3].reshape(B * N_obj * subsample, 3)).reshape(B, N_obj * subsample, self.pos_feat_dim) + if self.use_feats: + feat_feats = self.dino_feat_mlp(inputs[..., 3:].reshape(B * N_obj * subsample, D - 3)).reshape(B, N_obj * subsample, self.dino_feat_dim) + + # transformer + if self.use_feats: + tf_input = torch.cat([pos_feats, feat_feats], dim=-1) # (B, N_obj * subsample, n_emb) + else: + tf_input = pos_feats # (B, N_obj * subsample, n_emb) + tf_input = torch.cat([self.cls_token.repeat(B, 1, 1), tf_input], dim=1) # (B, N_obj * subsample + 1, n_emb) + + tf_output = self.encoder(tf_input) # (B, N_obj * subsample + 1, n_emb) + tf_output = tf_output[:, 0, :] # (B, n_emb) + + # # postprocess + # output = self.postproc_mlp(tf_output) # (B, output_dim) + + return tf_output + +""" +================================================ +Spatial Test Networks (PointMLP) +================================================ +""" +class SpatialTest(EncoderCore, BaseNets.ConvBase): + """ + A PointMLP + """ + def __init__(self, + input_shape, + output_dim=256): + super(SpatialTest, self).__init__(input_shape=input_shape) + self.output_dim = output_dim + self.nets = PointMLP() + + def output_shape(self, input_shape): + return [self.output_dim] + + def forward(self, inputs): + """ + Forward pass through visual core. + """ + return super(SpatialTest, self).forward(inputs) """ ================================================ @@ -826,3 +1167,11 @@ def __repr__(self): msg = header + f"(input_shape={self.input_shape}, noise_mean={self.noise_mean}, noise_std={self.noise_std}, " \ f"limits={self.limits}, num_samples={self.num_samples})" return msg + +def test_spatial_core(): + spatial_core = SpatialCore(input_shape=[100, 1027]).cuda() + pcd = torch.randn(16, 1027, 200).cuda() + print(spatial_core(pcd).shape) + +if __name__ == "__main__": + test_spatial_core() diff --git a/robomimic/models/obs_nets.py b/robomimic/models/obs_nets.py index b3284185..13046f0a 100644 --- a/robomimic/models/obs_nets.py +++ b/robomimic/models/obs_nets.py @@ -103,7 +103,7 @@ class ObservationEncoder(Module): Call @register_obs_key to register observation keys with the encoder and then finally call @make to create the encoder networks. """ - def __init__(self, feature_activation=nn.ReLU): + def __init__(self, feature_activation=None): """ Args: feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass diff --git a/robomimic/models/pointnet2_utils.py b/robomimic/models/pointnet2_utils.py new file mode 100644 index 00000000..a8467a0c --- /dev/null +++ b/robomimic/models/pointnet2_utils.py @@ -0,0 +1,359 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from time import time +import numpy as np + +def timeit(tag, t): + print("{}: {}s".format(tag, time() - t)) + return time() + +def pc_normalize(pc): + l = pc.shape[0] + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + pc = pc / m + return pc + +def square_distance(src, dst): + """ + Calculate Euclid distance between each two points. + + src^T * dst = xn * xm + yn * ym + zn * zm; + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + Output: + dist: per-point square distance, [B, N, M] + """ + B, N, _ = src.shape + _, M, _ = dst.shape + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) + dist += torch.sum(src ** 2, -1).view(B, N, 1) + dist += torch.sum(dst ** 2, -1).view(B, 1, M) + return dist + + +def index_points(points, idx): + """ + + Input: + points: input points data, [B, N, C] + idx: sample index data, [B, S] + Return: + new_points:, indexed points data, [B, S, C] + """ + device = points.device + B = points.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) + new_points = points[batch_indices, idx, :] + return new_points + + +def farthest_point_sample(xyz, npoint): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + centroids[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + return centroids + + +def query_ball_point(radius, nsample, xyz, new_xyz): + """ + Input: + radius: local region radius + nsample: max sample number in local region + xyz: all points, [B, N, 3] + new_xyz: query points, [B, S, 3] + Return: + group_idx: grouped points index, [B, S, nsample] + """ + device = xyz.device + B, N, C = xyz.shape + _, S, _ = new_xyz.shape + group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) + sqrdists = square_distance(new_xyz, xyz) + group_idx[sqrdists > radius ** 2] = N + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] + group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) + mask = group_idx == N + group_idx[mask] = group_first[mask] + return group_idx + + +def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): + """ + Input: + npoint: + radius: + nsample: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, npoint, nsample, 3] + new_points: sampled points data, [B, npoint, nsample, 3+D] + """ + B, N, C = xyz.shape + S = npoint + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] + new_xyz = index_points(xyz, fps_idx) + idx = query_ball_point(radius, nsample, xyz, new_xyz) + grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] + grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) + + if points is not None: + grouped_points = index_points(points, idx) + new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] + else: + new_points = grouped_xyz_norm + if returnfps: + return new_xyz, new_points, grouped_xyz, fps_idx + else: + return new_xyz, new_points + + +def sample_and_group_all(xyz, points): + """ + Input: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, 1, 3] + new_points: sampled points data, [B, 1, N, 3+D] + """ + device = xyz.device + B, N, C = xyz.shape + new_xyz = torch.zeros(B, 1, C).to(device) + grouped_xyz = xyz.view(B, 1, N, C) + if points is not None: + new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) + else: + new_points = grouped_xyz + return new_xyz, new_points + + +class PointNetSetAbstraction(nn.Module): + def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, bn=False): + super(PointNetSetAbstraction, self).__init__() + self.npoint = npoint + self.radius = radius + self.nsample = nsample + self.bn = bn + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) + if self.bn: + self.mlp_bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.group_all = group_all + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + if self.group_all: + new_xyz, new_points = sample_and_group_all(xyz, points) + else: + new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) + # new_xyz: sampled points position data, [B, npoint, C] + # new_points: sampled points data, [B, npoint, nsample, C+D] + new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] + for i, conv in enumerate(self.mlp_convs): + if self.bn: + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + else: + new_points = F.relu(conv(new_points)) + + new_points = torch.max(new_points, 2)[0] + new_xyz = new_xyz.permute(0, 2, 1) + return new_xyz, new_points + + +class PointNetSetAbstractionMsg(nn.Module): + def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): + super(PointNetSetAbstractionMsg, self).__init__() + self.npoint = npoint + self.radius_list = radius_list + self.nsample_list = nsample_list + self.conv_blocks = nn.ModuleList() + self.bn_blocks = nn.ModuleList() + for i in range(len(mlp_list)): + convs = nn.ModuleList() + bns = nn.ModuleList() + last_channel = in_channel + 3 + for out_channel in mlp_list[i]: + convs.append(nn.Conv2d(last_channel, out_channel, 1)) + bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.conv_blocks.append(convs) + self.bn_blocks.append(bns) + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + B, N, C = xyz.shape + S = self.npoint + new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) + new_points_list = [] + for i, radius in enumerate(self.radius_list): + K = self.nsample_list[i] + group_idx = query_ball_point(radius, K, xyz, new_xyz) + grouped_xyz = index_points(xyz, group_idx) + grouped_xyz -= new_xyz.view(B, S, 1, C) + if points is not None: + grouped_points = index_points(points, group_idx) + grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) + else: + grouped_points = grouped_xyz + + grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] + for j in range(len(self.conv_blocks[i])): + conv = self.conv_blocks[i][j] + bn = self.bn_blocks[i][j] + grouped_points = F.relu(bn(conv(grouped_points))) + new_points = torch.max(grouped_points, 2)[0] # [B, D', S] + new_points_list.append(new_points) + + new_xyz = new_xyz.permute(0, 2, 1) + new_points_concat = torch.cat(new_points_list, dim=1) + return new_xyz, new_points_concat + + +class PointNetFeaturePropagation(nn.Module): + def __init__(self, in_channel, mlp): + super(PointNetFeaturePropagation, self).__init__() + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm1d(out_channel)) + last_channel = out_channel + + def forward(self, xyz1, xyz2, points1, points2): + """ + Input: + xyz1: input points position data, [B, C, N] + xyz2: sampled input points position data, [B, C, S] + points1: input points data, [B, D, N] + points2: input points data, [B, D, S] + Return: + new_points: upsampled points data, [B, D', N] + """ + xyz1 = xyz1.permute(0, 2, 1) + xyz2 = xyz2.permute(0, 2, 1) + + points2 = points2.permute(0, 2, 1) + B, N, C = xyz1.shape + _, S, _ = xyz2.shape + + if S == 1: + interpolated_points = points2.repeat(1, N, 1) + else: + dists = square_distance(xyz1, xyz2) + dists, idx = dists.sort(dim=-1) + dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] + + dist_recip = 1.0 / (dists + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) + + if points1 is not None: + points1 = points1.permute(0, 2, 1) + new_points = torch.cat([points1, interpolated_points], dim=-1) + else: + new_points = interpolated_points + + new_points = new_points.permute(0, 2, 1) + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + return new_points + +class PointNet2Encoder(nn.Module): + def __init__(self, in_channels=3, use_bn=False): + super(PointNet2Encoder, self).__init__() + + # self.sa1 = PointNetSetAbstractionMsg(50, [0.02, 0.04, 0.08], [8, 16, 64], in_channel - 3,[[32, 32, 64], [64, 64, 128], [64, 96, 128]]) + # self.sa2 = PointNetSetAbstractionMsg(10, [0.04, 0.08, 0.16], [16, 32, 64], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]]) + # self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) + self.sa1 = PointNetSetAbstraction(npoint=64, radius=0.04, nsample=16, in_channel=in_channels, mlp=[64, 64, 128], group_all=False) + self.sa2 = PointNetSetAbstraction(npoint=16, radius=0.08, nsample=32, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) + self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) + self.fc1 = nn.Linear(1024, 512) + self.bn = use_bn + if self.bn: + self.bn1 = nn.BatchNorm1d(512) + self.drop1 = nn.Dropout(0.4) + self.fc2 = nn.Linear(512, 256) + if self.bn: + self.bn2 = nn.BatchNorm1d(256) + + def forward(self, xyz): + B, D, N = xyz.size() + if D > 3: + norm = xyz[:, 3:, :] + xyz = xyz[:, :3, :] + else: + norm = None + l1_xyz, l1_points = self.sa1(xyz, norm) + l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) + l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) + x = l3_points.view(B, 1024) + if self.bn: + x = self.drop1(F.relu(self.bn1(self.fc1(x)))) + x = self.bn2(self.fc2(x)) + else: + x = self.drop1(F.relu(self.fc1(x))) + x = self.fc2(x) + + return x diff --git a/robomimic/models/pointnet_utils.py b/robomimic/models/pointnet_utils.py new file mode 100644 index 00000000..64b6ed7a --- /dev/null +++ b/robomimic/models/pointnet_utils.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F + + +class STN3d(nn.Module): + def __init__(self, channel): + super(STN3d, self).__init__() + self.conv1 = torch.nn.Conv1d(channel, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, 9) + self.relu = nn.ReLU() + + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(128) + self.bn3 = nn.BatchNorm1d(1024) + self.bn4 = nn.BatchNorm1d(512) + self.bn5 = nn.BatchNorm1d(256) + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.bn4(self.fc1(x))) + x = F.relu(self.bn5(self.fc2(x))) + x = self.fc3(x) + + iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( + batchsize, 1) + if x.is_cuda: + iden = iden.cuda() + x = x + iden + x = x.view(-1, 3, 3) + return x + + +class STNkd(nn.Module): + def __init__(self, k=64): + super(STNkd, self).__init__() + self.conv1 = torch.nn.Conv1d(k, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k * k) + self.relu = nn.ReLU() + + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(128) + self.bn3 = nn.BatchNorm1d(1024) + self.bn4 = nn.BatchNorm1d(512) + self.bn5 = nn.BatchNorm1d(256) + + self.k = k + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.bn4(self.fc1(x))) + x = F.relu(self.bn5(self.fc2(x))) + x = self.fc3(x) + + iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat( + batchsize, 1) + if x.is_cuda: + iden = iden.cuda() + x = x + iden + x = x.view(-1, self.k, self.k) + return x + + +class PointNetEncoder(nn.Module): + def __init__(self, global_feat=True, feature_transform=False, channel=3): + super(PointNetEncoder, self).__init__() + self.stn = STN3d(channel) + self.conv1 = torch.nn.Conv1d(channel, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(128) + self.bn3 = nn.BatchNorm1d(1024) + self.global_feat = global_feat + self.feature_transform = feature_transform + if self.feature_transform: + self.fstn = STNkd(k=64) + + def forward(self, x): + B, D, N = x.size() + trans = self.stn(x) + x = x.transpose(2, 1) + if D > 3: + feature = x[:, :, 3:] + x = x[:, :, :3] + x = torch.bmm(x, trans) + if D > 3: + x = torch.cat([x, feature], dim=2) + x = x.transpose(2, 1) + x = F.relu(self.bn1(self.conv1(x))) + + if self.feature_transform: + trans_feat = self.fstn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans_feat) + x = x.transpose(2, 1) + else: + trans_feat = None + + pointfeat = x + x = F.relu(self.bn2(self.conv2(x))) + x = self.bn3(self.conv3(x)) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + if self.global_feat: + return x, trans, trans_feat + else: + x = x.view(-1, 1024, 1).repeat(1, 1, N) + return torch.cat([x, pointfeat], 1), trans, trans_feat + + +def feature_transform_reguliarzer(trans): + d = trans.size()[1] + I = torch.eye(d)[None, :, :] + if trans.is_cuda: + I = I.cuda() + loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))) + return loss \ No newline at end of file diff --git a/robomimic/scripts/dataset_states_to_obs.py b/robomimic/scripts/dataset_states_to_obs.py index eef2412a..8e215ebf 100644 --- a/robomimic/scripts/dataset_states_to_obs.py +++ b/robomimic/scripts/dataset_states_to_obs.py @@ -149,6 +149,7 @@ def extract_trajectory( def dataset_states_to_obs(args): # create environment to use for data processing env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset) + env_meta["env_kwargs"]["camera_depths"] = args.camera_depths env = EnvUtils.create_env_for_data_processing( env_meta=env_meta, camera_names=args.camera_names, @@ -320,6 +321,13 @@ def dataset_states_to_obs(args): action='store_true', help="(optional) copy rewards from source file instead of inferring them", ) + + # flag for rendering camera depths + parser.add_argument( + "--camera_depths", + action='store_true', + help="(optional) copy rewards from source file instead of inferring them", + ) # flag for copying dones from source file instead of re-writing them parser.add_argument(