Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature inversion task (#81) に関わるPR内のテストコードの作成 #85

Merged
merged 27 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
49b3542
first commit
myaokai Jan 23, 2024
fb203f1
create test files
myaokai Jan 23, 2024
3210670
add TestBaseLatent
myaokai Jan 24, 2024
31d3ba5
Update tests/recon/torch/modules/test_latent.py
myaokai Jan 24, 2024
4f367b5
Update tests/recon/torch/modules/test_latent.py
myaokai Jan 24, 2024
3660f40
Update tests/recon/torch/modules/test_latent.py
myaokai Jan 24, 2024
33f2c9f
update test latent
myaokai Jan 25, 2024
b03b68b
create test_core
myaokai Jan 25, 2024
6cef2f2
update test_core
myaokai Jan 25, 2024
547d337
create test_feature_domain
myaokai Jan 25, 2024
b038a75
create test_image_domain
myaokai Jan 26, 2024
6605f28
Update test_image_domain.py
myaokai Jan 30, 2024
7cbfc0e
create task/test_core
myaokai Jan 30, 2024
ca3b162
create inversion test
myaokai Feb 6, 2024
948d5a5
Update test_inversion.py
myaokai Feb 8, 2024
bd2e180
Update test_inversion.py
myaokai Feb 8, 2024
3f69381
Update tests/dl/torch/domain/test_core.py
myaokai Mar 21, 2024
291807f
Update tests/dl/torch/domain/test_core.py
myaokai Mar 21, 2024
c87d8bc
Update tests/dl/torch/domain/test_image_domain.py
myaokai Mar 21, 2024
7f0c323
Update tests/dl/torch/domain/test_image_domain.py
myaokai Mar 21, 2024
13da19f
Update tests/dl/torch/domain/test_core.py
myaokai Mar 21, 2024
46d4d10
Update tests/recon/torch/modules/test_latent.py
myaokai Mar 21, 2024
75a45ab
Update tests/recon/torch/task/test_inversion.py
myaokai Mar 21, 2024
bbae3d7
Update test_core.py
myaokai Mar 21, 2024
556e6f0
fix
myaokai Mar 21, 2024
7ddd4fb
fix
myaokai Mar 21, 2024
d86e526
fix
myaokai Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bdpy/dl/torch/domain/image_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

if isinstance(center, (float, int)) or center.ndim == 0:
center = np.array([center])[np.newaxis, np.newaxis, np.newaxis]
if center.ndim == 1: # 1D vector (C,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you very much for fixing the potential bug I made 🙇

elif center.ndim == 1: # 1D vector (C,)
center = center[np.newaxis, :, np.newaxis, np.newaxis]
elif center.ndim == 3: # 3D vector (1, C, W, H)
center = center[np.newaxis]
Expand Down
Empty file.
138 changes: 138 additions & 0 deletions tests/dl/torch/domain/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests for bdpy.dl.torch.domain.core."""

import unittest
from bdpy.dl.torch.domain import core as core_module


class DummyAddDomain(core_module.Domain):
def send(self, num):
return num + 1

def receive(self, num):
return num - 1

myaokai marked this conversation as resolved.
Show resolved Hide resolved

class DummyDoubleDomain(core_module.Domain):
def send(self, num):
return num * 2

def receive(self, num):
return num // 2


class DummyUpperCaseDomain(core_module.Domain):
def send(self, text):
return text.upper()

def receive(self, value):
return value.lower()


class TestDomain(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the implementations of the test cases for the domain module. I think basically the current implementations already cover the minimum test cases. But I would apppreciate it if you could add test cases which check the invertibility of the function domain.send . domain.receive over several example inputs including edge cases. I anticipate the tests like following:

domain = ...  # instantiate some subclass of the reversible domain class
input_candidates = [...]  # candidate of possible inputs including typical & edge-case examples
for x in input_candidates:
    assert x == domain.send(domain.receive(x))
    assert x == domain.receive(domain.send(x))

"""Tests for bdpy.dl.torch.domain.core.Domain."""
def setUp(self):
self.domain = DummyAddDomain()
self.original_space_num = 0
self.internal_space_num = 1

def test_instantiation(self):
"""Test instantiation."""
self.assertRaises(TypeError, core_module.Domain)

def test_send(self):
"""test send"""
self.assertEqual(self.domain.send(self.original_space_num), self.internal_space_num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.domain.receive(self.internal_space_num), self.original_space_num)

def test_invertibility(self):
input_candidates = [-1, 0, 1, 0.5]
for x in input_candidates:
assert x == self.domain.send(self.domain.receive(x))
assert x == self.domain.receive(self.domain.send(x))


class TestInternalDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.InternalDomain."""
def setUp(self):
self.domain = core_module.InternalDomain()
self.num = 1

def test_send(self):
"""test send"""
self.assertEqual(self.domain.send(self.num), self.num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.domain.receive(self.num), self.num)

def test_invertibility(self):
input_candidates = [-1, 0, 1, 0.5]
for x in input_candidates:
assert x == self.domain.send(self.domain.receive(x))
assert x == self.domain.receive(self.domain.send(x))


class TestIrreversibleDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.IrreversibleDomain."""
def setUp(self):
self.domain = core_module.IrreversibleDomain()
self.num = 1

def test_send(self):
"""test send"""
self.assertEqual(self.domain.send(self.num), self.num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.domain.receive(self.num), self.num)


class TestComposedDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.ComposedDomain."""
def setUp(self):
self.composed_domain = core_module.ComposedDomain([
DummyDoubleDomain(),
DummyAddDomain(),
])
self.original_space_num = 0
self.internal_space_num = 2

def test_send(self):
"""test send"""
self.assertEqual(self.composed_domain.send(self.original_space_num), self.internal_space_num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.composed_domain.receive(self.internal_space_num), self.original_space_num)


class TestKeyValueDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.KeyValueDomain."""
def setUp(self):
self.key_value_domain = core_module.KeyValueDomain({
"name": DummyUpperCaseDomain(),
"age": DummyDoubleDomain()
})
self.original_space_data = {"name": "alice", "age": 30}
self.internal_space_data = {"name": "ALICE", "age": 60}

def test_send(self):
"""test send"""
self.assertEqual(self.key_value_domain.send(self.original_space_data), self.internal_space_data)

def test_receive(self):
"""test receive"""
self.assertEqual(self.key_value_domain.receive(self.internal_space_data), self.original_space_data)


if __name__ == "__main__":
#unittest.main()
composed_domain = core_module.ComposedDomain([
DummyDoubleDomain(),
DummyAddDomain(),
])
print(composed_domain.receive(-1))
print(composed_domain.send(-2))
86 changes: 86 additions & 0 deletions tests/dl/torch/domain/test_feature_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for bdpy.dl.torch.domain.feature_domain."""

import unittest
import torch
from bdpy.dl.torch.domain import feature_domain as feature_domain_module


class TestMethods(unittest.TestCase):
def setUp(self):
self.lnd_tensor = torch.empty((12, 196, 768))
self.nld_tensor = torch.empty((196, 12, 768))

def test_lnd2nld(self):
"""test _lnd2nld"""
self.assertEqual(feature_domain_module._lnd2nld(self.lnd_tensor).shape, self.nld_tensor.shape)

def test_nld2lnd(self):
"""test _nld2lnd"""
self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape)


class TestArbitraryFeatureKeyDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain."""
def setUp(self):
self.to_internal_mapping = {
"self_key1": "internal_key1",
"self_key2": "internal_key2"
}
self.to_self_mapping = {
"internal_key1": "self_key1",
"internal_key2": "self_key2"
}
self.features = {
"self_key1": 123,
"self_key2": 456
}
self.internal_features = {
"internal_key1": 123,
"internal_key2": 456
}

def test_send(self):
"""test send"""
# when both are specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping,
to_self=self.to_self_mapping
)
self.assertEqual(domain.send(self.features), self.internal_features)

# when only to_self is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_self=self.to_self_mapping
)
self.assertEqual(domain.send(self.features), self.internal_features)

# when only to_internal is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping
)
self.assertEqual(domain.send(self.features), self.internal_features)

def test_receive(self):
"""test receive"""
# when both are specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping,
to_self=self.to_self_mapping
)
self.assertEqual(domain.receive(self.internal_features), self.features)

# when only to_self is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_self=self.to_self_mapping
)
self.assertEqual(domain.receive(self.internal_features), self.features)

# when only to_internal is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping
)
self.assertEqual(domain.receive(self.internal_features), self.features)


if __name__ == "__main__":
unittest.main()
143 changes: 143 additions & 0 deletions tests/dl/torch/domain/test_image_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Tests for bdpy.dl.torch.domain.image_domain."""

import unittest
import torch
import numpy as np
import warnings
from bdpy.dl.torch.domain import image_domain as iamge_domain_module


class TestAffineDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.image_domain.AffineDomain"""
def setUp(self):
self.center0d = 0.0
self.center1d = np.random.randn(3)
self.center2d = np.random.randn(32, 32)
self.center3d = np.random.randn(3, 32, 32)
self.scale0d = 1
self.scale1d = np.random.randn(3)
self.scale2d = np.random.randn(32, 32)
self.scale3d = np.random.randn(3, 32, 32)
self.image = torch.rand((1, 3, 32, 32))

def test_instantiation(self):
"""Test instantiation."""
# Succeeds when center and scale are 0-dimensional
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d)
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain)

# Succeeds when center and scale are 1-dimensional
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d)
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain)

# Succeeds when center and scale are 3-dimensional
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d)
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain)

# Fails when the center is neither 1-dimensional nor 3-dimensional
with self.assertRaises(ValueError):
iamge_domain_module.AffineDomain(self.center2d, self.scale0d)

# Fails when the scale is neither 1-dimensional nor 3-dimensional
with self.assertRaises(ValueError):
iamge_domain_module.AffineDomain(self.center0d, self.scale2d)

def test_send_and_receive(self):
"""Test send and receive"""
# when 0d
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d)
transformed_image = affine_domain.send(self.image)
center0d = torch.from_numpy(np.array([self.center0d])[np.newaxis, np.newaxis, np.newaxis])
scale0d = torch.from_numpy(np.array([self.scale0d])[np.newaxis, np.newaxis, np.newaxis])
expected_transformed_image = (self.image + center0d) / self.scale0d
torch.testing.assert_close(transformed_image, expected_transformed_image)
received_image = affine_domain.receive(transformed_image)
expected_received_image = expected_transformed_image * scale0d - center0d
torch.testing.assert_close(received_image, expected_received_image)

# when 1d
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d)
transformed_image = affine_domain.send(self.image)
center1d = self.center1d[np.newaxis, :, np.newaxis, np.newaxis]
scale1d = self.scale1d[np.newaxis, :, np.newaxis, np.newaxis]
expected_transformed_image = (self.image + center1d) / scale1d
torch.testing.assert_close(transformed_image, expected_transformed_image)
received_image = affine_domain.receive(transformed_image)
expected_received_image = expected_transformed_image * scale1d - center1d
torch.testing.assert_close(received_image, expected_received_image)

# when 3d
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d)
transformed_image = affine_domain.send(self.image)
center3d = self.center3d[np.newaxis]
scale3d = self.scale3d[np.newaxis]
expected_transformed_image = (self.image + center3d) / scale3d
torch.testing.assert_close(transformed_image, expected_transformed_image)
received_image = affine_domain.receive(transformed_image)
expected_received_image = expected_transformed_image * scale3d - center3d
torch.testing.assert_close(received_image, expected_received_image)


class TestRGBDomain(unittest.TestCase):
"""Tests fot bdpy.dl.torch.domain.image_domain.BGRDomain"""

def setUp(self):
self.bgr_image = torch.rand((1, 3, 32, 32))
self.rgb_image = self.bgr_image[:, [2, 1, 0], ...]

def test_send(self):
"""Test send"""
bgr_domain = iamge_domain_module.BGRDomain()
transformed_image = bgr_domain.send(self.bgr_image)
torch.testing.assert_close(transformed_image, self.rgb_image)

def test_receive(self):
"""Tests receive"""
bgr_domain = iamge_domain_module.BGRDomain()
received_image = bgr_domain.receive(self.rgb_image)
torch.testing.assert_close(received_image, self.bgr_image)


class TestPILDomainWithExplicitCrop(unittest.TestCase):
"""Tests fot bdpy.dl.torch.domain.image_domain.PILDomainWithExplicitCrop"""
def setUp(self):
self.expected_transformed_image = torch.rand((1, 3, 32, 32))
self.image = self.expected_transformed_image.permute(0, 2, 3, 1) * 255

def test_send(self):
"""Test send"""
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop()
transformed_image = pdwe_domain.send(self.image)
torch.testing.assert_close(transformed_image, self.expected_transformed_image)

def test_receive(self):
"""Tests receive"""
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop()
with warnings.catch_warnings(record=True) as w:
received_image = pdwe_domain.receive(self.expected_transformed_image)
self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w))
torch.testing.assert_close(received_image, self.image)


class TestFixedResolutionDomain(unittest.TestCase):
"""Tests fot bdpy.dl.torch.domain.image_domain.FixedResolutionDomain"""
def setUp(self):
self.expected_received_image_size = (1, 3, 16, 16)
self.image =torch.rand((1, 3, 32, 32))

def test_send(self):
"""Test send"""
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16))
with self.assertRaises(RuntimeError):
fr_domain.send(self.image)

def test_receive(self):
"""Tests receive"""
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16))

received_image = fr_domain.receive(self.image)
self.assertEqual(received_image.size(), self.expected_received_image_size)


if __name__ == "__main__":
unittest.main()
Loading