Skip to content

Commit

Permalink
preprocessing script fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Oct 11, 2019
1 parent 4b1315c commit 7699599
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 34 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
# Spatio-Temporal Segmentation

This repository contains the accompanying code for [4D-SpatioTemporal ConvNets: Minkowski Convolutional Neural Networks, CVPR'19](https://arxiv.org/abs/1904.08755).


## ScanNet Training

First, preprocess all scannet raw point cloud with the following command after you set the path correctly.

```
python -m lib.datasets.prepreocessing.scannet
```

Then, train the scannet network with

```
./scripts/train_scannet.sh 0 -default "--scannet_path /path/to/preprocessed/scannet"
```

The first argument is the GPU id and the second argument is the path postfix
and the last argument is the miscellaneous arguments.

41 changes: 7 additions & 34 deletions lib/datasets/preprocessing/scannet.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,11 @@
# Check that all points are included in the crops.
assert set(tuple(l) for l in all_points.tolist()) == set(tuple(l) for l in xyz.tolist())

# Split trainval data to train/val according to scene.
trainval_files = [f.name for f in (SCANNET_OUT_PATH / TRAIN_DEST).glob('*.ply')]
trainval_scenes = list(set(f.split('_')[0] for f in trainval_files))
shuffle(trainval_scenes)
num_train = int(len(trainval_scenes))
train_scenes = trainval_scenes[:num_train]
val_scenes = trainval_scenes[num_train:]

# Collect file list for all phase.
train_files = [f'{TRAIN_DEST}/{f}' for f in trainval_files if any(s in f for s in train_scenes)]
val_files = [f'{TRAIN_DEST}/{f}' for f in trainval_files if any(s in f for s in val_scenes)]
test_files = [f'{TEST_DEST}/{f.name}' for f in (SCANNET_OUT_PATH / TEST_DEST).glob('*.ply')]

# Data sanity check.
assert not set(train_files).intersection(val_files)
assert all((SCANNET_OUT_PATH / f).is_file() for f in train_files)
assert all((SCANNET_OUT_PATH / f).is_file() for f in val_files)
assert all((SCANNET_OUT_PATH / f).is_file() for f in test_files)

# Write file list for all phase.
with open(SCANNET_OUT_PATH / 'train.txt', 'w') as f:
f.writelines([f + '\n' for f in train_files])
with open(SCANNET_OUT_PATH / 'val.txt', 'w') as f:
f.writelines([f + '\n' for f in val_files])
with open(SCANNET_OUT_PATH / 'test.txt', 'w') as f:
f.writelines([f + '\n' for f in test_files])

# Fix bug in the data.
# for files, bug_index in BUGS.items():
# for f in SCANNET_OUT_PATH.glob(files):
# pointcloud = read_plyfile(f)
# bug_mask = pointcloud[:, -1] == bug_index
# print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
# pointcloud[bug_mask, -1] = 0
# save_point_cloud(pointcloud, f, with_label=True, verbose=False)
for files, bug_index in BUGS.items():
for f in SCANNET_OUT_PATH.glob(files):
pointcloud = read_plyfile(f)
bug_mask = pointcloud[:, -1] == bug_index
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
pointcloud[bug_mask, -1] = 0
save_point_cloud(pointcloud, f, with_label=True, verbose=False)

0 comments on commit 7699599

Please sign in to comment.