Skip to content

Commit

Permalink
correct sparse_quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 4, 2020
1 parent bd852b4 commit c6e421a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 55 deletions.
1 change: 1 addition & 0 deletions lib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def initialize_data_loader(DatasetClass,

if augment_data:
input_transforms += [
t.RandomDropout(0.2),
t.RandomHorizontalFlip(DatasetClass.ROTATION_AXIS, DatasetClass.IS_TEMPORAL),
t.ChromaticAutoContrast(),
t.ChromaticTranslation(config.data_aug_color_trans_ratio),
Expand Down
68 changes: 32 additions & 36 deletions lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,38 @@ def train(model, data_loader, val_data_loader, config, transform_data_fn=None):
data_time, batch_loss = 0, 0
iter_timer.tic()

try: # torch issue #16998
for sub_iter in range(config.iter_size):
# Get training data
data_timer.tic()
if config.return_transformation:
coords, input, target, pointcloud, transformation = data_iter.next()
else:
coords, input, target = data_iter.next()

# For some networks, making the network invariant to even, odd coords is important
coords[:, :3] += (torch.rand(3) * 100).type_as(coords)

# Preprocess input
color = input[:, :3].int()
if config.normalize_color:
input[:, :3] = input[:, :3] / 255. - 0.5
sinput = SparseTensor(input, coords).to(device)

data_time += data_timer.toc(False)

# Feed forward
inputs = (sinput,) if config.wrapper_type == 'None' else (sinput, coords, color)
# model.initialize_coords(*init_args)
soutput = model(*inputs)
# The output of the network is not sorted
target = target.long().to(device)

loss = criterion(soutput.F, target.long())

# Compute and accumulate gradient
loss /= config.iter_size
batch_loss += loss.item()
loss.backward()
except Exception as e:
logging.error(e)
continue
for sub_iter in range(config.iter_size):
# Get training data
data_timer.tic()
if config.return_transformation:
coords, input, target, pointcloud, transformation = data_iter.next()
else:
coords, input, target = data_iter.next()

# For some networks, making the network invariant to even, odd coords is important
coords[:, :3] += (torch.rand(3) * 100).type_as(coords)

# Preprocess input
color = input[:, :3].int()
if config.normalize_color:
input[:, :3] = input[:, :3] / 255. - 0.5
sinput = SparseTensor(input, coords).to(device)

data_time += data_timer.toc(False)

# Feed forward
inputs = (sinput,) if config.wrapper_type == 'None' else (sinput, coords, color)
# model.initialize_coords(*init_args)
soutput = model(*inputs)
# The output of the network is not sorted
target = target.long().to(device)

loss = criterion(soutput.F, target.long())

# Compute and accumulate gradient
loss /= config.iter_size
batch_loss += loss.item()
loss.backward()

# Update number of steps
optimizer.step()
Expand Down
17 changes: 17 additions & 0 deletions lib/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,23 @@ def __call__(self, coords, feats, labels):
##############################
# Coordinate transformations
##############################
class RandomDropout(object):

def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
"""
upright_axis: axis index among x,y,z, i.e. 2 for z
"""
self.dropout_ratio = dropout_ratio
self.dropout_application_ratio = dropout_application_ratio

def __call__(self, coords, feats, labels):
if random.random() < self.dropout_ratio:
N = len(coords)
inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False)
return coords[inds], feats[inds], labels[inds]
return coords, feats, labels


class RandomHorizontalFlip(object):

def __init__(self, upright_axis, is_temporal):
Expand Down
35 changes: 16 additions & 19 deletions lib/voxelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ def clip(self, coords, center=None, trans_aug_ratio=None):
trans = np.multiply(trans_aug_ratio, bound_size)
center += trans
# Clip points outside the limit
clip_inds = ((coords[:, 0] >= (lim[0][0] + center[0])) &
(coords[:, 0] < (lim[0][1] + center[0])) &
(coords[:, 1] >= (lim[1][0] + center[1])) &
(coords[:, 1] < (lim[1][1] + center[1])) &
(coords[:, 2] >= (lim[2][0] + center[2])) &
(coords[:, 2] < (lim[2][1] + center[2])))
clip_inds = ((coords[:, 0] >=
(lim[0][0] + center[0])) & (coords[:, 0] <
(lim[0][1] + center[0])) & (coords[:, 1] >=
(lim[1][0] + center[1])) &
(coords[:, 1] <
(lim[1][1] + center[1])) & (coords[:, 2] >=
(lim[2][0] + center[2])) & (coords[:, 2] <
(lim[2][1] + center[2])))
return clip_inds

def voxelize(self, coords, feats, labels, center=None):
Expand Down Expand Up @@ -123,12 +125,9 @@ def voxelize(self, coords, feats, labels, center=None):
rigid_transformation = M_t @ rigid_transformation
coords_aug = np.floor(coords_aug - min_coords)

inds = ME.utils.sparse_quantize(coords_aug, return_index=True)
coords_aug, feats, labels = coords_aug[inds], feats[inds], labels[inds]

# Normal rotation
if feats.shape[1] > 6:
feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T)
# key = self.hash(coords_aug) # floor happens by astype(np.uint64)
coords_aug, feats, labels = ME.utils.sparse_quantize(
coords_aug, feats, labels=labels, ignore_label=self.ignore_label)

return coords_aug, feats, labels, rigid_transformation.flatten()

Expand All @@ -140,7 +139,9 @@ def voxelize_temporal(self,
return_transformation=False):
# Legacy code, remove
if centers is None:
centers = [None, ] * len(coords_t)
centers = [
None,
] * len(coords_t)
coords_tc, feats_tc, labels_tc, transformation_tc = [], [], [], []

# ######################### Data Augmentation #############################
Expand Down Expand Up @@ -171,12 +172,8 @@ def voxelize_temporal(self,
homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype)))
coords_aug = np.floor(homo_coords @ rigid_transformation.T)[:, :3]

inds = ME.utils.sparse_quantize(coords_aug, return_index=True)
coords_aug, feats, labels = coords_aug[inds], feats[inds], labels[inds]

# If use normal rotation
if feats.shape[1] > 6:
feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T)
coords_aug, feats, labels = ME.utils.sparse_quantize(
coords_aug, feats, labels=labels, ignore_label=self.ignore_label)

coords_tc.append(coords_aug)
feats_tc.append(feats)
Expand Down

0 comments on commit c6e421a

Please sign in to comment.