diff --git a/MinkowskiEngine/MinkowskiOps.py b/MinkowskiEngine/MinkowskiOps.py index 7bba38e3..d3160dae 100644 --- a/MinkowskiEngine/MinkowskiOps.py +++ b/MinkowskiEngine/MinkowskiOps.py @@ -77,3 +77,54 @@ def cat(*sparse_tensors): torch.cat(tens, dim=1), coords_key=sparse_tensors[0].coords_key, coords_manager=coords_man) + + +def cat_union(A, B): + r"""Concatenate sparse tensors (different sparsity) + + Concatenate sparse tensor features with different sparsity patterns. All sparse tensors must have the same + `coords_man` (does not need the same coordinates). If the coordinate is matched, corresponding features + are concatenated, otherwise, the zero features are concatenated to original features. + + Example:: + + >>> import MinkowskiEngine as ME + >>> import torch + >>> feats1 = torch.zeros(3, 3) + 1 + >>> coords1 = torch.Tensor([[0,0,0], [0,0,1], [0,1,0]]) + >>> feats2 = torch.zeros(4, 3) + 2 + >>> coords2 = torch.Tensor([[0,0,0], [0,0,1], [0,1,0], [1,1,1]]) + >>> a = ME.SparseTensor(feats1, coords1) + >>> b = ME.SparseTensor(feats2, coords2, coords_manager=a.coords_man, force_creation=True) + >>> result = ME.cat_union(a, b) # the coordinates are 'coords2' and the feature is [0,0,0,2,2,2] + # the feature of coordinate [1,1,1] is [0,0,0,2,2,2] + + """ + + cm = A.coords_man + assert cm == B.coords_man, "different coords_man" + assert A.tensor_stride == B.tensor_stride, "different tensor_stride" + + zeros_cat_with_A = torch.zeros([A.F.shape[0], B.F.shape[1]]).to(A.device) + zeros_cat_with_B = torch.zeros([B.F.shape[0], A.F.shape[1]]).to(A.device) + + feats_A = torch.cat([A.F, zeros_cat_with_A], dim=1) + feats_B = torch.cat([zeros_cat_with_B, B.F], dim=1) + + new_A = SparseTensor( + feats=feats_A, + coords=A.C, + coords_manager=cm, + force_creation=True, + tensor_stride=A.tensor_stride, + ) + + new_B = SparseTensor( + feats=feats_B, + coords=B.C, + coords_manager=cm, + force_creation=True, + tensor_stride=A.tensor_stride, + ) + + return new_A + new_B diff --git a/MinkowskiEngine/__init__.py b/MinkowskiEngine/__init__.py index 2360b408..1c30865b 100644 --- a/MinkowskiEngine/__init__.py +++ b/MinkowskiEngine/__init__.py @@ -72,7 +72,7 @@ import MinkowskiOps -from MinkowskiOps import MinkowskiLinear, cat +from MinkowskiOps import MinkowskiLinear, cat, cat_union import MinkowskiFunctional