Skip to content

Commit

Permalink
ME collation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 27, 2020
1 parent 4a4c0c5 commit 1c937e8
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions lib/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import scipy.interpolate
import torch

import MinkowskiEngine as ME


# A sparse tensor consists of coordinates and associated features.
# You must apply augmentation to both.
Expand Down Expand Up @@ -260,19 +262,15 @@ def __call__(self, list_data):
f'limit. Truncating batch size at {batch_id} out of {num_full_batch_size} with {batch_num_points - num_points}.'
)
break
coords_batch.append(
torch.cat((torch.from_numpy(
coords[batch_id]).int(), torch.ones(num_points, 1).int() * batch_id), 1))
coords_batch.append(torch.from_numpy(coords[batch_id]).int())
feats_batch.append(torch.from_numpy(feats[batch_id]))
labels_batch.append(torch.from_numpy(labels[batch_id]).int())

batch_id += 1

# Concatenate all lists
coords_batch = torch.cat(coords_batch, 0).int()
feats_batch = torch.cat(feats_batch, 0).float()
labels_batch = torch.cat(labels_batch, 0).int()
return coords_batch, feats_batch, labels_batch
coords_batch, feats_batch, labels_batch = ME.utils.sparse_collate(coords_batch, feats_batch, labels_batch)
return coords_batch, feats_batch.float(), labels_batch


class cflt_collate_fn_factory:
Expand All @@ -293,16 +291,11 @@ def __call__(self, list_data):
num_truncated_batch = coords_batch[:, -1].max().item() + 1

batch_id = 0
pointclouds_batch, transformations_batch = [], []
transformations_batch = []
for transformation in transformations:
if batch_id >= num_truncated_batch:
break
transformations_batch.append(
torch.cat(
(torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id),
1))
transformations_batch.append(torch.from_numpy(transformation).float())
batch_id += 1

pointclouds_batch = torch.cat(pointclouds_batch, 0).float()
transformations_batch = torch.cat(transformations_batch, 0).float()
return coords_batch, feats_batch, labels_batch, transformations_batch

0 comments on commit 1c937e8

Please sign in to comment.