Skip to content

Commit

Permalink
accelerate preprocess && small bug fix (#14)
Browse files Browse the repository at this point in the history
N/A

use multi-process to reduce preprocessing time

Signed-off-by: s <panyunyi97>
  • Loading branch information
faultaddr authored and chrischoy committed Dec 12, 2019
1 parent 36d67fe commit ee203da
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 40 deletions.
86 changes: 47 additions & 39 deletions lib/datasets/preprocessing/scannet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,64 @@
from random import shuffle

import numpy as np

import sys
from lib.pc_utils import read_plyfile, save_point_cloud


SCANNET_RAW_PATH = Path('/data/chrischoy/datasets/scannet_raw')
SCANNET_OUT_PATH = Path('/data/chrischoy/datasets/scannet_processed')
from concurrent.futures import ProcessPoolExecutor
SCANNET_RAW_PATH = Path('/path/ScanNet_data/')
SCANNET_OUT_PATH = Path('/path/scans_processed/')
TRAIN_DEST = 'train'
TEST_DEST = 'test'
SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'}
POINTCLOUD_FILE = '_vh_clean_2.ply'
BUGS = {
'train/scene0270_00_*.ply': 50,
'train/scene0270_02_*.ply': 50,
'train/scene0384_00_*.ply': 149,
'train/scene0270_00.ply': 50,
'train/scene0270_02.ply': 50,
'train/scene0384_00.ply': 149,
}
print('start preprocess')
# Preprocess data.

def handle_process(path):
f = Path(path.split(',')[0])
phase_out_path = Path(path.split(',')[1])
pointcloud = read_plyfile(f)
# Make sure alpha value is meaningless.
assert np.unique(pointcloud[:, -1]).size == 1
# Load label file.
label_f = f.parent / (f.stem + '.labels' + f.suffix)
if label_f.is_file():
label = read_plyfile(label_f)
# Sanity check that the pointcloud and its label has same vertices.
assert pointcloud.shape[0] == label.shape[0]
assert np.allclose(pointcloud[:, :3], label[:, :3])
else: # Label may not exist in test case.
label = np.zeros_like(pointcloud)
xyz = pointcloud[:, :3]
pool = ProcessPoolExecutor(max_workers=9)
all_points = np.empty((0, 3))
out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
processed = np.hstack((pointcloud[:, :6], np.array([label[:, -1]]).T))
save_point_cloud(processed, out_f, with_label=True, verbose=False)

# Preprocess data.
path_list = []
for out_path, in_path in SUBSETS.items():
phase_out_path = SCANNET_OUT_PATH / out_path
phase_out_path.mkdir(parents=True, exist_ok=True)
for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
# Load pointcloud file.
pointcloud = read_plyfile(f)
# Make sure alpha value is meaningless.
assert np.unique(pointcloud[:, -1]).size == 1
# Load label file.
label_f = f.parent / (f.stem + '.labels' + f.suffix)
if label_f.is_file():
label = read_plyfile(label_f)
# Sanity check that the pointcloud and its label has same vertices.
assert pointcloud.shape[0] == label.shape[0]
assert np.allclose(pointcloud[:, :3], label[:, :3])
else: # Label may not exist in test case.
label = np.zeros_like(pointcloud)
xyz = pointcloud[:, :3]

all_points = np.empty((0, 3))
out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
processed = np.hstack((pointcloud[:, :6], np.array([label[:, -1]]).T))
save_point_cloud(processed, out_f, with_label=True, verbose=False)
phase_out_path = SCANNET_OUT_PATH / out_path
phase_out_path.mkdir(parents=True, exist_ok=True)
for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
path_list.append(str(f)+','+str(phase_out_path))

# Check that all points are included in the crops.
assert set(tuple(l) for l in all_points.tolist()) == set(tuple(l) for l in xyz.tolist())
pool = ProcessPoolExecutor(max_workers=20)
result = list(pool.map(handle_process,path_list))
for i in result:
pass

# Fix bug in the data.
for files, bug_index in BUGS.items():
for f in SCANNET_OUT_PATH.glob(files):
pointcloud = read_plyfile(f)
bug_mask = pointcloud[:, -1] == bug_index
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
pointcloud[bug_mask, -1] = 0
save_point_cloud(pointcloud, f, with_label=True, verbose=False)
print(files)

for f in SCANNET_OUT_PATH.glob(files):
pointcloud = read_plyfile(f)
bug_mask = pointcloud[:, -1] == bug_index
print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}')
pointcloud[bug_mask, -1] = 0
save_point_cloud(pointcloud, f, with_label=True, verbose=False)
2 changes: 1 addition & 1 deletion lib/datasets/scannet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self,
data_root = config.scannet_path
if phase not in [DatasetPhase.Train, DatasetPhase.TrainVal]:
self.CLIP_BOUND = self.TEST_CLIP_BOUND
data_paths = read_txt(os.path.join('./splits/scannt', self.DATA_PATH_FILE[phase]))
data_paths = read_txt(os.path.join('./splits/scannet', self.DATA_PATH_FILE[phase]))
logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase]))
super().__init__(
data_paths,
Expand Down

0 comments on commit ee203da

Please sign in to comment.