From 1c937e8bcbf030974667e16f8c76ff1c9c6988af Mon Sep 17 00:00:00 2001 From: Chris Choy Date: Mon, 27 Jan 2020 01:20:20 -0800 Subject: [PATCH] ME collation --- lib/transforms.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/lib/transforms.py b/lib/transforms.py index 831a5f1..7882d37 100644 --- a/lib/transforms.py +++ b/lib/transforms.py @@ -7,6 +7,8 @@ import scipy.interpolate import torch +import MinkowskiEngine as ME + # A sparse tensor consists of coordinates and associated features. # You must apply augmentation to both. @@ -260,19 +262,15 @@ def __call__(self, list_data): f'limit. Truncating batch size at {batch_id} out of {num_full_batch_size} with {batch_num_points - num_points}.' ) break - coords_batch.append( - torch.cat((torch.from_numpy( - coords[batch_id]).int(), torch.ones(num_points, 1).int() * batch_id), 1)) + coords_batch.append(torch.from_numpy(coords[batch_id]).int()) feats_batch.append(torch.from_numpy(feats[batch_id])) labels_batch.append(torch.from_numpy(labels[batch_id]).int()) batch_id += 1 # Concatenate all lists - coords_batch = torch.cat(coords_batch, 0).int() - feats_batch = torch.cat(feats_batch, 0).float() - labels_batch = torch.cat(labels_batch, 0).int() - return coords_batch, feats_batch, labels_batch + coords_batch, feats_batch, labels_batch = ME.utils.sparse_collate(coords_batch, feats_batch, labels_batch) + return coords_batch, feats_batch.float(), labels_batch class cflt_collate_fn_factory: @@ -293,16 +291,11 @@ def __call__(self, list_data): num_truncated_batch = coords_batch[:, -1].max().item() + 1 batch_id = 0 - pointclouds_batch, transformations_batch = [], [] + transformations_batch = [] for transformation in transformations: if batch_id >= num_truncated_batch: break - transformations_batch.append( - torch.cat( - (torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id), - 1)) + transformations_batch.append(torch.from_numpy(transformation).float()) batch_id += 1 - pointclouds_batch = torch.cat(pointclouds_batch, 0).float() - transformations_batch = torch.cat(transformations_batch, 0).float() return coords_batch, feats_batch, labels_batch, transformations_batch