From dd39075af8f7466f88021b81eac9d77ce6a262cd Mon Sep 17 00:00:00 2001 From: mzouink Date: Thu, 14 Nov 2024 10:53:56 -0500 Subject: [PATCH] add validation tests --- dacapo/predict_local.py | 5 +++++ tests/operations/test_architecture.py | 32 +++++++++++++++++++++++++++ tests/operations/test_validate.py | 2 +- 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/dacapo/predict_local.py b/dacapo/predict_local.py index 674d00a40..875bed602 100644 --- a/dacapo/predict_local.py +++ b/dacapo/predict_local.py @@ -71,6 +71,10 @@ def predict( compute_context = create_compute_context() device = compute_context.device + model_device = next(model.parameters()).device + + assert model_device == device, f"Model is not on the right device, Model: {model_device}, Compute device: {device}" + def predict_fn(block): raw_input = raw_array.to_ndarray(block.read_roi) @@ -85,6 +89,7 @@ def predict_fn(block): raw_input = np.expand_dims(raw_input, 0) axis_names = ["c^"] + axis_names + with torch.no_grad(): model.eval() predictions = ( diff --git a/tests/operations/test_architecture.py b/tests/operations/test_architecture.py index de5d44f61..272c410fc 100644 --- a/tests/operations/test_architecture.py +++ b/tests/operations/test_architecture.py @@ -2,6 +2,7 @@ import pytest from pytest_lazy_fixtures import lf +import torch.nn as nn import logging @@ -27,6 +28,7 @@ def test_architecture( [ lf("dummy_architecture"), lf("unet_architecture"), + lf("unet_3d_architecture"), ], ) def test_stored_architecture( @@ -48,3 +50,33 @@ def test_stored_architecture( architecture = retrieved_arch_config.architecture_type(retrieved_arch_config) assert architecture.dims is not None, f"Architecture dims are None {architecture}" + + +@pytest.mark.parametrize( + "architecture_config", + [ + lf("unet_architecture"), + ], +) +def test_2d_conv_unet( + architecture_config, +): + architecture = architecture_config.architecture_type(architecture_config) + for name, module in architecture.named_modules(): + if isinstance(module, nn.Conv3d): + raise ValueError(f"Conv3d found in 2d unet {name}") + + +@pytest.mark.parametrize( + "architecture_config", + [ + lf("unet_3d_architecture"), + ], +) +def test_2d_conv_unet( + architecture_config, +): + architecture = architecture_config.architecture_type(architecture_config) + for name, module in architecture.named_modules(): + if isinstance(module, nn.Conv2d): + raise ValueError(f"Conv2d found in 3d unet {name}") diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 1ae80cd01..78319654c 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -130,4 +130,4 @@ def test_validate_unet(datasplit, task, trainer, architecture): # ------------------------------------- weights_store.store_weights(run, 0) - validate_run(run, 0) + validate(run.name, 0)