-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
176 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright 2024 The Google Research Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Functions for processing object data.""" | ||
|
||
from typing import Tuple | ||
import numpy as np | ||
from tf import transformations | ||
|
||
|
||
def mask_to_points(mask: np.ndarray, | ||
voxel_size: Tuple[float, float, float], | ||
origin_in_center=True): | ||
"""Converts an object mask to a dense point cloud. | ||
Args: | ||
mask: 3-5 dim boolean ndarray representing the object mask | ||
voxel_size: xyz voxel size in physical units | ||
origin_in_center: whether the origin of the coordinate system should | ||
be located in the center of `mask` | ||
Returns: | ||
[3, n] array with coordinates of the `True` voxels of mask; | ||
coordinates are expressed in physical units | ||
""" | ||
if mask.ndim == 5: | ||
mask = mask.squeeze(axis=(0, 4)) | ||
elif mask.ndim == 4: | ||
mask = mask.squeeze(axis=3) | ||
elif mask.ndim != 3: | ||
raise ValueError('mask needs to be a 3-5 dimensional array.') | ||
|
||
if origin_in_center: | ||
r = (np.array(mask.shape) - 1) / 2 | ||
else: | ||
r = np.array([0, 0, 0]) | ||
|
||
z, y, x = np.where(mask) | ||
return np.array([ | ||
(x - r[2]) * voxel_size[0], # | ||
(y - r[1]) * voxel_size[1], # | ||
(z - r[0]) * voxel_size[2] | ||
]) | ||
|
||
|
||
def compute_orientation(points: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | ||
eigvals, eigvecs = np.linalg.eig(np.cov(points)) | ||
sorted_idx = np.argsort(eigvals)[::-1] | ||
return eigvecs[:, sorted_idx], eigvals[sorted_idx] | ||
|
||
|
||
def compute_rotation_matrix(eigvecs: np.ndarray) -> np.ndarray: | ||
"""Computes a rotation matrix to put the object in a standard orientation. | ||
Args: | ||
eigvecs: two xyz vectors ([3, 2] ndarray) corresponding to the | ||
orthogonal directions of maximum variance, in descending order of | ||
variance | ||
Returns: | ||
3x3 rotation matrix which will reorient the coordinate system so that | ||
'z' is the direction of maximum variance, and 'y' is the axis of 2nd | ||
largest variance | ||
""" | ||
# 1st rotation to make z the axis of maximum variance | ||
base = np.array([0, 0, 1]) | ||
u = eigvecs[:, 0] | ||
axis = np.cross(u, base) | ||
theta = np.arccos(u.dot(base)) | ||
|
||
# Cut to 3x3 since we don't need homegeneous coordinates. | ||
rot = transformations.rotation_matrix(theta, axis)[:3, :3] | ||
|
||
# 2nd rotation to align the 2nd direction of max variance with the y axis | ||
v = np.matmul(rot, eigvecs[:, 1]) | ||
base = np.array([0, 1, 0]) | ||
axis = np.cross(v, base) | ||
theta = np.arccos(v.dot(np.array([0, 1, 0]))) | ||
rot2 = transformations.rotation_matrix(theta, axis)[:3, :3] | ||
|
||
# Compose the two rotations. | ||
return np.matmul(rot2, rot) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2024 The Google Research Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Implements the disjoint-set forest/union-find data structure.""" | ||
|
||
|
||
class UnionFind: | ||
"""Dictionary-based implementation of the disjoint-set forest data structure. | ||
Uses union-by-rank and path compression. Works with any hashable type. | ||
Does not store singleton sets. | ||
More info: http://en.wikipedia.org/wiki/Disjoint-set_data_structure | ||
""" | ||
|
||
def __init__(self): | ||
self._parents = {} | ||
self._ranks = {} | ||
|
||
def __bool__(self): | ||
return bool(self._parents) | ||
|
||
def Union(self, a, b): | ||
"""Merge 'a' and 'b' into a single set.""" | ||
root_a = self.Find(a) | ||
root_b = self.Find(b) | ||
if root_a == root_b: | ||
return | ||
|
||
rank_a = self._ranks.setdefault(root_a, 1) | ||
rank_b = self._ranks.setdefault(root_b, 1) | ||
|
||
if rank_a < rank_b: | ||
self._parents[root_a] = root_b | ||
elif rank_a > rank_b: | ||
self._parents[root_b] = root_a | ||
else: | ||
self._parents[root_b] = root_a | ||
self._ranks[root_a] += 1 | ||
|
||
def IsSingleton(self, a): | ||
"""Returns whether set 'a' only contains a single element.""" | ||
return a not in self._ranks | ||
|
||
def Find(self, a): | ||
"""Finds the representative of 'a'. | ||
If 'a' was not seen before, treats it as a singleton set. | ||
Args: | ||
a: object to find a representative for | ||
Returns: | ||
representative of 'a' | ||
""" | ||
if a not in self._parents: | ||
return a | ||
|
||
# Find representative. | ||
path = [a] | ||
root = self._parents[a] | ||
while root != path[-1]: | ||
path.append(root) | ||
root = self._parents.get(root, root) | ||
|
||
# Compress path. | ||
for ancestor in path[::-1]: | ||
self._parents[ancestor] = root | ||
|
||
return root |