diff --git a/config.py b/config.py index a54c80d..b37ff9a 100644 --- a/config.py +++ b/config.py @@ -30,8 +30,7 @@ def add_argument_group(name): # Network net_arg = add_argument_group('Network') -net_arg.add_argument( - '--model', type=str, default='ResUNet14', help='Model name') +net_arg.add_argument('--model', type=str, default='ResUNet14', help='Model name') net_arg.add_argument( '--conv1_kernel_size', type=int, default=3, help='First layer conv kernel size') net_arg.add_argument('--weights', type=str, default='None', help='Saved weights to load') @@ -117,8 +116,16 @@ def add_argument_group(name): data_arg.add_argument( '--synthia_path', type=str, - default='/home/chrischoy/datasets/synthia_preprocessed', + default='/home/chrischoy/datasets/Synthia/Synthia4D', help='Point Cloud dataset root dir') +# For temporal sequences +data_arg.add_argument( + '--synthia_camera_path', type=str, default='/home/chrischoy/datasets/Synthia/%s/CameraParams/') +data_arg.add_argument('--synthia_camera_intrinsic_file', type=str, default='intrinsics.txt') +data_arg.add_argument( + '--synthia_camera_extrinsics_file', type=str, default='Stereo_Right/Omni_F/%s.txt') +data_arg.add_argument('--temporal_rand_dilation', type=str2bool, default=False) +data_arg.add_argument('--temporal_rand_numseq', type=str2bool, default=False) data_arg.add_argument( '--scannet_path', @@ -179,7 +186,10 @@ def add_argument_group(name): data_aug_arg.add_argument( '--data_aug_hue_max', type=float, default=0.5, help='Hue translation range. [0, 1]') data_aug_arg.add_argument( - '--data_aug_saturation_max', type=float, default=0.20, help='Saturation translation range, [0, 1]') + '--data_aug_saturation_max', + type=float, + default=0.20, + help='Saturation translation range, [0, 1]') # Test test_arg = add_argument_group('Test') diff --git a/lib/dataset.py b/lib/dataset.py index b12e0c7..8cb7984 100644 --- a/lib/dataset.py +++ b/lib/dataset.py @@ -194,7 +194,7 @@ class VoxelizationDataset(VoxelizationDatasetBase): ELASTIC_DISTORT_PARAMS = None # MISC. - PREVOXELIZE_VOXEL_SIZE = None + PREVOXELIZATION_VOXEL_SIZE = None def __init__(self, data_paths, @@ -252,9 +252,9 @@ def __getitem__(self, index): pointcloud, center = self.load_ply(index) # Downsample the pointcloud with finer voxel size before transformation for memory and speed - if self.PREVOXELIZE_VOXEL_SIZE is not None: + if self.PREVOXELIZATION_VOXEL_SIZE is not None: inds = ME.utils.sparse_quantize( - pointcloud[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True) + pointcloud[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True) pointcloud = pointcloud[inds] # Prevoxel transformations @@ -296,9 +296,18 @@ def __init__(self, augment_data=False, config=None, **kwargs): - VoxelizationDataset.__init__(self, data_paths, input_transform, target_transform, data_root, - ignore_label, return_transformation, augment_data, config, - **kwargs) + VoxelizationDataset.__init__( + self, + data_paths, + prevoxel_transform=prevoxel_transform, + input_transform=input_transform, + target_transform=target_transform, + data_root=data_root, + ignore_label=ignore_label, + return_transformation=return_transformation, + augment_data=augment_data, + config=config, + **kwargs) self.temporal_dilation = temporal_dilation self.temporal_numseq = temporal_numseq temporal_window = temporal_dilation * (temporal_numseq - 1) + 1 @@ -333,10 +342,11 @@ def __getitem__(self, index): ptcs, centers = zip(*world_pointclouds) # Downsample pointcloud for speed and memory - if self.PREVOXELIZE_VOXEL_SIZE is not None: + if self.PREVOXELIZATION_VOXEL_SIZE is not None: new_ptcs = [] for ptc in ptcs: - inds = ME.utils.sparse_quantize(ptc[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True) + inds = ME.utils.sparse_quantize( + ptc[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True) new_ptcs.append(ptc[inds]) ptcs = new_ptcs diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index dce4449..cdd0eb0 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -1,12 +1,13 @@ from .synthia import SynthiaCVPR15cmVoxelizationDataset, SynthiaCVPR30cmVoxelizationDataset, \ - SynthiaAllSequencesVoxelizationDataset + SynthiaAllSequencesVoxelizationDataset, SynthiaTemporalVoxelizationDataset from .stanford import StanfordVoxelizationDataset, StanfordVoxelization2cmDataset from .scannet import ScannetVoxelizationDataset, ScannetVoxelization2cmDataset DATASETS = [ StanfordVoxelizationDataset, StanfordVoxelization2cmDataset, ScannetVoxelizationDataset, ScannetVoxelization2cmDataset, SynthiaCVPR15cmVoxelizationDataset, - SynthiaCVPR30cmVoxelizationDataset, SynthiaAllSequencesVoxelizationDataset + SynthiaCVPR30cmVoxelizationDataset, SynthiaTemporalVoxelizationDataset, + SynthiaAllSequencesVoxelizationDataset ] diff --git a/lib/datasets/synthia.py b/lib/datasets/synthia.py index 9fbbeb1..06c1e5d 100644 --- a/lib/datasets/synthia.py +++ b/lib/datasets/synthia.py @@ -110,7 +110,7 @@ class SynthiaVoxelizationDataset(VoxelizationDataset): TEST_CLIP_BOUND = ((-2500, 2500), (-2500, 2500), (-2500, 2500)) VOXEL_SIZE = 15 # cm - PREVOXELIZE_VOXEL_SIZE = 7.5 + PREVOXELIZATION_VOXEL_SIZE = 7.5 # Elastic distortion, (granularity, magitude) pairs ELASTIC_DISTORT_PARAMS = ((80, 300),) @@ -166,7 +166,7 @@ class SynthiaTemporalVoxelizationDataset(TemporalVoxelizationDataset): TEST_CLIP_BOUND = ((-2500, 2500), (-2500, 2500), (-2500, 2500)) VOXEL_SIZE = 15 # cm - PREVOXELIZE_VOXEL_SIZE = 7.5 + PREVOXELIZATION_VOXEL_SIZE = 7.5 # For temporal sequences, the voxel locations has to be aligned exactly. ELASTIC_DISTORT_PARAMS = None @@ -179,13 +179,19 @@ class SynthiaTemporalVoxelizationDataset(TemporalVoxelizationDataset): NUM_LABELS = 16 # Automatically subtract ignore labels after processed IGNORE_LABELS = (0, 1, 13, 14) # void, sky, reserved, reserved + # Split used in the Minkowski ConvNet, CVPR'19 + DATA_PATH_FILE = { + DatasetPhase.Train: 'train_cvpr19.txt', + DatasetPhase.Val: 'val_cvpr19.txt', + DatasetPhase.Test: 'test_cvpr19.txt' + } + def __init__(self, config, prevoxel_transform=None, input_transform=None, target_transform=None, augment_data=True, - elastic_distortion=False, cache=False, phase=DatasetPhase.Train): if isinstance(phase, str): @@ -193,7 +199,7 @@ def __init__(self, if phase not in [DatasetPhase.Train, DatasetPhase.TrainVal]: self.CLIP_BOUND = self.TEST_CLIP_BOUND data_root = config.synthia_path - data_paths = read_txt(osp.join(data_root, self.DATA_PATH_FILE[phase])) + data_paths = read_txt(osp.join('./splits/synthia4d', self.DATA_PATH_FILE[phase])) data_paths = sorted([d.split()[0] for d in data_paths]) seq2files = defaultdict(list) for f in data_paths: @@ -211,15 +217,15 @@ def __init__(self, TemporalVoxelizationDataset.__init__( self, file_seq_list, - data_root=data_root, + prevoxel_transform=prevoxel_transform, input_transform=input_transform, target_transform=target_transform, + data_root=data_root, ignore_label=config.ignore_label, temporal_dilation=config.temporal_dilation, temporal_numseq=config.temporal_numseq, return_transformation=config.return_transformation, augment_data=augment_data, - elastic_distortion=elastic_distortion, config=config) def load_world_pointcloud(self, filename): diff --git a/lib/pc_utils.py b/lib/pc_utils.py index cb382b1..77a3212 100644 --- a/lib/pc_utils.py +++ b/lib/pc_utils.py @@ -1,10 +1,8 @@ import os -import logging import numpy as np from numpy.linalg import matrix_rank, inv from plyfile import PlyData, PlyElement import pandas as pd -from retrying import retry COLOR_MAP_RGB = ( (241, 255, 82), @@ -27,16 +25,6 @@ IGNORE_COLOR = (0, 0, 0) -def retry_on_ioerror(exc): - logging.warning("Retrying file load") - return isinstance(exc, IOError) - - -@retry( - retry_on_exception=retry_on_ioerror, - wait_exponential_multiplier=1000, - wait_exponential_max=10000, - stop_max_delay=30000) def read_plyfile(filepath): """Read ply file and return it as numpy array. Returns None if emtpy.""" with open(filepath, 'rb') as f: diff --git a/lib/voxelizer.py b/lib/voxelizer.py index a7a87de..3ef059f 100644 --- a/lib/voxelizer.py +++ b/lib/voxelizer.py @@ -132,6 +132,63 @@ def voxelize(self, coords, feats, labels, center=None): return coords_aug, feats, labels, rigid_transformation.flatten() + def voxelize_temporal(self, + coords_t, + feats_t, + labels_t, + centers=None, + return_transformation=False): + # Legacy code, remove + if centers is None: + centers = [None, ] * len(coords_t) + coords_tc, feats_tc, labels_tc, transformation_tc = [], [], [], [] + + # ######################### Data Augmentation ############################# + # Get rotation and scale + M_v, M_r = self.get_transformation_matrix() + # Apply transformations + rigid_transformation = M_v + if self.use_augmentation: + rigid_transformation = M_r @ rigid_transformation + # ######################### Voxelization ############################# + # Voxelize coords + for coords, feats, labels, center in zip(coords_t, feats_t, labels_t, centers): + + ################################### + # Clip the data if bound exists + if self.clip_bound is not None: + trans_aug_ratio = np.zeros(3) + if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: + for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound): + trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound) + + clip_inds = self.clip(coords, center, trans_aug_ratio) + coords, feats = coords[clip_inds], feats[clip_inds] + if labels is not None: + labels = labels[clip_inds] + ################################### + + homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype))) + coords_aug = np.floor(homo_coords @ rigid_transformation.T)[:, :3] + + inds = ME.utils.sparse_quantize(coords_aug, return_index=True) + coords_aug, feats, labels = coords_aug[inds], feats[inds], labels[inds] + + # If use normal rotation + if feats.shape[1] > 6: + feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T) + + coords_tc.append(coords_aug) + feats_tc.append(feats) + labels_tc.append(labels) + transformation_tc.append(rigid_transformation.flatten()) + + return_args = [coords_tc, feats_tc, labels_tc] + if return_transformation: + return_args.append(transformation_tc) + + return tuple(return_args) + def test(): N = 16575 diff --git a/models/res16unet.py b/models/res16unet.py index 7006422..765fab2 100644 --- a/models/res16unet.py +++ b/models/res16unet.py @@ -1,6 +1,6 @@ from models.resnet import ResNetBase, get_norm from models.modules.common import ConvType, NormType, conv, conv_tr -from models.modules.resnet_block import BasicBlock, Bottleneck, BasicBlockIN, BottleneckIN, BasicBlockLN +from models.modules.resnet_block import BasicBlock, Bottleneck from MinkowskiEngine import MinkowskiReLU import MinkowskiEngine.MinkowskiOps as me @@ -331,67 +331,6 @@ class Res16UNet34C(Res16UNet34): PLANES = (32, 64, 128, 256, 256, 128, 96, 96) -# Experimentally, worse than others -class Res16UNetLN14(Res16UNet14): - NORM_TYPE = NormType.SPARSE_LAYER_NORM - BLOCK = BasicBlockLN - - -class Res16UNetTemporalBase(Res16UNetBase): - """ - Res16UNet that can take 4D independently. No temporal convolution. - """ - CONV_TYPE = ConvType.SPATIAL_HYPERCUBE - - def __init__(self, in_channels, out_channels, config, D=4, **kwargs): - super(Res16UNetTemporalBase, self).__init__(in_channels, out_channels, config, D, **kwargs) - - -class Res16UNetTemporal14(Res16UNet14, Res16UNetTemporalBase): - pass - - -class Res16UNetTemporal18(Res16UNet18, Res16UNetTemporalBase): - pass - - -class Res16UNetTemporal34(Res16UNet34, Res16UNetTemporalBase): - pass - - -class Res16UNetTemporal50(Res16UNet50, Res16UNetTemporalBase): - pass - - -class Res16UNetTemporal101(Res16UNet101, Res16UNetTemporalBase): - pass - - -class Res16UNetTemporalIN14(Res16UNetTemporal14): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class Res16UNetTemporalIN18(Res16UNetTemporal18): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class Res16UNetTemporalIN34(Res16UNetTemporal34): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class Res16UNetTemporalIN50(Res16UNetTemporal50): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - -class Res16UNetTemporalIN101(Res16UNetTemporal101): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - class STRes16UNetBase(Res16UNetBase): CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS @@ -404,6 +343,10 @@ class STRes16UNet14(STRes16UNetBase, Res16UNet14): pass +class STRes16UNet14A(STRes16UNetBase, Res16UNet14A): + pass + + class STRes16UNet18(STRes16UNetBase, Res16UNet18): pass diff --git a/models/resunet.py b/models/resunet.py index 2f15dfc..012e19b 100644 --- a/models/resunet.py +++ b/models/resunet.py @@ -1,6 +1,6 @@ from models.resnet import ResNetBase, get_norm from models.modules.common import ConvType, NormType, conv, conv_tr -from models.modules.resnet_block import BasicBlock, Bottleneck, BasicBlockSN, BottleneckSN, BasicBlockIN, BottleneckIN, BasicBlockLN +from models.modules.resnet_block import BasicBlock, Bottleneck from MinkowskiEngine import MinkowskiReLU import MinkowskiEngine.MinkowskiOps as me @@ -257,122 +257,6 @@ class ResUNet34F(ResUNet34): PLANES = (32, 64, 128, 256, 128, 64, 32) -class ResUNetSN14(ResUNet14): - NORM_TYPE = NormType.SPARSE_SWITCH_NORM - BLOCK = BasicBlockSN - - -class ResUNetSN18(ResUNet18): - NORM_TYPE = NormType.SPARSE_SWITCH_NORM - BLOCK = BasicBlockSN - - -class ResUNetSN34(ResUNet34): - NORM_TYPE = NormType.SPARSE_SWITCH_NORM - BLOCK = BasicBlockSN - - -class ResUNetSN50(ResUNet50): - NORM_TYPE = NormType.SPARSE_SWITCH_NORM - BLOCK = BottleneckSN - - -class ResUNetSN101(ResUNet101): - NORM_TYPE = NormType.SPARSE_SWITCH_NORM - BLOCK = BottleneckSN - - -class ResUNetIN14(ResUNet14): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetIN18(ResUNet18): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetIN34(ResUNet34): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetIN34E(ResUNet34E): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetIN50(ResUNet50): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - -class ResUNetIN101(ResUNet101): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - -# Experimentally, worse than others -class ResUNetLN14(ResUNet14): - NORM_TYPE = NormType.SPARSE_LAYER_NORM - BLOCK = BasicBlockLN - - -class ResUNetTemporalBase(ResUNetBase): - """ - ResUNet that can take 4D independently. No temporal convolution. - """ - CONV_TYPE = ConvType.SPATIAL_HYPERCUBE - - def __init__(self, in_channels, out_channels, config, D=4, **kwargs): - super(ResUNetTemporalBase, self).__init__(in_channels, out_channels, config, D, **kwargs) - - -class ResUNetTemporal14(ResUNet14, ResUNetTemporalBase): - pass - - -class ResUNetTemporal18(ResUNet18, ResUNetTemporalBase): - pass - - -class ResUNetTemporal34(ResUNet34, ResUNetTemporalBase): - pass - - -class ResUNetTemporal50(ResUNet50, ResUNetTemporalBase): - pass - - -class ResUNetTemporal101(ResUNet101, ResUNetTemporalBase): - pass - - -class ResUNetTemporalIN14(ResUNetTemporal14): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetTemporalIN18(ResUNetTemporal18): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetTemporalIN34(ResUNetTemporal34): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class ResUNetTemporalIN50(ResUNetTemporal50): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - -class ResUNetTemporalIN101(ResUNetTemporal101): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - class STResUNetBase(ResUNetBase): CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS @@ -401,31 +285,6 @@ class STResUNet101(STResUNetBase, ResUNet101): pass -class STResUNetIN14(STResUNet14): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class STResUNetIN18(STResUNet18): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class STResUNetIN34(STResUNet34): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BasicBlockIN - - -class STResUNetIN50(STResUNet50): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - -class STResUNetIN101(STResUNet101): - NORM_TYPE = NormType.SPARSE_INSTANCE_NORM - BLOCK = BottleneckIN - - class STResTesseractUNetBase(STResUNetBase): CONV_TYPE = ConvType.HYPERCUBE diff --git a/scripts/train_synthia4d.sh b/scripts/train_synthia4d.sh new file mode 100755 index 0000000..880f1d1 --- /dev/null +++ b/scripts/train_synthia4d.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -x +# Exit script when a command returns nonzero state +set -e + +set -o pipefail + +export PYTHONUNBUFFERED="True" +export CUDA_VISIBLE_DEVICES=$1 + +export BATCH_SIZE=${BATCH_SIZE:-9} +export DATASET=${DATASET:-SynthiaCVPR15cmVoxelizationDataset} +export MODEL=${MODEL:-Res16UNet14A} + +export TIME=$(date +"%Y-%m-%d_%H-%M-%S") + +export LOG_DIR=./outputs/Synthia4D$2/$TIME + +# Save the experiment detail and dir to the common log file +mkdir -p $LOG_DIR + +LOG="$LOG_DIR/$TIME.txt" + +python main.py \ + --log_dir $LOG_DIR \ + --dataset $DATASET \ + --model $MODEL \ + --lr 1e-1 \ + --batch_size $BATCH_SIZE \ + --scheduler PolyLR \ + --max_iter 120000 \ + --train_limit_numpoints 1200000 \ + --train_phase train \ + $3 2>&1 | tee -a "$LOG" diff --git a/scripts/train_synthia4d_temporal.sh b/scripts/train_synthia4d_temporal.sh new file mode 100755 index 0000000..204a650 --- /dev/null +++ b/scripts/train_synthia4d_temporal.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -x +# Exit script when a command returns nonzero state +set -e + +set -o pipefail + +export PYTHONUNBUFFERED="True" +export CUDA_VISIBLE_DEVICES=$1 + +export BATCH_SIZE=${BATCH_SIZE:-9} +export DATASET=${DATASET:-SynthiaTemporalVoxelizationDataset} +export MODEL=${MODEL:-STRes16UNet14A} + +export TIME=$(date +"%Y-%m-%d_%H-%M-%S") + +export LOG_DIR=./outputs/$DATASET/$TIME + +# Save the experiment detail and dir to the common log file +mkdir -p $LOG_DIR + +LOG="$LOG_DIR/$TIME.txt" + +ipdb3 main.py \ + --log_dir $LOG_DIR \ + --dataset $DATASET \ + --model $MODEL \ + --lr 1e-1 \ + --batch_size $BATCH_SIZE \ + --scheduler PolyLR \ + --max_iter 120000 \ + --train_limit_numpoints 1500000 \ + --train_phase train \ + $3 2>&1 # | tee -a "$LOG"