Skip to content

Commit

Permalink
stanford training
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 14, 2020
1 parent c6e421a commit c4140c7
Show file tree
Hide file tree
Showing 23 changed files with 923 additions and 161 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,38 @@ export BATCH_SIZE=N; \
The above script trains a network. You have to change the arguments accordingly. The first argument to the script is the GPU id. Second argument is the log directory postfix; change to mark your experimental setup. The final argument is a series of the miscellaneous aruments. You have to specify the synthia directory here. Also, you have to wrap all arguments with " ".


## Stanford 3D Dataset

1. Download the stanford 3d dataset from [the website](http://buildingparser.stanford.edu/dataset.html)

2. Preprocess

Modify the input and output directory accordingly in

`lib/datasets/preprocessing/stanford.py`

And run

```
python -m lib.datasets.preprocessing.stanford
```

3. Train

Modify the stanford 3d path in the script and run

```
./scripts/train_stanford.sh 0 \
"-default" \
""
```

## Model Zoo

| Model | Dataset | Voxel Size | Conv1 Kernel Size | Performance | Link |
|:-------------:|:-------------------:|:----------:|:-----------------:|:------------------------:|:------:|
| Mink16UNet34C | ScanNet train + val | 2cm | 3 | Test set 73.6% mIoU | [download](https://node1.chrischoy.org/data/publications/minknet/Mink16UNet34C_ScanNet.pth) |
| Mink16UNet34C | ScanNet train | 2cm | 5 | Val 72.219% mIoU without rotation average [per class performance](https://github.com/chrischoy/SpatioTemporalSegmentation/issues/13) | [download](https://node1.chrischoy.org/data/publications/minknet/MinkUNet34C-train-conv1-5.pth) |
| Mink16UNet18 | Stanford Area5 train | 5cm | 5 | Area 5 test 65.483% mIoU w/o rotation average, no sliding window | [download](https://node1.chrischoy.org/data/publications/minknet/Mink16UNet18_stanford-conv1-5.pth) |

Note that sliding window style evaluation (cropping and stitching results) used in many related works effectively works as an ensemble (rotation averaging) which boosts the performance.
14 changes: 4 additions & 10 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def add_argument_group(name):

# Data
data_arg = add_argument_group('Data')
data_arg.add_argument('--dataset', type=str, default='ScannetSparseVoxelization2cmDataset')
data_arg.add_argument('--dataset', type=str, default='ScannetVoxelization2cmDataset')
data_arg.add_argument('--temporal_dilation', type=int, default=30)
data_arg.add_argument('--temporal_numseq', type=int, default=3)
data_arg.add_argument('--point_lim', type=int, default=-1)
Expand Down Expand Up @@ -134,15 +134,15 @@ def add_argument_group(name):
help='Scannet online voxelization dataset root dir')

data_arg.add_argument(
'--stanford3d_online_path',
'--stanford3d_path',
type=str,
default='/home/chrischoy/datasets/stanford_preprocessed',
default='/home/chrischoy/datasets/Stanford3D',
help='Stanford precropped dataset root dir')

# Training / test parameters
train_arg = add_argument_group('Training')
train_arg.add_argument('--is_train', type=str2bool, default=True)
train_arg.add_argument('--stat_freq', type=int, default=10, help='print frequency')
train_arg.add_argument('--stat_freq', type=int, default=40, help='print frequency')
train_arg.add_argument('--test_stat_freq', type=int, default=100, help='print frequency')
train_arg.add_argument('--save_freq', type=int, default=1000, help='save frequency')
train_arg.add_argument('--val_freq', type=int, default=1000, help='validation frequency')
Expand Down Expand Up @@ -174,12 +174,6 @@ def add_argument_group(name):
'--data_aug_color_trans_ratio', type=float, default=0.10, help='Color translation range')
data_aug_arg.add_argument(
'--data_aug_color_jitter_std', type=float, default=0.05, help='STD of color jitter')
data_aug_arg.add_argument(
'--data_aug_height_trans_std', type=float, default=1, help='STD of height translation')
data_aug_arg.add_argument(
'--data_aug_height_jitter_std', type=float, default=0.1, help='STD of height jitter')
data_aug_arg.add_argument(
'--data_aug_normal_jitter_std', type=float, default=0.01, help='STD of normal jitter')
data_aug_arg.add_argument('--normalize_color', type=str2bool, default=True)
data_aug_arg.add_argument('--data_aug_scale_min', type=float, default=0.9)
data_aug_arg.add_argument('--data_aug_scale_max', type=float, default=1.1)
Expand Down
1 change: 1 addition & 0 deletions lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import open3d as o3d
37 changes: 25 additions & 12 deletions lib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from enum import Enum

import torch
from torch.utils.data import Dataset, DataLoader

import MinkowskiEngine as ME
Expand Down Expand Up @@ -196,6 +197,9 @@ class VoxelizationDataset(VoxelizationDatasetBase):
# MISC.
PREVOXELIZATION_VOXEL_SIZE = None

# Augment coords to feats
AUGMENT_COORDS_TO_FEATS = False

def __init__(self,
data_paths,
prevoxel_transform=None,
Expand Down Expand Up @@ -244,24 +248,33 @@ def __init__(self,
self.label_map = label_map
self.NUM_LABELS -= len(self.IGNORE_LABELS)

def _augment_coords_to_feats(self, coords, feats, labels=None):
norm_coords = coords - coords.mean(0)
# color must come first.
if isinstance(coords, np.ndarray):
feats = np.concatenate((feats, norm_coords), 1)
else:
feats = torch.cat((feats, norm_coords), 1)
return coords, feats, labels

def convert_mat2cfl(self, mat):
# Generally, xyz,rgb,label
return mat[:, :3], mat[:, 3:-1], mat[:, -1]

def __getitem__(self, index):
pointcloud, center = self.load_ply(index)

coords, feats, labels, center = self.load_ply(index)
# Downsample the pointcloud with finer voxel size before transformation for memory and speed
if self.PREVOXELIZATION_VOXEL_SIZE is not None:
inds = ME.utils.sparse_quantize(
pointcloud[:, :3] / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
pointcloud = pointcloud[inds]
coords / self.PREVOXELIZATION_VOXEL_SIZE, return_index=True)
coords = coords[inds]
feats = feats[inds]
labels = labels[inds]

# Prevoxel transformations
if self.prevoxel_transform is not None:
pointcloud = self.prevoxel_transform(pointcloud)
coords, feats, labels = self.prevoxel_transform(coords, feats, labels)

coords, feats, labels = self.convert_mat2cfl(pointcloud)
coords, feats, labels, transformation = self.voxelizer.voxelize(
coords, feats, labels, center=center)

Expand All @@ -273,9 +286,14 @@ def __getitem__(self, index):
if self.IGNORE_LABELS is not None:
labels = np.array([self.label_map[x] for x in labels], dtype=np.int)

# Use coordinate features if config is set
if self.AUGMENT_COORDS_TO_FEATS:
coords, feats, labels = self._augment_coords_to_feats(coords, feats, labels)

return_args = [coords, feats, labels]
if self.return_transformation:
return_args.extend([pointcloud.astype(np.float32), transformation.astype(np.float32)])
return_args.append(transformation.astype(np.float32))

return tuple(return_args)


Expand Down Expand Up @@ -319,10 +337,6 @@ def __init__(self,
def load_world_pointcloud(self, filename):
raise NotImplementedError

def convert_mat2cfl(self, mat):
# Generally, xyz,rgb,label
return mat[:, :3], mat[:, 3:-1], mat[:, -1]

def __getitem__(self, index):
for seq_idx, numel in enumerate(self.numels):
if index >= numel:
Expand Down Expand Up @@ -353,7 +367,6 @@ def __getitem__(self, index):
# Apply prevoxel transformations
ptcs = [self.prevoxel_transform(ptc) for ptc in ptcs]

ptcs = [self.convert_mat2cfl(ptc) for ptc in ptcs]
coords, feats, labels = zip(*ptcs)
outs = self.voxelizer.voxelize_temporal(
coords, feats, labels, centers=centers, return_transformation=self.return_transformation)
Expand Down
25 changes: 14 additions & 11 deletions lib/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from .synthia import SynthiaCVPR15cmVoxelizationDataset, SynthiaCVPR30cmVoxelizationDataset, \
SynthiaAllSequencesVoxelizationDataset, SynthiaTemporalVoxelizationDataset
from .stanford import StanfordVoxelizationDataset, StanfordVoxelization2cmDataset
from .scannet import ScannetVoxelizationDataset, ScannetVoxelization2cmDataset

DATASETS = [
StanfordVoxelizationDataset, StanfordVoxelization2cmDataset, ScannetVoxelizationDataset,
ScannetVoxelization2cmDataset, SynthiaCVPR15cmVoxelizationDataset,
SynthiaCVPR30cmVoxelizationDataset, SynthiaTemporalVoxelizationDataset,
SynthiaAllSequencesVoxelizationDataset
]
import lib.datasets.synthia as synthia
import lib.datasets.stanford as stanford
import lib.datasets.scannet as scannet

DATASETS = []


def add_datasets(module):
DATASETS.extend([getattr(module, a) for a in dir(module) if 'Dataset' in a])


add_datasets(stanford)
add_datasets(synthia)
add_datasets(scannet)


def load_dataset(name):
Expand Down
67 changes: 31 additions & 36 deletions lib/datasets/preprocessing/scannet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from pathlib import Path
from random import shuffle

import numpy as np
import sys
from lib.pc_utils import read_plyfile, save_point_cloud
from concurrent.futures import ProcessPoolExecutor
SCANNET_RAW_PATH = Path('/path/ScanNet_data/')
Expand All @@ -19,47 +17,44 @@
print('start preprocess')
# Preprocess data.


def handle_process(path):
f = Path(path.split(',')[0])
phase_out_path = Path(path.split(',')[1])
pointcloud = read_plyfile(f)
# Make sure alpha value is meaningless.
assert np.unique(pointcloud[:, -1]).size == 1
# Load label file.
label_f = f.parent / (f.stem + '.labels' + f.suffix)
if label_f.is_file():
label = read_plyfile(label_f)
# Sanity check that the pointcloud and its label has same vertices.
assert pointcloud.shape[0] == label.shape[0]
assert np.allclose(pointcloud[:, :3], label[:, :3])
else: # Label may not exist in test case.
label = np.zeros_like(pointcloud)
xyz = pointcloud[:, :3]
pool = ProcessPoolExecutor(max_workers=9)
all_points = np.empty((0, 3))
out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
processed = np.hstack((pointcloud[:, :6], np.array([label[:, -1]]).T))
save_point_cloud(processed, out_f, with_label=True, verbose=False)
f = Path(path.split(',')[0])
phase_out_path = Path(path.split(',')[1])
pointcloud = read_plyfile(f)
# Make sure alpha value is meaningless.
assert np.unique(pointcloud[:, -1]).size == 1
# Load label file.
label_f = f.parent / (f.stem + '.labels' + f.suffix)
if label_f.is_file():
label = read_plyfile(label_f)
# Sanity check that the pointcloud and its label has same vertices.
assert pointcloud.shape[0] == label.shape[0]
assert np.allclose(pointcloud[:, :3], label[:, :3])
else: # Label may not exist in test case.
label = np.zeros_like(pointcloud)
out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
processed = np.hstack((pointcloud[:, :6], np.array([label[:, -1]]).T))
save_point_cloud(processed, out_f, with_label=True, verbose=False)


path_list = []
for out_path, in_path in SUBSETS.items():
phase_out_path = SCANNET_OUT_PATH / out_path
phase_out_path.mkdir(parents=True, exist_ok=True)
for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
path_list.append(str(f)+','+str(phase_out_path))
phase_out_path = SCANNET_OUT_PATH / out_path
phase_out_path.mkdir(parents=True, exist_ok=True)
for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
path_list.append(str(f) + ',' + str(phase_out_path))

pool = ProcessPoolExecutor(max_workers=20)
result = list(pool.map(handle_process,path_list))
for i in result:
pass
result = list(pool.map(handle_process, path_list))

# Fix bug in the data.
for files, bug_index in BUGS.items():
print(files)
print(files)

for f in SCANNET_OUT_PATH.glob(files):
pointcloud = read_plyfile(f)
bug_mask = pointcloud[:, -1] == bug_index
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
pointcloud[bug_mask, -1] = 0
save_point_cloud(pointcloud, f, with_label=True, verbose=False)
for f in SCANNET_OUT_PATH.glob(files):
pointcloud = read_plyfile(f)
bug_mask = pointcloud[:, -1] == bug_index
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
pointcloud[bug_mask, -1] = 0
save_point_cloud(pointcloud, f, with_label=True, verbose=False)
Loading

0 comments on commit c4140c7

Please sign in to comment.