Skip to content

Commit

Permalink
synthia temporal, network cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Nov 5, 2019
1 parent 1fa6da7 commit 36d67fe
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 236 deletions.
18 changes: 14 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def add_argument_group(name):

# Network
net_arg = add_argument_group('Network')
net_arg.add_argument(
'--model', type=str, default='ResUNet14', help='Model name')
net_arg.add_argument('--model', type=str, default='ResUNet14', help='Model name')
net_arg.add_argument(
'--conv1_kernel_size', type=int, default=3, help='First layer conv kernel size')
net_arg.add_argument('--weights', type=str, default='None', help='Saved weights to load')
Expand Down Expand Up @@ -117,8 +116,16 @@ def add_argument_group(name):
data_arg.add_argument(
'--synthia_path',
type=str,
default='/home/chrischoy/datasets/synthia_preprocessed',
default='/home/chrischoy/datasets/Synthia/Synthia4D',
help='Point Cloud dataset root dir')
# For temporal sequences
data_arg.add_argument(
'--synthia_camera_path', type=str, default='/home/chrischoy/datasets/Synthia/%s/CameraParams/')
data_arg.add_argument('--synthia_camera_intrinsic_file', type=str, default='intrinsics.txt')
data_arg.add_argument(
'--synthia_camera_extrinsics_file', type=str, default='Stereo_Right/Omni_F/%s.txt')
data_arg.add_argument('--temporal_rand_dilation', type=str2bool, default=False)
data_arg.add_argument('--temporal_rand_numseq', type=str2bool, default=False)

data_arg.add_argument(
'--scannet_path',
Expand Down Expand Up @@ -179,7 +186,10 @@ def add_argument_group(name):
data_aug_arg.add_argument(
'--data_aug_hue_max', type=float, default=0.5, help='Hue translation range. [0, 1]')
data_aug_arg.add_argument(
'--data_aug_saturation_max', type=float, default=0.20, help='Saturation translation range, [0, 1]')
'--data_aug_saturation_max',
type=float,
default=0.20,
help='Saturation translation range, [0, 1]')

# Test
test_arg = add_argument_group('Test')
Expand Down
26 changes: 18 additions & 8 deletions lib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class VoxelizationDataset(VoxelizationDatasetBase):
ELASTIC_DISTORT_PARAMS = None

# MISC.
PREVOXELIZE_VOXEL_SIZE = None
PREVOXELIZATION_VOXEL_SIZE = None

