diff --git a/dacapo/compute_context/local_torch.py b/dacapo/compute_context/local_torch.py index a547b7dd7..08e813712 100644 --- a/dacapo/compute_context/local_torch.py +++ b/dacapo/compute_context/local_torch.py @@ -56,9 +56,10 @@ def device(self): if self._device is None: if torch.cuda.is_available(): # TODO: make this more sophisticated, for multiple GPUs for instance - free = torch.cuda.mem_get_info()[0] / 1024**3 - if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM - return torch.device("cpu") + # commented out code below is for checking free memory and falling back on CPU, whhen model in GPU and memory is low model get moved to CPU + # free = torch.cuda.mem_get_info()[0] / 1024**3 + # if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM + # return torch.device("cpu") return torch.device("cuda") # Multiple MPS ops are not available yet : https://github.com/pytorch/pytorch/issues/77764 # got error aten::max_pool3d_with_indices diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 7eab80115..643921386 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -128,6 +128,6 @@ class CNNectomeUNetConfig(ArchitectureConfig): }, ) batch_norm: bool = attr.ib( - default=True, + default=False, metadata={"help_text": "Whether to use batch normalization."}, ) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 59059e516..38bfda883 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -18,7 +18,9 @@ from funlib.persistence import Array from typing import Iterable +import logging +logger = logging.getLogger(__name__) class ThresholdPostProcessor(PostProcessor): """ @@ -135,7 +137,7 @@ def process_block(block): data = input_array[write_roi] > parameters.threshold data = data.astype(np.uint8) if int(data.max()) == 0: - print("No data in block", write_roi) + logger.debug("No data in block", write_roi) return output_array[write_roi] = data diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 861a9e1dd..568deae53 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -231,9 +231,11 @@ def create_distance_mask( ) slices = tmp.ndim * (slice(1, -1),) tmp[slices] = channel_mask + sampling = tuple(float(v) / 2 for v in voxel_size) + sampling = sampling[-len(tmp.shape) :] boundary_distance = distance_transform_edt( tmp, - sampling=voxel_size, + sampling=sampling, ) if self.epsilon is None: add = 0 @@ -315,13 +317,17 @@ def process( distances = np.ones(channel.shape, dtype=np.float32) * max_distance else: # get distances (voxel_size/2 because image is doubled) + sampling = tuple(float(v) / 2 for v in voxel_size) + # fixing the sampling for 2D images + if len(boundaries.shape) < len(sampling): + sampling = sampling[-len(boundaries.shape):] distances = distance_transform_edt( - boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + boundaries, sampling=sampling ) distances = distances.astype(np.float32) # restore original shape - downsample = (slice(None, None, 2),) * len(voxel_size) + downsample = (slice(None, None, 2),) * distances.ndim distances = distances[downsample] # todo: inverted distance diff --git a/dacapo/validate.py b/dacapo/validate.py index 4e091ff55..b826a6dbd 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -246,6 +246,9 @@ def validate_run(run: Run, iteration: int, datasets_config=None): # validation_dataset.name, # criterion, # ) + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) except: logger.error( f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.", @@ -257,9 +260,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): # the evaluator # array_store.remove(output_array_identifier) - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) + iteration_scores.append(dataset_iteration_scores) # array_store.remove(prediction_array_identifier) diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index a852101be..374621c6b 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -9,10 +9,58 @@ import pytest from pytest_lazy_fixtures import lf +from dacapo.experiments.run_config import RunConfig + import logging logging.basicConfig(level=logging.INFO) +from dacapo.experiments.architectures import DummyArchitectureConfig, CNNectomeUNetConfig + +import pytest + + +def unet_architecture(batch_norm, upsample,use_attention, three_d): + name = "3d_unet" if three_d else "2d_unet" + name = f"{name}_bn" if batch_norm else name + name = f"{name}_up" if upsample else name + name = f"{name}_att" if use_attention else name + + if three_d: + return CNNectomeUNetConfig( + name=name, + input_shape=(188, 188, 188), + eval_shape_increase=(72, 72, 72), + fmaps_in=1, + num_fmaps=6, + fmaps_out=6, + fmap_inc_factor=2, + downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + constant_upsample=True, + upsample_factors=[(2, 2, 2)] if upsample else [], + batch_norm=batch_norm, + use_attention=use_attention, + ) + else: + return CNNectomeUNetConfig( + name=name, + input_shape=(2, 132, 132), + eval_shape_increase=(8, 32, 32), + fmaps_in=2, + num_fmaps=8, + fmaps_out=8, + fmap_inc_factor=2, + downsample_factors=[(1, 4, 4), (1, 4, 4)], + kernel_size_down=[[(1, 3, 3)] * 2] * 3, + kernel_size_up=[[(1, 3, 3)] * 2] * 2, + constant_upsample=True, + padding="valid", + batch_norm=batch_norm, + use_attention=use_attention, + upsample_factors=[(1, 2, 2)] if upsample else [], + ) + + # skip the test for the Apple Paravirtual device # that does not support Metal 2.0 @@ -59,3 +107,66 @@ def test_train( training_stats = stats_store.retrieve_training_stats(run_config.name) assert training_stats.trained_until() == run_config.num_iterations + + +@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")]) +@pytest.mark.parametrize("task", [lf("distance_task")]) +@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) +@pytest.mark.parametrize("batch_norm", [True, False]) +@pytest.mark.parametrize("upsample", [True, False]) +@pytest.mark.parametrize("use_attention", [True, False]) +@pytest.mark.parametrize("three_d", [True, False]) +def test_train_unet( + datasplit, + task, + trainer, + batch_norm, + upsample, + use_attention, + three_d): + + store = create_config_store() + stats_store = create_stats_store() + weights_store = create_weights_store() + + architecture_config = unet_architecture(batch_norm, upsample,use_attention, three_d) + + run_config = RunConfig( + name=f"{architecture_config.name}_run", + task_config=task, + architecture_config=architecture_config, + trainer_config=trainer, + datasplit_config=datasplit, + repetition=0, + num_iterations=2, + ) + try: + store.store_run_config(run_config) + except Exception as e: + store.delete_run_config(run_config.name) + store.store_run_config(run_config) + + run = Run(run_config) + + # ------------------------------------- + + # train + + weights_store.store_weights(run, 0) + train_run(run) + + init_weights = weights_store.retrieve_weights(run.name, 0) + final_weights = weights_store.retrieve_weights(run.name, run.train_until) + + for name, weight in init_weights.model.items(): + weight_diff = (weight - final_weights.model[name]).sum() + assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff + + # assert train_stats and validation_scores are available + + training_stats = stats_store.retrieve_training_stats(run_config.name) + + assert training_stats.trained_until() == run_config.num_iterations + + + \ No newline at end of file