Skip to content

Commit

Permalink
Merge pull request #85 from KamitaniLab/myaokai/feat_inv_test
Browse files Browse the repository at this point in the history
Feature inversion task (#81) に関わるPR内のテストコードの作成
  • Loading branch information
ShuntaroAoki authored Apr 4, 2024
2 parents 4d3c8af + d86e526 commit ff116ef
Show file tree
Hide file tree
Showing 9 changed files with 778 additions and 1 deletion.
2 changes: 1 addition & 1 deletion bdpy/dl/torch/domain/image_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

if isinstance(center, (float, int)) or center.ndim == 0:
center = np.array([center])[np.newaxis, np.newaxis, np.newaxis]
if center.ndim == 1: # 1D vector (C,)
elif center.ndim == 1: # 1D vector (C,)
center = center[np.newaxis, :, np.newaxis, np.newaxis]
elif center.ndim == 3: # 3D vector (1, C, W, H)
center = center[np.newaxis]
Expand Down
Empty file.
138 changes: 138 additions & 0 deletions tests/dl/torch/domain/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests for bdpy.dl.torch.domain.core."""

import unittest
from bdpy.dl.torch.domain import core as core_module


class DummyAddDomain(core_module.Domain):
def send(self, num):
return num + 1

def receive(self, num):
return num - 1


class DummyDoubleDomain(core_module.Domain):
def send(self, num):
return num * 2

def receive(self, num):
return num // 2


class DummyUpperCaseDomain(core_module.Domain):
def send(self, text):
return text.upper()

def receive(self, value):
return value.lower()


class TestDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.Domain."""
def setUp(self):
self.domain = DummyAddDomain()
self.original_space_num = 0
self.internal_space_num = 1

def test_instantiation(self):
"""Test instantiation."""
self.assertRaises(TypeError, core_module.Domain)

def test_send(self):
"""test send"""
self.assertEqual(self.domain.send(self.original_space_num), self.internal_space_num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.domain.receive(self.internal_space_num), self.original_space_num)

def test_invertibility(self):
input_candidates = [-1, 0, 1, 0.5]
for x in input_candidates:
assert x == self.domain.send(self.domain.receive(x))
assert x == self.domain.receive(self.domain.send(x))


class TestInternalDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.InternalDomain."""
def setUp(self):
self.domain = core_module.InternalDomain()
self.num = 1

def test_send(self):
"""test send"""
self.assertEqual(self.domain.send(self.num), self.num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.domain.receive(self.num), self.num)

def test_invertibility(self):
input_candidates = [-1, 0, 1, 0.5]
for x in input_candidates:
assert x == self.domain.send(self.domain.receive(x))
assert x == self.domain.receive(self.domain.send(x))


class TestIrreversibleDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.IrreversibleDomain."""
def setUp(self):
self.domain = core_module.IrreversibleDomain()
self.num = 1

def test_send(self):
"""test send"""
self.assertEqual(self.domain.send(self.num), self.num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.domain.receive(self.num), self.num)


class TestComposedDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.ComposedDomain."""
def setUp(self):
self.composed_domain = core_module.ComposedDomain([
DummyDoubleDomain(),
DummyAddDomain(),
])
self.original_space_num = 0
self.internal_space_num = 2

def test_send(self):
"""test send"""
self.assertEqual(self.composed_domain.send(self.original_space_num), self.internal_space_num)

def test_receive(self):
"""test receive"""
self.assertEqual(self.composed_domain.receive(self.internal_space_num), self.original_space_num)


class TestKeyValueDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.core.KeyValueDomain."""
def setUp(self):
self.key_value_domain = core_module.KeyValueDomain({
"name": DummyUpperCaseDomain(),
"age": DummyDoubleDomain()
})
self.original_space_data = {"name": "alice", "age": 30}
self.internal_space_data = {"name": "ALICE", "age": 60}

def test_send(self):
"""test send"""
self.assertEqual(self.key_value_domain.send(self.original_space_data), self.internal_space_data)

def test_receive(self):
"""test receive"""
self.assertEqual(self.key_value_domain.receive(self.internal_space_data), self.original_space_data)


if __name__ == "__main__":
#unittest.main()
composed_domain = core_module.ComposedDomain([
DummyDoubleDomain(),
DummyAddDomain(),
])
print(composed_domain.receive(-1))
print(composed_domain.send(-2))
86 changes: 86 additions & 0 deletions tests/dl/torch/domain/test_feature_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for bdpy.dl.torch.domain.feature_domain."""

import unittest
import torch
from bdpy.dl.torch.domain import feature_domain as feature_domain_module


class TestMethods(unittest.TestCase):
def setUp(self):
self.lnd_tensor = torch.empty((12, 196, 768))
self.nld_tensor = torch.empty((196, 12, 768))

def test_lnd2nld(self):
"""test _lnd2nld"""
self.assertEqual(feature_domain_module._lnd2nld(self.lnd_tensor).shape, self.nld_tensor.shape)

def test_nld2lnd(self):
"""test _nld2lnd"""
self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape)


class TestArbitraryFeatureKeyDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain."""
def setUp(self):
self.to_internal_mapping = {
"self_key1": "internal_key1",
"self_key2": "internal_key2"
}
self.to_self_mapping = {
"internal_key1": "self_key1",
"internal_key2": "self_key2"
}
self.features = {
"self_key1": 123,
"self_key2": 456
}
self.internal_features = {
"internal_key1": 123,
"internal_key2": 456
}

def test_send(self):
"""test send"""
# when both are specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping,
to_self=self.to_self_mapping
)
self.assertEqual(domain.send(self.features), self.internal_features)

