Skip to content

Commit

Permalink
heavy unet tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 12, 2024
1 parent 460cf2a commit a56dca3
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 11 deletions.
7 changes: 4 additions & 3 deletions dacapo/compute_context/local_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def device(self):
if self._device is None:
if torch.cuda.is_available():
# TODO: make this more sophisticated, for multiple GPUs for instance
free = torch.cuda.mem_get_info()[0] / 1024**3
if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
return torch.device("cpu")
# commented out code below is for checking free memory and falling back on CPU, whhen model in GPU and memory is low model get moved to CPU
# free = torch.cuda.mem_get_info()[0] / 1024**3
# if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
# return torch.device("cpu")
return torch.device("cuda")
# Multiple MPS ops are not available yet : https://github.com/pytorch/pytorch/issues/77764
# got error aten::max_pool3d_with_indices
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ class CNNectomeUNetConfig(ArchitectureConfig):
},
)
batch_norm: bool = attr.ib(
default=True,
default=False,
metadata={"help_text": "Whether to use batch normalization."},
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from funlib.persistence import Array

from typing import Iterable
import logging

logger = logging.getLogger(__name__)

class ThresholdPostProcessor(PostProcessor):
"""
Expand Down Expand Up @@ -135,7 +137,7 @@ def process_block(block):
data = input_array[write_roi] > parameters.threshold
data = data.astype(np.uint8)
if int(data.max()) == 0:
print("No data in block", write_roi)
logger.debug("No data in block", write_roi)
return
output_array[write_roi] = data

Expand Down
12 changes: 9 additions & 3 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ def create_distance_mask(
)
slices = tmp.ndim * (slice(1, -1),)
tmp[slices] = channel_mask
sampling = tuple(float(v) / 2 for v in voxel_size)
sampling = sampling[-len(tmp.shape) :]
boundary_distance = distance_transform_edt(
tmp,
sampling=voxel_size,
sampling=sampling,
)
if self.epsilon is None:
add = 0
Expand Down Expand Up @@ -315,13 +317,17 @@ def process(
distances = np.ones(channel.shape, dtype=np.float32) * max_distance
else:
# get distances (voxel_size/2 because image is doubled)
sampling = tuple(float(v) / 2 for v in voxel_size)
# fixing the sampling for 2D images
if len(boundaries.shape) < len(sampling):
sampling = sampling[-len(boundaries.shape):]
distances = distance_transform_edt(
boundaries, sampling=tuple(float(v) / 2 for v in voxel_size)
boundaries, sampling=sampling
)
distances = distances.astype(np.float32)

# restore original shape
downsample = (slice(None, None, 2),) * len(voxel_size)
downsample = (slice(None, None, 2),) * distances.ndim
distances = distances[downsample]

# todo: inverted distance
Expand Down
7 changes: 4 additions & 3 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# validation_dataset.name,
# criterion,
# )
dataset_iteration_scores.append(
[getattr(scores, criterion) for criterion in scores.criteria]
)
except:
logger.error(
f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.",
Expand All @@ -257,9 +260,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# the evaluator
# array_store.remove(output_array_identifier)

dataset_iteration_scores.append(
[getattr(scores, criterion) for criterion in scores.criteria]
)


iteration_scores.append(dataset_iteration_scores)
# array_store.remove(prediction_array_identifier)
Expand Down
111 changes: 111 additions & 0 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,58 @@
import pytest
from pytest_lazy_fixtures import lf

from dacapo.experiments.run_config import RunConfig

import logging

logging.basicConfig(level=logging.INFO)

from dacapo.experiments.architectures import DummyArchitectureConfig, CNNectomeUNetConfig

import pytest


def unet_architecture(batch_norm, upsample,use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
name = f"{name}_up" if upsample else name
name = f"{name}_att" if use_attention else name

if three_d:
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
else:
return CNNectomeUNetConfig(
name=name,
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=2,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(1, 4, 4), (1, 4, 4)],
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
batch_norm=batch_norm,
use_attention=use_attention,
upsample_factors=[(1, 2, 2)] if upsample else [],
)



# skip the test for the Apple Paravirtual device
# that does not support Metal 2.0
Expand Down Expand Up @@ -59,3 +107,66 @@ def test_train(
training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations


@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")])
@pytest.mark.parametrize("task", [lf("distance_task")])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("upsample", [True, False])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_train_unet(
datasplit,
task,
trainer,
batch_norm,
upsample,
use_attention,
three_d):

store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

architecture_config = unet_architecture(batch_norm, upsample,use_attention, three_d)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
task_config=task,
architecture_config=architecture_config,
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=2,
)
try:
store.store_run_config(run_config)
except Exception as e:
store.delete_run_config(run_config.name)
store.store_run_config(run_config)

run = Run(run_config)

# -------------------------------------

# train

weights_store.store_weights(run, 0)
train_run(run)

init_weights = weights_store.retrieve_weights(run.name, 0)
final_weights = weights_store.retrieve_weights(run.name, run.train_until)

for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).sum()
assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff

# assert train_stats and validation_scores are available

training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations



0 comments on commit a56dca3

Please sign in to comment.