diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 3ea282acc..23f9a14fe 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,5 +1,5 @@ from .db import options -from .architectures import dummy_architecture +from .architectures import dummy_architecture, unet_architecture from .arrays import dummy_array, zarr_array, cellmap_array from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit from .evaluators import binary_3_channel_evaluator diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 6980c8f6b..e940e5aed 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -1,4 +1,7 @@ -from dacapo.experiments.architectures import DummyArchitectureConfig +from dacapo.experiments.architectures import ( + DummyArchitectureConfig, + CNNectomeUNetConfig, +) import pytest @@ -8,3 +11,21 @@ def dummy_architecture(): yield DummyArchitectureConfig( name="dummy_architecture", num_in_channels=1, num_out_channels=12 ) + + +@pytest.fixture() +def unet_architecture(): + yield CNNectomeUNetConfig( + name="tmp_unet_architecture", + input_shape=(2, 132, 132), + eval_shape_increase=(8, 32, 32), + fmaps_in=2, + num_fmaps=8, + fmaps_out=8, + fmap_inc_factor=2, + downsample_factors=[(1, 4, 4), (1, 4, 4)], + kernel_size_down=[[(1, 3, 3)] * 2] * 3, + kernel_size_up=[[(1, 3, 3)] * 2] * 2, + constant_upsample=True, + padding="valid", + ) diff --git a/tests/operations/test_architecture.py b/tests/operations/test_architecture.py new file mode 100644 index 000000000..5ba387b44 --- /dev/null +++ b/tests/operations/test_architecture.py @@ -0,0 +1,26 @@ +from ..fixtures import * + +import pytest +from pytest_lazy_fixtures import lf + +import logging + +logging.basicConfig(level=logging.INFO) + + +@pytest.mark.parametrize( + "architecture_config", + [ + lf("dummy_architecture"), + lf("unet_architecture"), + ], +) +def test_architecture( + architecture_config, +): + + architecture_type = architecture_config.architecture_type + + architecture = architecture_type(architecture_config) + + assert architecture.dims is not None, f"Architecture dims are None {architecture}" diff --git a/tests/operations/test_context.py b/tests/operations/test_context.py index b7f70500e..b2924e721 100644 --- a/tests/operations/test_context.py +++ b/tests/operations/test_context.py @@ -3,12 +3,20 @@ import pytest -@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("device", [""]) def test_create_compute_context(device): compute_context = create_compute_context() assert compute_context is not None assert compute_context.device is not None if torch.cuda.is_available(): - assert compute_context.device == torch.device('cuda'), "Model is not on CUDA when CUDA is available {}".format(compute_context.device) + assert compute_context.device == torch.device( + "cuda" + ), "Model is not on CUDA when CUDA is available {}".format( + compute_context.device + ) else: - assert compute_context.device == torch.device('cpu'), "Model is not on CPU when CUDA is not available {}".format(compute_context.device) \ No newline at end of file + assert compute_context.device == torch.device( + "cpu" + ), "Model is not on CPU when CUDA is not available {}".format( + compute_context.device + )