# when only to_self is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_self=self.to_self_mapping
)
self.assertEqual(domain.send(self.features), self.internal_features)

# when only to_internal is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping
)
self.assertEqual(domain.send(self.features), self.internal_features)

def test_receive(self):
"""test receive"""
# when both are specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping,
to_self=self.to_self_mapping
)
self.assertEqual(domain.receive(self.internal_features), self.features)

# when only to_self is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_self=self.to_self_mapping
)
self.assertEqual(domain.receive(self.internal_features), self.features)

# when only to_internal is specified
domain = feature_domain_module.ArbitraryFeatureKeyDomain(
to_internal=self.to_internal_mapping
)
self.assertEqual(domain.receive(self.internal_features), self.features)


if __name__ == "__main__":
unittest.main()
143 changes: 143 additions & 0 deletions tests/dl/torch/domain/test_image_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Tests for bdpy.dl.torch.domain.image_domain."""

import unittest
import torch
import numpy as np
import warnings
from bdpy.dl.torch.domain import image_domain as iamge_domain_module


class TestAffineDomain(unittest.TestCase):
"""Tests for bdpy.dl.torch.domain.image_domain.AffineDomain"""
def setUp(self):
self.center0d = 0.0
self.center1d = np.random.randn(3)
self.center2d = np.random.randn(32, 32)
self.center3d = np.random.randn(3, 32, 32)
self.scale0d = 1
self.scale1d = np.random.randn(3)
self.scale2d = np.random.randn(32, 32)
self.scale3d = np.random.randn(3, 32, 32)
self.image = torch.rand((1, 3, 32, 32))

def test_instantiation(self):
"""Test instantiation."""
# Succeeds when center and scale are 0-dimensional
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d)
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain)

# Succeeds when center and scale are 1-dimensional
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d)
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain)

# Succeeds when center and scale are 3-dimensional
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d)
self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain)

# Fails when the center is neither 1-dimensional nor 3-dimensional
with self.assertRaises(ValueError):
iamge_domain_module.AffineDomain(self.center2d, self.scale0d)

# Fails when the scale is neither 1-dimensional nor 3-dimensional
with self.assertRaises(ValueError):
iamge_domain_module.AffineDomain(self.center0d, self.scale2d)

def test_send_and_receive(self):
"""Test send and receive"""
# when 0d
affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d)
transformed_image = affine_domain.send(self.image)
center0d = torch.from_numpy(np.array([self.center0d])[np.newaxis, np.newaxis, np.newaxis])
scale0d = torch.from_numpy(np.array([self.scale0d])[np.newaxis, np.newaxis, np.newaxis])
expected_transformed_image = (self.image + center0d) / self.scale0d
torch.testing.assert_close(transformed_image, expected_transformed_image)
received_image = affine_domain.receive(transformed_image)
expected_received_image = expected_transformed_image * scale0d - center0d
torch.testing.assert_close(received_image, expected_received_image)

# when 1d
affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d)
transformed_image = affine_domain.send(self.image)
center1d = self.center1d[np.newaxis, :, np.newaxis, np.newaxis]
scale1d = self.scale1d[np.newaxis, :, np.newaxis, np.newaxis]
expected_transformed_image = (self.image + center1d) / scale1d
torch.testing.assert_close(transformed_image, expected_transformed_image)
received_image = affine_domain.receive(transformed_image)
expected_received_image = expected_transformed_image * scale1d - center1d
torch.testing.assert_close(received_image, expected_received_image)

# when 3d
affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d)
transformed_image = affine_domain.send(self.image)
center3d = self.center3d[np.newaxis]
scale3d = self.scale3d[np.newaxis]
expected_transformed_image = (self.image + center3d) / scale3d
torch.testing.assert_close(transformed_image, expected_transformed_image)
received_image = affine_domain.receive(transformed_image)
expected_received_image = expected_transformed_image * scale3d - center3d
torch.testing.assert_close(received_image, expected_received_image)


class TestRGBDomain(unittest.TestCase):
"""Tests fot bdpy.dl.torch.domain.image_domain.BGRDomain"""

def setUp(self):
self.bgr_image = torch.rand((1, 3, 32, 32))
self.rgb_image = self.bgr_image[:, [2, 1, 0], ...]

def test_send(self):
"""Test send"""
bgr_domain = iamge_domain_module.BGRDomain()
transformed_image = bgr_domain.send(self.bgr_image)
torch.testing.assert_close(transformed_image, self.rgb_image)

def test_receive(self):
"""Tests receive"""
bgr_domain = iamge_domain_module.BGRDomain()
received_image = bgr_domain.receive(self.rgb_image)
torch.testing.assert_close(received_image, self.bgr_image)


class TestPILDomainWithExplicitCrop(unittest.TestCase):
"""Tests fot bdpy.dl.torch.domain.image_domain.PILDomainWithExplicitCrop"""
def setUp(self):
self.expected_transformed_image = torch.rand((1, 3, 32, 32))
self.image = self.expected_transformed_image.permute(0, 2, 3, 1) * 255

def test_send(self):
"""Test send"""
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop()
transformed_image = pdwe_domain.send(self.image)
torch.testing.assert_close(transformed_image, self.expected_transformed_image)

def test_receive(self):
"""Tests receive"""
pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop()
with warnings.catch_warnings(record=True) as w:
received_image = pdwe_domain.receive(self.expected_transformed_image)
self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w))
torch.testing.assert_close(received_image, self.image)


class TestFixedResolutionDomain(unittest.TestCase):
"""Tests fot bdpy.dl.torch.domain.image_domain.FixedResolutionDomain"""
def setUp(self):
self.expected_received_image_size = (1, 3, 16, 16)
self.image =torch.rand((1, 3, 32, 32))

def test_send(self):
"""Test send"""
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16))
with self.assertRaises(RuntimeError):
fr_domain.send(self.image)

def test_receive(self):
"""Tests receive"""
fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16))

received_image = fr_domain.receive(self.image)
self.assertEqual(received_image.size(), self.expected_received_image_size)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit ff116ef

Please sign in to comment.