def __init__(self,
data_paths,
Expand Down Expand Up @@ -252,9 +252,9 @@ def __getitem__(self, index):
pointcloud, center = self.load_ply(index)

# Downsample the pointcloud with finer voxel size before transformation for memory and speed
if self.PREVOXELIZE_VOXEL_SIZE is not None:
if self.PREVOXELIZATION_VOXEL_SIZE is not None:
inds = ME.utils.sparse_quantize(
pointcloud[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True)
pointcloud[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
pointcloud = pointcloud[inds]

# Prevoxel transformations
Expand Down Expand Up @@ -296,9 +296,18 @@ def __init__(self,
augment_data=False,
config=None,
**kwargs):
VoxelizationDataset.__init__(self, data_paths, input_transform, target_transform, data_root,
ignore_label, return_transformation, augment_data, config,
**kwargs)
VoxelizationDataset.__init__(
self,
data_paths,
prevoxel_transform=prevoxel_transform,
input_transform=input_transform,
target_transform=target_transform,
data_root=data_root,
ignore_label=ignore_label,
return_transformation=return_transformation,
augment_data=augment_data,
config=config,
**kwargs)
self.temporal_dilation = temporal_dilation
self.temporal_numseq = temporal_numseq
temporal_window = temporal_dilation * (temporal_numseq - 1) + 1
Expand Down Expand Up @@ -333,10 +342,11 @@ def __getitem__(self, index):
ptcs, centers = zip(*world_pointclouds)

# Downsample pointcloud for speed and memory
if self.PREVOXELIZE_VOXEL_SIZE is not None:
if self.PREVOXELIZATION_VOXEL_SIZE is not None:
new_ptcs = []
for ptc in ptcs:
inds = ME.utils.sparse_quantize(ptc[:, :3] / self.PREVOXELIZE_VOXEL_SIZE, return_index=True)
inds = ME.utils.sparse_quantize(
ptc[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
new_ptcs.append(ptc[inds])
ptcs = new_ptcs

Expand Down
5 changes: 3 additions & 2 deletions lib/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .synthia import SynthiaCVPR15cmVoxelizationDataset, SynthiaCVPR30cmVoxelizationDataset, \
SynthiaAllSequencesVoxelizationDataset
SynthiaAllSequencesVoxelizationDataset, SynthiaTemporalVoxelizationDataset
from .stanford import StanfordVoxelizationDataset, StanfordVoxelization2cmDataset
from .scannet import ScannetVoxelizationDataset, ScannetVoxelization2cmDataset

DATASETS = [
StanfordVoxelizationDataset, StanfordVoxelization2cmDataset, ScannetVoxelizationDataset,
ScannetVoxelization2cmDataset, SynthiaCVPR15cmVoxelizationDataset,
SynthiaCVPR30cmVoxelizationDataset, SynthiaAllSequencesVoxelizationDataset
SynthiaCVPR30cmVoxelizationDataset, SynthiaTemporalVoxelizationDataset,
SynthiaAllSequencesVoxelizationDataset
]


Expand Down
18 changes: 12 additions & 6 deletions lib/datasets/synthia.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class SynthiaVoxelizationDataset(VoxelizationDataset):
TEST_CLIP_BOUND = ((-2500, 2500), (-2500, 2500), (-2500, 2500))
VOXEL_SIZE = 15 # cm

PREVOXELIZE_VOXEL_SIZE = 7.5
PREVOXELIZATION_VOXEL_SIZE = 7.5
# Elastic distortion, (granularity, magitude) pairs
ELASTIC_DISTORT_PARAMS = ((80, 300),)

Expand Down Expand Up @@ -166,7 +166,7 @@ class SynthiaTemporalVoxelizationDataset(TemporalVoxelizationDataset):
TEST_CLIP_BOUND = ((-2500, 2500), (-2500, 2500), (-2500, 2500))
VOXEL_SIZE = 15 # cm

PREVOXELIZE_VOXEL_SIZE = 7.5
PREVOXELIZATION_VOXEL_SIZE = 7.5
# For temporal sequences, the voxel locations has to be aligned exactly.
ELASTIC_DISTORT_PARAMS = None

Expand All @@ -179,21 +179,27 @@ class SynthiaTemporalVoxelizationDataset(TemporalVoxelizationDataset):
NUM_LABELS = 16 # Automatically subtract ignore labels after processed
IGNORE_LABELS = (0, 1, 13, 14) # void, sky, reserved, reserved

# Split used in the Minkowski ConvNet, CVPR'19
DATA_PATH_FILE = {
DatasetPhase.Train: 'train_cvpr19.txt',
DatasetPhase.Val: 'val_cvpr19.txt',
DatasetPhase.Test: 'test_cvpr19.txt'
}

def __init__(self,
config,
prevoxel_transform=None,
input_transform=None,
target_transform=None,
augment_data=True,
elastic_distortion=False,
cache=False,
phase=DatasetPhase.Train):
if isinstance(phase, str):
phase = str2datasetphase_type(phase)
if phase not in [DatasetPhase.Train, DatasetPhase.TrainVal]:
self.CLIP_BOUND = self.TEST_CLIP_BOUND
data_root = config.synthia_path
data_paths = read_txt(osp.join(data_root, self.DATA_PATH_FILE[phase]))
data_paths = read_txt(osp.join('./splits/synthia4d', self.DATA_PATH_FILE[phase]))
data_paths = sorted([d.split()[0] for d in data_paths])
seq2files = defaultdict(list)
for f in data_paths:
Expand All @@ -211,15 +217,15 @@ def __init__(self,
TemporalVoxelizationDataset.__init__(
self,
file_seq_list,
data_root=data_root,
prevoxel_transform=prevoxel_transform,
input_transform=input_transform,
target_transform=target_transform,
data_root=data_root,
ignore_label=config.ignore_label,
temporal_dilation=config.temporal_dilation,
temporal_numseq=config.temporal_numseq,
return_transformation=config.return_transformation,
augment_data=augment_data,
elastic_distortion=elastic_distortion,
config=config)

def load_world_pointcloud(self, filename):
Expand Down
12 changes: 0 additions & 12 deletions lib/pc_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
import logging
import numpy as np
from numpy.linalg import matrix_rank, inv
from plyfile import PlyData, PlyElement
import pandas as pd
from retrying import retry

COLOR_MAP_RGB = (
(241, 255, 82),
Expand All @@ -27,16 +25,6 @@
IGNORE_COLOR = (0, 0, 0)


def retry_on_ioerror(exc):
logging.warning("Retrying file load")
return isinstance(exc, IOError)


@retry(
retry_on_exception=retry_on_ioerror,
wait_exponential_multiplier=1000,
wait_exponential_max=10000,
stop_max_delay=30000)
def read_plyfile(filepath):
"""Read ply file and return it as numpy array. Returns None if emtpy."""
with open(filepath, 'rb') as f:
Expand Down
57 changes: 57 additions & 0 deletions lib/voxelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,63 @@ def voxelize(self, coords, feats, labels, center=None):

return coords_aug, feats, labels, rigid_transformation.flatten()

def voxelize_temporal(self,
coords_t,
feats_t,
labels_t,
centers=None,
return_transformation=False):
# Legacy code, remove
if centers is None:
centers = [None, ] * len(coords_t)
coords_tc, feats_tc, labels_tc, transformation_tc = [], [], [], []

# ######################### Data Augmentation #############################
# Get rotation and scale
M_v, M_r = self.get_transformation_matrix()
# Apply transformations
rigid_transformation = M_v
if self.use_augmentation:
rigid_transformation = M_r @ rigid_transformation
# ######################### Voxelization #############################
# Voxelize coords
for coords, feats, labels, center in zip(coords_t, feats_t, labels_t, centers):

###################################
# Clip the data if bound exists
if self.clip_bound is not None:
trans_aug_ratio = np.zeros(3)
if self.use_augmentation and self.translation_augmentation_ratio_bound is not None:
for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound):
trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound)

clip_inds = self.clip(coords, center, trans_aug_ratio)
coords, feats = coords[clip_inds], feats[clip_inds]
if labels is not None:
labels = labels[clip_inds]
###################################

homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype)))
coords_aug = np.floor(homo_coords @ rigid_transformation.T)[:, :3]

inds = ME.utils.sparse_quantize(coords_aug, return_index=True)
coords_aug, feats, labels = coords_aug[inds], feats[inds], labels[inds]

# If use normal rotation
if feats.shape[1] > 6:
feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T)

coords_tc.append(coords_aug)
feats_tc.append(feats)
labels_tc.append(labels)
transformation_tc.append(rigid_transformation.flatten())

return_args = [coords_tc, feats_tc, labels_tc]
if return_transformation:
return_args.append(transformation_tc)

return tuple(return_args)


def test():
N = 16575
Expand Down
67 changes: 5 additions & 62 deletions models/res16unet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from models.resnet import ResNetBase, get_norm
from models.modules.common import ConvType, NormType, conv, conv_tr
from models.modules.resnet_block import BasicBlock, Bottleneck, BasicBlockIN, BottleneckIN, BasicBlockLN
from models.modules.resnet_block import BasicBlock, Bottleneck

from MinkowskiEngine import MinkowskiReLU
import MinkowskiEngine.MinkowskiOps as me
Expand Down Expand Up @@ -331,67 +331,6 @@ class Res16UNet34C(Res16UNet34):
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)


