Skip to content

Commit

Permalink
error handling 2d unet
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 14, 2024
1 parent dd39075 commit 4a18c42
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 80 deletions.
4 changes: 2 additions & 2 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def predict(
compute_context = create_compute_context()
device = compute_context.device

model_device = next(model.parameters()).device
model_device = str(next(model.parameters()).device).split(":")[0]

assert model_device == device, f"Model is not on the right device, Model: {model_device}, Compute device: {device}"
assert model_device == str(device), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
raw_input = raw_array.to_ndarray(block.read_roi)
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .losses import dummy_loss
from .post_processors import argmax, threshold
from .predictors import distance_predictor, onehot_predictor
from .runs import dummy_run, distance_run, onehot_run
from .runs import dummy_run, distance_run, onehot_run, unet_2d_distance_run, unet_3d_distance_run
from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task
from .trainers import dummy_trainer, gunpowder_trainer
22 changes: 11 additions & 11 deletions tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def dummy_architecture():
def unet_architecture():
yield CNNectomeUNetConfig(
name="tmp_unet_architecture",
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
input_shape=(132, 132),
eval_shape_increase=(32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(1, 4, 4), (1, 4, 4)],
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
downsample_factors=[(4, 4), (4, 4)],
kernel_size_down=[[(3, 3)] * 2] * 3,
kernel_size_up=[[(3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
)
Expand Down Expand Up @@ -70,18 +70,18 @@ def unet_architecture_builder(batch_norm, upsample, use_attention, three_d):
else:
return CNNectomeUNetConfig(
name=name,
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
input_shape=(132, 132),
eval_shape_increase=(32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(1, 4, 4), (1, 4, 4)],
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
downsample_factors=[(4, 4), ( 4, 4)],
kernel_size_down=[[( 3, 3)] * 2] * 3,
kernel_size_up=[[(3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
batch_norm=batch_norm,
use_attention=use_attention,
upsample_factors=[(1, 2, 2)] if upsample else [],
upsample_factors=[(2, 2)] if upsample else [],
)
43 changes: 40 additions & 3 deletions tests/fixtures/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def distance_run(
trainer_config=gunpowder_trainer,
datasplit_config=six_class_datasplit,
repetition=0,
num_iterations=100,
num_iterations=10,
)


Expand All @@ -35,7 +35,7 @@ def dummy_run(
trainer_config=dummy_trainer,
datasplit_config=dummy_datasplit,
repetition=0,
num_iterations=100,
num_iterations=10,
)


Expand All @@ -53,5 +53,42 @@ def onehot_run(
trainer_config=gunpowder_trainer,
datasplit_config=twelve_class_datasplit,
repetition=0,
num_iterations=100,
num_iterations=10,
)


@pytest.fixture()
def unet_2d_distance_run(
six_class_datasplit,
unet_architecture,
distance_task,
gunpowder_trainer,
):
yield RunConfig(
name="unet_2d_distance_run",
task_config=distance_task,
architecture_config=unet_architecture,
trainer_config=gunpowder_trainer,
datasplit_config=six_class_datasplit,
repetition=0,
num_iterations=10,
)



@pytest.fixture()
def unet_3d_distance_run(
six_class_datasplit,
unet_3d_architecture,
distance_task,
gunpowder_trainer,
):
yield RunConfig(
name="unet_3d_distance_run",
task_config=distance_task,
architecture_config=unet_3d_architecture,
trainer_config=gunpowder_trainer,
datasplit_config=six_class_datasplit,
repetition=0,
num_iterations=10,
)
37 changes: 35 additions & 2 deletions tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pytest_lazy_fixtures import lf
import torch.nn as nn

from dacapo.experiments import Run
import logging

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_stored_architecture(
lf("unet_architecture"),
],
)
def test_2d_conv_unet(
def test_3d_conv_unet(
architecture_config,
):
architecture = architecture_config.architecture_type(architecture_config)
Expand All @@ -80,3 +80,36 @@ def test_2d_conv_unet(
for name, module in architecture.named_modules():
if isinstance(module, nn.Conv2d):
raise ValueError(f"Conv2d found in 3d unet {name}")



@pytest.mark.parametrize(
"run_config",
[
lf("unet_2d_distance_run"),
],
)
def test_2d_conv_unet_in_run(
run_config,
):
run = Run(run_config)
model = run.model
for name, module in model.named_modules():
if isinstance(module, nn.Conv3d):
raise ValueError(f"Conv3d found in 2d unet {name}")


@pytest.mark.parametrize(
"run_config",
[
lf("unet_3d_distance_run"),
],
)
def test_3d_conv_unet_in_run(
run_config,
):
run = Run(run_config)
model = run.model
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
raise ValueError(f"Conv2d found in 3d unet {name}")
120 changes: 61 additions & 59 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ def test_train(
@pytest.mark.parametrize("upsample", [False])
@pytest.mark.parametrize("use_attention", [False])
@pytest.mark.parametrize("three_d", [False])
def test_train_unet(
def test_train_non_stored_unet(
datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
architecture_config = unet_architecture_builder(
batch_norm, upsample, use_attention, three_d
)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
name=f"{architecture_config.name}_run_v",
task_config=task,
architecture_config=architecture_config,
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
num_iterations=1,
)
run = Run(run_config)
train_run(run)
Expand All @@ -110,14 +110,15 @@ def test_train_unet(
batch_norm, upsample, use_attention, three_d
)


run_config = RunConfig(
name=f"{architecture_config.name}_run",
task_config=task,
architecture_config=architecture_config,
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
num_iterations=1,
)
try:
store.store_run_config(run_config)
Expand All @@ -135,7 +136,7 @@ def test_train_unet(
train_run(run)

init_weights = weights_store.retrieve_weights(run.name, 0)
final_weights = weights_store.retrieve_weights(run.name, run.train_until)
final_weights = weights_store.retrieve_weights(run.name, 1)

for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).any()
Expand All @@ -148,57 +149,58 @@ def test_train_unet(
assert training_stats.trained_until() == run_config.num_iterations


@pytest.mark.parametrize("upsample_datasplit", [lf("upsample_six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("upsample", [True])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_upsample_train_unet(
upsample_datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

architecture_config = unet_architecture_builder(
batch_norm, upsample, use_attention, three_d
)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
task_config=task,
architecture_config=architecture_config,
trainer_config=trainer,
datasplit_config=upsample_datasplit,
repetition=0,
num_iterations=2,
)
try:
store.store_run_config(run_config)
except Exception as e:
store.delete_run_config(run_config.name)
store.store_run_config(run_config)

run = Run(run_config)

# -------------------------------------

# train

weights_store.store_weights(run, 0)
train_run(run)

init_weights = weights_store.retrieve_weights(run.name, 0)
final_weights = weights_store.retrieve_weights(run.name, run.train_until)

for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).any()
assert weight_diff != 0, "Weights did not change"

# assert train_stats and validation_scores are available

training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations
# @pytest.mark.parametrize("upsample_datasplit", [lf("upsample_six_class_datasplit")])
# @pytest.mark.parametrize("task", [lf("distance_task")])
# @pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
# @pytest.mark.parametrize("batch_norm", [True, False])
# @pytest.mark.parametrize("upsample", [True])
# @pytest.mark.parametrize("use_attention", [True, False])
# @pytest.mark.parametrize("three_d", [True, False])
# def test_upsample_train_unet(
# upsample_datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
# ):
# store = create_config_store()
# stats_store = create_stats_store()
# weights_store = create_weights_store()

# architecture_config = unet_architecture_builder(
# batch_norm, upsample, use_attention, three_d
# )

# run_config = RunConfig(
# name=f"{architecture_config.name}_run",
# task_config=task,
# architecture_config=architecture_config,
# trainer_config=trainer,
# datasplit_config=upsample_datasplit,
# repetition=0,
# num_iterations=1,
# )
# try:
# store.store_run_config(run_config)
# except Exception as e:
# store.delete_run_config(run_config.name)
# store.store_run_config(run_config)

# run = Run(run_config)

# # -------------------------------------

# # train

# weights_store.store_weights(run, 0)
# train_run(run)
# # weights_store.store_weights(run, run.train_until)

# init_weights = weights_store.retrieve_weights(run.name, 0)
# final_weights = weights_store.retrieve_weights(run.name, 1)

# for name, weight in init_weights.model.items():
# weight_diff = (weight - final_weights.model[name]).any()
# assert weight_diff != 0, "Weights did not change"

# # assert train_stats and validation_scores are available

# training_stats = stats_store.retrieve_training_stats(run_config.name)

# assert training_stats.trained_until() == run_config.num_iterations
4 changes: 2 additions & 2 deletions tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def test_validate_unet(datasplit, task, trainer, architecture):
weights_store = create_weights_store()

run_config = RunConfig(
name=f"{architecture.name}_run",
name=f"{architecture.name}_run_validate",
task_config=task,
architecture_config=architecture,
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
num_iterations=10,
)
try:
store.store_run_config(run_config)
Expand Down

0 comments on commit 4a18c42

Please sign in to comment.