Skip to content

Commit

Permalink
feat: kornia binary morph ops (#428)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergiy Popovych <[email protected]>
  • Loading branch information
nkemnitz and supersergiy authored Aug 1, 2023
1 parent 7254e12 commit e7dc3cb
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 82 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ tensor-ops = [
"fastremap >= 1.12.2",
"einops >= 0.4.1",
"torchfields >= 0.1.2",
"kornia >= 0.6.12",
]
viz = [
"zetta_utils[tensor_ops]",
Expand Down
183 changes: 144 additions & 39 deletions tests/unit/tensor_ops/test_mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# pylint: disable=missing-docstring,invalid-name
import numpy as np
import pytest
import skimage
import torch

from zetta_utils.tensor_ops import mask
Expand Down Expand Up @@ -80,66 +83,168 @@ def test_filter_cc_big():
assert_array_equal(result, expected)


def test_coarsen_width1():
a = (
torch.Tensor(
def test_kornia_closing():
a = np.expand_dims(
np.array(
[
[
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 1],
[0, 1, 1, 1],
[0, 1, 0, 1],
[0, 1, 1, 1],
]
]
).unsqueeze(-1)
> 0
],
dtype=np.uint8,
),
-1,
)

expected = (
torch.Tensor(
expected = np.expand_dims(
np.array(
[
[
[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0],
]
],
dtype=np.uint8,
),
-1,
)

result = mask.kornia_closing(
a, "square", width=3, border_type="constant", border_value=0, device="cpu"
)
assert_array_equal(result, expected)


def test_kornia_opening():
a = np.expand_dims(
np.array(
[
[
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 1, 1, 1],
]
]
).unsqueeze(-1)
> 0
),
-1,
)

result = mask.coarsen(
a,
width=1,
expected = np.expand_dims(
np.array(
[
[
[0, 0, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 1, 1],
]
]
),
-1,
)

result = mask.kornia_opening(a, torch.ones(3, 3), border_type="geodesic")
assert_array_equal(result, expected)


def test_binary_closing():
a = torch.tensor(
[
def test_kornia_dilation():
a = np.expand_dims(
np.array(
[
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 0, 1],
[0, 1, 1, 1],
[
[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
]
]
]
).unsqueeze(-1)
),
-1,
)

expected = torch.Tensor(
[
expected = np.expand_dims(
np.array(
[
[0, 0, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1],
[0, 1, 1, 1],
[
[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0],
]
]
]
).unsqueeze(-1)
),
-1,
)

result = mask.binary_closing(
a,
result = mask.kornia_dilation(a, torch.ones(3, 3), border_type="geodesic")
assert_array_equal(result, expected)


def test_kornia_erosion():
a = np.expand_dims(
np.array(
[
[
[0, 0, 0, 0],
[0, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0],
]
]
),
-1,
)

expected = np.expand_dims(
np.array(
[
[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
]
]
),
-1,
)
assert_array_equal(result.bool(), expected.bool())

result = mask.kornia_erosion(a, torch.ones(3, 3), border_type="geodesic")
assert_array_equal(result, expected)


@pytest.mark.parametrize(
"kernel, width, expected_kernel",
[
["square", 7, torch.ones(7, 7)],
["diamond", 7, torch.tensor(skimage.morphology.diamond(7))],
["star", 5, torch.tensor(skimage.morphology.star(5))],
["disk", 7, torch.tensor(skimage.morphology.disk(7))],
[torch.ones(5, 3), None, torch.ones(5, 3)],
[np.ones((5, 3)), None, torch.ones(5, 3)],
],
)
def test_normalize_kernel(kernel, width, expected_kernel):
result = mask._normalize_kernel(kernel, width, device=None) # pylint: disable=protected-access
assert_array_equal(result, expected_kernel)


@pytest.mark.parametrize(
"kernel, width, expected_exc",
[
["ball", 7, ValueError],
["square", 2.5, TypeError],
["square", -1, ValueError],
[torch.ones(5, 3, 2), None, ValueError],
[np.ones((5, 3, 2)), None, ValueError],
],
)
def test_normalize_kernel_exc(kernel, width, expected_exc):
with pytest.raises(expected_exc):
mask._normalize_kernel(kernel, width, device=None) # pylint: disable=protected-access
Loading

0 comments on commit e7dc3cb

Please sign in to comment.