-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #85 from KamitaniLab/myaokai/feat_inv_test
Feature inversion task (#81) に関わるPR内のテストコードの作成
- Loading branch information
Showing
9 changed files
with
778 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""Tests for bdpy.dl.torch.domain.core.""" | ||
|
||
import unittest | ||
from bdpy.dl.torch.domain import core as core_module | ||
|
||
|
||
class DummyAddDomain(core_module.Domain): | ||
def send(self, num): | ||
return num + 1 | ||
|
||
def receive(self, num): | ||
return num - 1 | ||
|
||
|
||
class DummyDoubleDomain(core_module.Domain): | ||
def send(self, num): | ||
return num * 2 | ||
|
||
def receive(self, num): | ||
return num // 2 | ||
|
||
|
||
class DummyUpperCaseDomain(core_module.Domain): | ||
def send(self, text): | ||
return text.upper() | ||
|
||
def receive(self, value): | ||
return value.lower() | ||
|
||
|
||
class TestDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.Domain.""" | ||
def setUp(self): | ||
self.domain = DummyAddDomain() | ||
self.original_space_num = 0 | ||
self.internal_space_num = 1 | ||
|
||
def test_instantiation(self): | ||
"""Test instantiation.""" | ||
self.assertRaises(TypeError, core_module.Domain) | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.domain.send(self.original_space_num), self.internal_space_num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.domain.receive(self.internal_space_num), self.original_space_num) | ||
|
||
def test_invertibility(self): | ||
input_candidates = [-1, 0, 1, 0.5] | ||
for x in input_candidates: | ||
assert x == self.domain.send(self.domain.receive(x)) | ||
assert x == self.domain.receive(self.domain.send(x)) | ||
|
||
|
||
class TestInternalDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.InternalDomain.""" | ||
def setUp(self): | ||
self.domain = core_module.InternalDomain() | ||
self.num = 1 | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.domain.send(self.num), self.num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.domain.receive(self.num), self.num) | ||
|
||
def test_invertibility(self): | ||
input_candidates = [-1, 0, 1, 0.5] | ||
for x in input_candidates: | ||
assert x == self.domain.send(self.domain.receive(x)) | ||
assert x == self.domain.receive(self.domain.send(x)) | ||
|
||
|
||
class TestIrreversibleDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.IrreversibleDomain.""" | ||
def setUp(self): | ||
self.domain = core_module.IrreversibleDomain() | ||
self.num = 1 | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.domain.send(self.num), self.num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.domain.receive(self.num), self.num) | ||
|
||
|
||
class TestComposedDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.ComposedDomain.""" | ||
def setUp(self): | ||
self.composed_domain = core_module.ComposedDomain([ | ||
DummyDoubleDomain(), | ||
DummyAddDomain(), | ||
]) | ||
self.original_space_num = 0 | ||
self.internal_space_num = 2 | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.composed_domain.send(self.original_space_num), self.internal_space_num) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.composed_domain.receive(self.internal_space_num), self.original_space_num) | ||
|
||
|
||
class TestKeyValueDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.core.KeyValueDomain.""" | ||
def setUp(self): | ||
self.key_value_domain = core_module.KeyValueDomain({ | ||
"name": DummyUpperCaseDomain(), | ||
"age": DummyDoubleDomain() | ||
}) | ||
self.original_space_data = {"name": "alice", "age": 30} | ||
self.internal_space_data = {"name": "ALICE", "age": 60} | ||
|
||
def test_send(self): | ||
"""test send""" | ||
self.assertEqual(self.key_value_domain.send(self.original_space_data), self.internal_space_data) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
self.assertEqual(self.key_value_domain.receive(self.internal_space_data), self.original_space_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
#unittest.main() | ||
composed_domain = core_module.ComposedDomain([ | ||
DummyDoubleDomain(), | ||
DummyAddDomain(), | ||
]) | ||
print(composed_domain.receive(-1)) | ||
print(composed_domain.send(-2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""Tests for bdpy.dl.torch.domain.feature_domain.""" | ||
|
||
import unittest | ||
import torch | ||
from bdpy.dl.torch.domain import feature_domain as feature_domain_module | ||
|
||
|
||
class TestMethods(unittest.TestCase): | ||
def setUp(self): | ||
self.lnd_tensor = torch.empty((12, 196, 768)) | ||
self.nld_tensor = torch.empty((196, 12, 768)) | ||
|
||
def test_lnd2nld(self): | ||
"""test _lnd2nld""" | ||
self.assertEqual(feature_domain_module._lnd2nld(self.lnd_tensor).shape, self.nld_tensor.shape) | ||
|
||
def test_nld2lnd(self): | ||
"""test _nld2lnd""" | ||
self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape) | ||
|
||
|
||
class TestArbitraryFeatureKeyDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain.""" | ||
def setUp(self): | ||
self.to_internal_mapping = { | ||
"self_key1": "internal_key1", | ||
"self_key2": "internal_key2" | ||
} | ||
self.to_self_mapping = { | ||
"internal_key1": "self_key1", | ||
"internal_key2": "self_key2" | ||
} | ||
self.features = { | ||
"self_key1": 123, | ||
"self_key2": 456 | ||
} | ||
self.internal_features = { | ||
"internal_key1": 123, | ||
"internal_key2": 456 | ||
} | ||
|
||
def test_send(self): | ||
"""test send""" | ||
# when both are specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping, | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.send(self.features), self.internal_features) | ||
|
||
# when only to_self is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.send(self.features), self.internal_features) | ||
|
||
# when only to_internal is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping | ||
) | ||
self.assertEqual(domain.send(self.features), self.internal_features) | ||
|
||
def test_receive(self): | ||
"""test receive""" | ||
# when both are specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping, | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.receive(self.internal_features), self.features) | ||
|
||
# when only to_self is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_self=self.to_self_mapping | ||
) | ||
self.assertEqual(domain.receive(self.internal_features), self.features) | ||
|
||
# when only to_internal is specified | ||
domain = feature_domain_module.ArbitraryFeatureKeyDomain( | ||
to_internal=self.to_internal_mapping | ||
) | ||
self.assertEqual(domain.receive(self.internal_features), self.features) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
"""Tests for bdpy.dl.torch.domain.image_domain.""" | ||
|
||
import unittest | ||
import torch | ||
import numpy as np | ||
import warnings | ||
from bdpy.dl.torch.domain import image_domain as iamge_domain_module | ||
|
||
|
||
class TestAffineDomain(unittest.TestCase): | ||
"""Tests for bdpy.dl.torch.domain.image_domain.AffineDomain""" | ||
def setUp(self): | ||
self.center0d = 0.0 | ||
self.center1d = np.random.randn(3) | ||
self.center2d = np.random.randn(32, 32) | ||
self.center3d = np.random.randn(3, 32, 32) | ||
self.scale0d = 1 | ||
self.scale1d = np.random.randn(3) | ||
self.scale2d = np.random.randn(32, 32) | ||
self.scale3d = np.random.randn(3, 32, 32) | ||
self.image = torch.rand((1, 3, 32, 32)) | ||
|
||
def test_instantiation(self): | ||
"""Test instantiation.""" | ||
# Succeeds when center and scale are 0-dimensional | ||
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d) | ||
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) | ||
|
||
# Succeeds when center and scale are 1-dimensional | ||
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d) | ||
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) | ||
|
||
# Succeeds when center and scale are 3-dimensional | ||
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) | ||
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) | ||
|
||
# Fails when the center is neither 1-dimensional nor 3-dimensional | ||
with self.assertRaises(ValueError): | ||
iamge_domain_module.AffineDomain(self.center2d, self.scale0d) | ||
|
||
# Fails when the scale is neither 1-dimensional nor 3-dimensional | ||
with self.assertRaises(ValueError): | ||
iamge_domain_module.AffineDomain(self.center0d, self.scale2d) | ||
|
||
def test_send_and_receive(self): | ||
"""Test send and receive""" | ||
# when 0d | ||
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d) | ||
transformed_image = affine_domain.send(self.image) | ||
center0d = torch.from_numpy(np.array([self.center0d])[np.newaxis, np.newaxis, np.newaxis]) | ||
scale0d = torch.from_numpy(np.array([self.scale0d])[np.newaxis, np.newaxis, np.newaxis]) | ||
expected_transformed_image = (self.image + center0d) / self.scale0d | ||
torch.testing.assert_close(transformed_image, expected_transformed_image) | ||
received_image = affine_domain.receive(transformed_image) | ||
expected_received_image = expected_transformed_image * scale0d - center0d | ||
torch.testing.assert_close(received_image, expected_received_image) | ||
|
||
# when 1d | ||
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d) | ||
transformed_image = affine_domain.send(self.image) | ||
center1d = self.center1d[np.newaxis, :, np.newaxis, np.newaxis] | ||
scale1d = self.scale1d[np.newaxis, :, np.newaxis, np.newaxis] | ||
expected_transformed_image = (self.image + center1d) / scale1d | ||
torch.testing.assert_close(transformed_image, expected_transformed_image) | ||
received_image = affine_domain.receive(transformed_image) | ||
expected_received_image = expected_transformed_image * scale1d - center1d | ||
torch.testing.assert_close(received_image, expected_received_image) | ||
|
||
# when 3d | ||
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) | ||
transformed_image = affine_domain.send(self.image) | ||
center3d = self.center3d[np.newaxis] | ||
scale3d = self.scale3d[np.newaxis] | ||
expected_transformed_image = (self.image + center3d) / scale3d | ||
torch.testing.assert_close(transformed_image, expected_transformed_image) | ||
received_image = affine_domain.receive(transformed_image) | ||
expected_received_image = expected_transformed_image * scale3d - center3d | ||
torch.testing.assert_close(received_image, expected_received_image) | ||
|
||
|
||
class TestRGBDomain(unittest.TestCase): | ||
"""Tests fot bdpy.dl.torch.domain.image_domain.BGRDomain""" | ||
|
||
def setUp(self): | ||
self.bgr_image = torch.rand((1, 3, 32, 32)) | ||
self.rgb_image = self.bgr_image[:, [2, 1, 0], ...] | ||
|
||
def test_send(self): | ||
"""Test send""" | ||
bgr_domain = iamge_domain_module.BGRDomain() | ||
transformed_image = bgr_domain.send(self.bgr_image) | ||
torch.testing.assert_close(transformed_image, self.rgb_image) | ||
|
||
def test_receive(self): | ||
"""Tests receive""" | ||
bgr_domain = iamge_domain_module.BGRDomain() | ||
received_image = bgr_domain.receive(self.rgb_image) | ||
torch.testing.assert_close(received_image, self.bgr_image) | ||
|
||
|
||
class TestPILDomainWithExplicitCrop(unittest.TestCase): | ||
"""Tests fot bdpy.dl.torch.domain.image_domain.PILDomainWithExplicitCrop""" | ||
def setUp(self): | ||
self.expected_transformed_image = torch.rand((1, 3, 32, 32)) | ||
self.image = self.expected_transformed_image.permute(0, 2, 3, 1) * 255 | ||
|
||
def test_send(self): | ||
"""Test send""" | ||
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop() | ||
transformed_image = pdwe_domain.send(self.image) | ||
torch.testing.assert_close(transformed_image, self.expected_transformed_image) | ||
|
||
def test_receive(self): | ||
"""Tests receive""" | ||
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop() | ||
with warnings.catch_warnings(record=True) as w: | ||
received_image = pdwe_domain.receive(self.expected_transformed_image) | ||
self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w)) | ||
torch.testing.assert_close(received_image, self.image) | ||
|
||
|
||
class TestFixedResolutionDomain(unittest.TestCase): | ||
"""Tests fot bdpy.dl.torch.domain.image_domain.FixedResolutionDomain""" | ||
def setUp(self): | ||
self.expected_received_image_size = (1, 3, 16, 16) | ||
self.image =torch.rand((1, 3, 32, 32)) | ||
|
||
def test_send(self): | ||
"""Test send""" | ||
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16)) | ||
with self.assertRaises(RuntimeError): | ||
fr_domain.send(self.image) | ||
|
||
def test_receive(self): | ||
"""Tests receive""" | ||
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16)) | ||
|
||
received_image = fr_domain.receive(self.image) | ||
self.assertEqual(received_image.size(), self.expected_received_image_size) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.