Skip to content

Commit

Permalink
Add checkerboard inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 6, 2024
1 parent 84ad062 commit 3e36b6e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ class Checkerboard(Coupling):
Checkerboard coupling for image data.
"""

def __init__(self, event_shape, resolution: int = 2):
def __init__(self, event_shape, resolution: int = 2, invert: bool = False):
"""
:param event_shape: image shape with the form (n_channels, width, height). Note: width and height must be equal
and a power of two.
:param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two
and smaller than image width.
:param invert: invert the checkerboard mask.
"""
n_channels, width, _ = event_shape
assert width % resolution == 0
Expand All @@ -93,6 +94,8 @@ def __init__(self, event_shape, resolution: int = 2):
a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution)
mask = torch.kron(a, torch.ones((square_side_length, square_side_length)))
mask = mask.bool()
if invert:
mask = ~mask
super().__init__(event_shape, mask)


Expand Down
16 changes: 16 additions & 0 deletions test/test_checkerboard_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,19 @@ def test_checkerboard_medium():
], dtype=torch.bool)[None].repeat(3, 1, 1)
)
assert torch.allclose(coupling.target_mask, ~coupling.source_mask)


def test_checkerboard_small_inverted():
torch.manual_seed(0)
image_shape = (3, 4, 4)
coupling = Checkerboard(image_shape, resolution=2, invert=True)
assert torch.allclose(
coupling.source_mask,
~torch.tensor([
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1],
], dtype=torch.bool)[None].repeat(3, 1, 1)
)
assert torch.allclose(coupling.target_mask, ~coupling.source_mask)

0 comments on commit 3e36b6e

Please sign in to comment.