# Experimentally, worse than others
class Res16UNetLN14(Res16UNet14):
NORM_TYPE = NormType.SPARSE_LAYER_NORM
BLOCK = BasicBlockLN


class Res16UNetTemporalBase(Res16UNetBase):
"""
Res16UNet that can take 4D independently. No temporal convolution.
"""
CONV_TYPE = ConvType.SPATIAL_HYPERCUBE

def __init__(self, in_channels, out_channels, config, D=4, **kwargs):
super(Res16UNetTemporalBase, self).__init__(in_channels, out_channels, config, D, **kwargs)


class Res16UNetTemporal14(Res16UNet14, Res16UNetTemporalBase):
pass


class Res16UNetTemporal18(Res16UNet18, Res16UNetTemporalBase):
pass


class Res16UNetTemporal34(Res16UNet34, Res16UNetTemporalBase):
pass


class Res16UNetTemporal50(Res16UNet50, Res16UNetTemporalBase):
pass


class Res16UNetTemporal101(Res16UNet101, Res16UNetTemporalBase):
pass


class Res16UNetTemporalIN14(Res16UNetTemporal14):
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
BLOCK = BasicBlockIN


class Res16UNetTemporalIN18(Res16UNetTemporal18):
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
BLOCK = BasicBlockIN


class Res16UNetTemporalIN34(Res16UNetTemporal34):
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
BLOCK = BasicBlockIN


class Res16UNetTemporalIN50(Res16UNetTemporal50):
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
BLOCK = BottleneckIN


class Res16UNetTemporalIN101(Res16UNetTemporal101):
NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
BLOCK = BottleneckIN


class STRes16UNetBase(Res16UNetBase):

CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS
Expand All @@ -404,6 +343,10 @@ class STRes16UNet14(STRes16UNetBase, Res16UNet14):
pass


class STRes16UNet14A(STRes16UNetBase, Res16UNet14A):
pass


class STRes16UNet18(STRes16UNetBase, Res16UNet18):
pass

Expand Down
Loading

0 comments on commit 36d67fe

Please sign in to comment.