-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature inversion task (#81) に関わるPR内のテストコードの作成 #85
Changes from all commits
49b3542
fb203f1
3210670
31d3ba5
4f367b5
3660f40
33f2c9f
b03b68b
6cef2f2
547d337
b038a75
6605f28
7cbfc0e
ca3b162
948d5a5
bd2e180
3f69381
291807f
c87d8bc
7f0c323
13da19f
46d4d10
75a45ab
bbae3d7
556e6f0
7ddd4fb
d86e526
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""Tests for bdpy.dl.torch.domain.core.""" | ||
|
||
import unittest | ||
from bdpy.dl.torch.domain import core as core_module | ||
|
||
|
||
class DummyAddDomain(core_module.Domain): | ||
def send(self, num): | ||
return num + 1 | ||
|
||
def receive(self, num): | ||
return num - 1 | ||
|
||
myaokai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class DummyDoubleDomain(core_module.Domain): | ||
def send(self, num): | ||
return num * 2 | ||
|
||
def receive(self, num): | ||
return num // 2 | ||
|
||
|
||
class DummyUpperCaseDomain(core_module.Domain): | ||
def send(self, text): | ||
return text.upper() | ||
|
||
def receive(self, value): | ||
return value.lower() | ||
|
||
|
||
class TestDomain(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you very much for the implementations of the test cases for the domain module. I think basically the current implementations already cover the minimum test cases. But I would apppreciate it if you could add test cases which check the invertibility of the function
|
||
"""Tests for bdpy.dl.torch.domain.core.Domain.""" | ||
def setUp(self): | ||
self.domain = DummyAddDomain() | ||
self.original_space_num = 0 | ||
self.internal_space_num = 1 | ||
|
||
def test_instantiation(self): | ||
"""Test instantiation.""" | ||
self.assertRaises(TypeError, core_module.Domain) | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.domain.send(self.original_space_num), self.internal_space_num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.domain.receive(self.internal_space_num), self.original_space_num) | ||
|
||
def test_invertibility(self): | ||
input_candidates = [-1, 0, 1, 0.5] | ||
for x in input_candidates: | ||
assert x == self.domain.send(self.domain.receive(x)) | ||
assert x == self.domain.receive(self.domain.send(x)) | ||
|
||
|
||
class TestInternalDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.InternalDomain.""" | ||
def setUp(self): | ||
self.domain = core_module.InternalDomain() | ||
self.num = 1 | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.domain.send(self.num), self.num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.domain.receive(self.num), self.num) | ||
|
||
def test_invertibility(self): | ||
input_candidates = [-1, 0, 1, 0.5] | ||
for x in input_candidates: | ||
assert x == self.domain.send(self.domain.receive(x)) | ||
assert x == self.domain.receive(self.domain.send(x)) | ||
|
||
|
||
class TestIrreversibleDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.IrreversibleDomain.""" | ||
def setUp(self): | ||
self.domain = core_module.IrreversibleDomain() | ||
self.num = 1 | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.domain.send(self.num), self.num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.domain.receive(self.num), self.num) | ||
|
||
|
||
class TestComposedDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.ComposedDomain.""" | ||
def setUp(self): | ||
self.composed_domain = core_module.ComposedDomain([ | ||
DummyDoubleDomain(), | ||
DummyAddDomain(), | ||
]) | ||
self.original_space_num = 0 | ||
self.internal_space_num = 2 | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.composed_domain.send(self.original_space_num), self.internal_space_num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.composed_domain.receive(self.internal_space_num), self.original_space_num) | ||
|
||
|
||
class TestKeyValueDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.KeyValueDomain.""" | ||
def setUp(self): | ||
self.key_value_domain = core_module.KeyValueDomain({ | ||
"name": DummyUpperCaseDomain(), | ||
"age": DummyDoubleDomain() | ||
}) | ||
self.original_space_data = {"name": "alice", "age": 30} | ||
self.internal_space_data = {"name": "ALICE", "age": 60} | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.key_value_domain.send(self.original_space_data), self.internal_space_data) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.key_value_domain.receive(self.internal_space_data), self.original_space_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
#unittest.main() | ||
composed_domain = core_module.ComposedDomain([ | ||
DummyDoubleDomain(), | ||
DummyAddDomain(), | ||
]) | ||
print(composed_domain.receive(-1)) | ||
print(composed_domain.send(-2)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""Tests for bdpy.dl.torch.domain.feature_domain.""" | ||
|
||
import unittest | ||
import torch | ||
from bdpy.dl.torch.domain import feature_domain as feature_domain_module | ||
|
||
|
||
class TestMethods(unittest.TestCase): | ||
def setUp(self): | ||
self.lnd_tensor = torch.empty((12, 196, 768)) | ||
self.nld_tensor = torch.empty((196, 12, 768)) | ||
|
||
def test_lnd2nld(self): | ||
"""test _lnd2nld""" | ||
self.assertEqual(feature_domain_module._lnd2nld(self.lnd_tensor).shape, self.nld_tensor.shape) | ||
|
||
def test_nld2lnd(self): | ||
"""test _nld2lnd""" | ||
self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape) | ||
|
||
|
||
class TestArbitraryFeatureKeyDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain.""" | ||
def setUp(self): | ||
self.to_internal_mapping = { | ||
"self_key1": "internal_key1", | ||
"self_key2": "internal_key2" | ||
} | ||
self.to_self_mapping = { | ||
"internal_key1": "self_key1", | ||
"internal_key2": "self_key2" | ||
} | ||
self.features = { | ||
"self_key1": 123, | ||
"self_key2": 456 | ||
} | ||
self.internal_features = { | ||
"internal_key1": 123, | ||
"internal_key2": 456 | ||
} | ||
|
||
def test_send(self): | ||
"""test send""" | ||
# when both are specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping, | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.send(self.features), self.internal_features) | ||
|
||
# when only to_self is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.send(self.features), self.internal_features) | ||
|
||
# when only to_internal is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping | ||
) | ||
self.assertEqual(domain.send(self.features), self.internal_features) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
# when both are specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping, | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.receive(self.internal_features), self.features) | ||
|
||
# when only to_self is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.receive(self.internal_features), self.features) | ||
|
||
# when only to_internal is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping | ||
) | ||
self.assertEqual(domain.receive(self.internal_features), self.features) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
"""Tests for bdpy.dl.torch.domain.image_domain.""" | ||
|
||
import unittest | ||
import torch | ||
import numpy as np | ||
import warnings | ||
from bdpy.dl.torch.domain import image_domain as iamge_domain_module | ||
|
||
|
||
class TestAffineDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.image_domain.AffineDomain""" | ||
def setUp(self): | ||
self.center0d = 0.0 | ||
self.center1d = np.random.randn(3) | ||
self.center2d = np.random.randn(32, 32) | ||
self.center3d = np.random.randn(3, 32, 32) | ||
self.scale0d = 1 | ||
self.scale1d = np.random.randn(3) | ||
self.scale2d = np.random.randn(32, 32) | ||
self.scale3d = np.random.randn(3, 32, 32) | ||
self.image = torch.rand((1, 3, 32, 32)) | ||
|
||
def test_instantiation(self): | ||
"""Test instantiation.""" | ||
# Succeeds when center and scale are 0-dimensional | ||
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d) | ||
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) | ||
|
||
# Succeeds when center and scale are 1-dimensional | ||
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d) | ||
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) | ||
|
||
# Succeeds when center and scale are 3-dimensional | ||
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) | ||
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) | ||
|
||
# Fails when the center is neither 1-dimensional nor 3-dimensional | ||
with self.assertRaises(ValueError): | ||
iamge_domain_module.AffineDomain(self.center2d, self.scale0d) | ||
|
||
# Fails when the scale is neither 1-dimensional nor 3-dimensional | ||
with self.assertRaises(ValueError): | ||
iamge_domain_module.AffineDomain(self.center0d, self.scale2d) | ||
|
||
def test_send_and_receive(self): | ||
"""Test send and receive""" | ||
# when 0d | ||
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d) | ||
transformed_image = affine_domain.send(self.image) | ||
center0d = torch.from_numpy(np.array([self.center0d])[np.newaxis, np.newaxis, np.newaxis]) | ||
scale0d = torch.from_numpy(np.array([self.scale0d])[np.newaxis, np.newaxis, np.newaxis]) | ||
expected_transformed_image = (self.image + center0d) / self.scale0d | ||
torch.testing.assert_close(transformed_image, expected_transformed_image) | ||
received_image = affine_domain.receive(transformed_image) | ||
expected_received_image = expected_transformed_image * scale0d - center0d | ||
torch.testing.assert_close(received_image, expected_received_image) | ||
|
||
# when 1d | ||
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d) | ||
transformed_image = affine_domain.send(self.image) | ||
center1d = self.center1d[np.newaxis, :, np.newaxis, np.newaxis] | ||
scale1d = self.scale1d[np.newaxis, :, np.newaxis, np.newaxis] | ||
expected_transformed_image = (self.image + center1d) / scale1d | ||
torch.testing.assert_close(transformed_image, expected_transformed_image) | ||
received_image = affine_domain.receive(transformed_image) | ||
expected_received_image = expected_transformed_image * scale1d - center1d | ||
torch.testing.assert_close(received_image, expected_received_image) | ||
|
||
# when 3d | ||
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) | ||
transformed_image = affine_domain.send(self.image) | ||
center3d = self.center3d[np.newaxis] | ||
scale3d = self.scale3d[np.newaxis] | ||
expected_transformed_image = (self.image + center3d) / scale3d | ||
torch.testing.assert_close(transformed_image, expected_transformed_image) | ||
received_image = affine_domain.receive(transformed_image) | ||
expected_received_image = expected_transformed_image * scale3d - center3d | ||
torch.testing.assert_close(received_image, expected_received_image) | ||
|
||
|
||
class TestRGBDomain(unittest.TestCase): | ||
"""Tests fot bdpy.dl.torch.domain.image_domain.BGRDomain""" | ||
|
||
def setUp(self): | ||
self.bgr_image = torch.rand((1, 3, 32, 32)) | ||
self.rgb_image = self.bgr_image[:, [2, 1, 0], ...] | ||
|
||
def test_send(self): | ||
"""Test send""" | ||
bgr_domain = iamge_domain_module.BGRDomain() | ||
transformed_image = bgr_domain.send(self.bgr_image) | ||
torch.testing.assert_close(transformed_image, self.rgb_image) | ||
|
||
def test_receive(self): | ||
"""Tests receive""" | ||
bgr_domain = iamge_domain_module.BGRDomain() | ||
received_image = bgr_domain.receive(self.rgb_image) | ||
torch.testing.assert_close(received_image, self.bgr_image) | ||
|
||
|
||
class TestPILDomainWithExplicitCrop(unittest.TestCase): | ||
"""Tests fot bdpy.dl.torch.domain.image_domain.PILDomainWithExplicitCrop""" | ||
def setUp(self): | ||
self.expected_transformed_image = torch.rand((1, 3, 32, 32)) | ||
self.image = self.expected_transformed_image.permute(0, 2, 3, 1) * 255 | ||
|
||
def test_send(self): | ||
"""Test send""" | ||
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop() | ||
transformed_image = pdwe_domain.send(self.image) | ||
torch.testing.assert_close(transformed_image, self.expected_transformed_image) | ||
|
||
def test_receive(self): | ||
"""Tests receive""" | ||
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop() | ||
with warnings.catch_warnings(record=True) as w: | ||
received_image = pdwe_domain.receive(self.expected_transformed_image) | ||
self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w)) | ||
torch.testing.assert_close(received_image, self.image) | ||
|
||
|
||
class TestFixedResolutionDomain(unittest.TestCase): | ||
"""Tests fot bdpy.dl.torch.domain.image_domain.FixedResolutionDomain""" | ||
def setUp(self): | ||
self.expected_received_image_size = (1, 3, 16, 16) | ||
self.image =torch.rand((1, 3, 32, 32)) | ||
|
||
def test_send(self): | ||
"""Test send""" | ||
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16)) | ||
with self.assertRaises(RuntimeError): | ||
fr_domain.send(self.image) | ||
|
||
def test_receive(self): | ||
"""Tests receive""" | ||
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16)) | ||
|
||
received_image = fr_domain.receive(self.image) | ||
self.assertEqual(received_image.size(), self.expected_received_image_size) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you very much for fixing the potential bug I made 🙇