Skip to content

Commit

Permalink
add validation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 14, 2024
1 parent ec4b5d3 commit dd39075
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
5 changes: 5 additions & 0 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def predict(
compute_context = create_compute_context()
device = compute_context.device

model_device = next(model.parameters()).device

assert model_device == device, f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
raw_input = raw_array.to_ndarray(block.read_roi)

Expand All @@ -85,6 +89,7 @@ def predict_fn(block):
raw_input = np.expand_dims(raw_input, 0)
axis_names = ["c^"] + axis_names


with torch.no_grad():
model.eval()
predictions = (
Expand Down
32 changes: 32 additions & 0 deletions tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from pytest_lazy_fixtures import lf
import torch.nn as nn

import logging

Expand All @@ -27,6 +28,7 @@ def test_architecture(
[
lf("dummy_architecture"),
lf("unet_architecture"),
lf("unet_3d_architecture"),
],
)
def test_stored_architecture(
Expand All @@ -48,3 +50,33 @@ def test_stored_architecture(
architecture = retrieved_arch_config.architecture_type(retrieved_arch_config)

assert architecture.dims is not None, f"Architecture dims are None {architecture}"


@pytest.mark.parametrize(
"architecture_config",
[
lf("unet_architecture"),
],
)
def test_2d_conv_unet(
architecture_config,
):
architecture = architecture_config.architecture_type(architecture_config)
for name, module in architecture.named_modules():
if isinstance(module, nn.Conv3d):
raise ValueError(f"Conv3d found in 2d unet {name}")


@pytest.mark.parametrize(
"architecture_config",
[
lf("unet_3d_architecture"),
],
)
def test_2d_conv_unet(
architecture_config,
):
architecture = architecture_config.architecture_type(architecture_config)
for name, module in architecture.named_modules():
if isinstance(module, nn.Conv2d):
raise ValueError(f"Conv2d found in 3d unet {name}")
2 changes: 1 addition & 1 deletion tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ def test_validate_unet(datasplit, task, trainer, architecture):

# -------------------------------------
weights_store.store_weights(run, 0)
validate_run(run, 0)
validate(run.name, 0)

0 comments on commit dd39075

Please sign in to comment.