diff --git a/tests/test_bounding_box.py b/tests/test_bounding_box.py new file mode 100644 index 000000000..cf7dccfb1 --- /dev/null +++ b/tests/test_bounding_box.py @@ -0,0 +1,31 @@ +from wkcuber.api.bounding_box import BoundingBox, Mag +import pytest + + +def test_align_with_mag(): + + assert BoundingBox((1, 1, 1), (10, 10, 10)).align_with_mag(Mag(2)) == BoundingBox( + topleft=(0, 0, 0), size=(12, 12, 12) + ) + assert BoundingBox((1, 1, 1), (9, 9, 9)).align_with_mag(Mag(2)) == BoundingBox( + topleft=(0, 0, 0), size=(10, 10, 10) + ) + assert BoundingBox((1, 1, 1), (9, 9, 9)).align_with_mag(Mag(4)) == BoundingBox( + topleft=(0, 0, 0), size=(12, 12, 12) + ) + assert BoundingBox((1, 2, 3), (9, 9, 9)).align_with_mag(Mag(2)) == BoundingBox( + topleft=(0, 2, 2), size=(10, 10, 10) + ) + + +def test_in_mag(): + + with pytest.raises(AssertionError): + BoundingBox((1, 2, 3), (9, 9, 9)).in_mag(Mag(2)) + + with pytest.raises(AssertionError): + BoundingBox((2, 2, 2), (9, 9, 9)).in_mag(Mag(2)) + + assert BoundingBox((2, 2, 2), (10, 10, 10)).in_mag(Mag(2)) == BoundingBox( + topleft=(1, 1, 1), size=(5, 5, 5) + ) diff --git a/wkcuber/api/bounding_box.py b/wkcuber/api/bounding_box.py index 8e5f21c17..f380560f2 100644 --- a/wkcuber/api/bounding_box.py +++ b/wkcuber/api/bounding_box.py @@ -185,20 +185,30 @@ def is_empty(self) -> bool: return not all(self.size > 0) - def in_mag(self, mag: Mag, ceil: bool = False) -> "BoundingBox": + def in_mag(self, mag: Mag) -> "BoundingBox": np_mag = np.array(mag.to_array()) - def ceil_maybe(array: np.ndarray) -> np.ndarray: - if ceil: - return np.ceil(array) - return array + assert ( + np.count_nonzero(self.topleft % np_mag) == 0 + ), f"topleft {self.topleft} is not aligned with the mag {mag}. Use BoundingBox.align_with_mag()." + assert ( + np.count_nonzero(self.bottomright % np_mag) == 0 + ), f"bottomright {self.bottomright} is not aligned with the mag {mag}. Use BoundingBox.align_with_mag()." return BoundingBox( - topleft=ceil_maybe(self.topleft / np_mag).astype(np.int), - size=ceil_maybe(self.size / np_mag).astype(np.int), + topleft=(self.topleft // np_mag).astype(np.int), + size=(self.size // np_mag).astype(np.int), ) + def align_with_mag(self, mag: Mag): + """Rounds the bounding box up, so that both topleft and bottomright are divisible by mag.""" + + np_mag = np.array(mag.to_array()) + topleft = (self.topleft // np_mag).astype(np.int) * np_mag + bottomright = np.ceil(self.bottomright / np_mag).astype(np.int) * np_mag + return BoundingBox(topleft, bottomright - topleft) + def contains(self, coord: Shape3D) -> bool: coord = np.array(coord)