diff --git a/dacapo/experiments/datasplits/simple_config.py b/dacapo/experiments/datasplits/simple_config.py index 53a66945b..9ee88283a 100644 --- a/dacapo/experiments/datasplits/simple_config.py +++ b/dacapo/experiments/datasplits/simple_config.py @@ -44,7 +44,7 @@ def get_paths(self, group_name: str) -> list[Path]: len(level_2_matches) == 0 ), f"Found raw data at {level_1} and {level_2}" return [Path(x).parent for x in level_1_matches] - elif len(level_2_matches).parent > 0: + elif len(level_2_matches) > 0: return [Path(x) for x in level_2_matches] raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}") diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py index 630e58ed5..382eaf8b1 100644 --- a/dacapo/experiments/tasks/hot_distance_task.py +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -4,6 +4,7 @@ from .predictors import HotDistancePredictor from .task import Task +import warnings class HotDistanceTask(Task): """ @@ -34,10 +35,19 @@ def __init__(self, task_config): >>> task = HotDistanceTask(task_config) """ + + if task_config.kernel_size is None: + warnings.warn( + "The default kernel size of 3 will be changing to 1. " + "Please specify the kernel size explicitly.", + DeprecationWarning, + ) + task_config.kernel_size = 3 self.predictor = HotDistancePredictor( channels=task_config.channels, scale_factor=task_config.scale_factor, mask_distances=task_config.mask_distances, + kernel_size=task_config.kernel_size, ) self.loss = HotDistanceLoss() self.post_processor = ThresholdPostProcessor() diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index 18cab91b3..d140e38e4 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -56,3 +56,8 @@ class HotDistanceTaskConfig(TaskConfig): "is less than the distance to object boundary." }, ) + + + kernel_size: int | None = attr.ib( + default=None, + ) \ No newline at end of file diff --git a/dacapo/experiments/tasks/one_hot_task.py b/dacapo/experiments/tasks/one_hot_task.py index 870140f50..55d115d15 100644 --- a/dacapo/experiments/tasks/one_hot_task.py +++ b/dacapo/experiments/tasks/one_hot_task.py @@ -4,6 +4,8 @@ from .predictors import OneHotPredictor from .task import Task +import warnings + class OneHotTask(Task): """ @@ -30,7 +32,17 @@ def __init__(self, task_config): Examples: >>> task = OneHotTask(task_config) """ - self.predictor = OneHotPredictor(classes=task_config.classes) + + if task_config.kernel_size is None: + warnings.warn( + "The default kernel size of 3 will be changing to 1. " + "Please specify the kernel size explicitly.", + DeprecationWarning, + ) + task_config.kernel_size = 3 + self.predictor = OneHotPredictor( + classes=task_config.classes, kernel_size=task_config.kernel_size + ) self.loss = DummyLoss() self.post_processor = ArgmaxPostProcessor() self.evaluator = DummyEvaluator() diff --git a/dacapo/experiments/tasks/one_hot_task_config.py b/dacapo/experiments/tasks/one_hot_task_config.py index de4817a0e..4207448de 100644 --- a/dacapo/experiments/tasks/one_hot_task_config.py +++ b/dacapo/experiments/tasks/one_hot_task_config.py @@ -28,3 +28,6 @@ class OneHotTaskConfig(TaskConfig): classes: List[str] = attr.ib( metadata={"help_text": "The classes corresponding with each id starting from 0"} ) + kernel_size: int | None = attr.ib( + default=None, + ) diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index f736d3e17..a88b89267 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -133,7 +133,7 @@ def process( overwrite=True, ) - read_roi = Roi((0, 0, 0), block_size[-self.prediction_array.dims :]) + read_roi = Roi((0,)*block_size.dims, block_size) input_array = open_ds( f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}" ) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 0c137e2f6..778064fcc 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -111,13 +111,13 @@ def process( if self.prediction_array._source_data.chunks is not None: block_size = self.prediction_array._source_data.chunks - write_size = [ + write_size = Coordinate([ b * v for b, v in zip( block_size[-self.prediction_array.dims :], self.prediction_array.voxel_size, ) - ] + ]) output_array = create_from_identifier( output_array_identifier, self.prediction_array.axis_names, @@ -128,7 +128,7 @@ def process( overwrite=True, ) - read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) + read_roi = Roi(write_size * 0, write_size) input_array = open_ds( f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}" ) diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 07cb92701..172db065a 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -223,6 +223,10 @@ def create_distance_mask( >>> predictor.create_distance_mask(distances, mask, voxel_size, normalize, normalize_args) """ + no_channel_dim = len(mask.shape) == len(distances.shape) - 1 + if no_channel_dim: + mask = mask[np.newaxis] + mask_output = mask.copy() for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): tmp = np.zeros( @@ -275,6 +279,8 @@ def create_distance_mask( np.sum(channel_mask_output) ) ) + if no_channel_dim: + mask_output = mask_output[0] return mask_output def process( @@ -300,7 +306,20 @@ def process( >>> predictor.process(labels, voxel_size, normalize, normalize_args) """ + + num_dims = len(labels.shape) + if num_dims == voxel_size.dims: + channel_dim = False + elif num_dims == voxel_size.dims + 1: + channel_dim = True + else: + raise ValueError("Cannot handle multiple channel dims") + + if not channel_dim: + labels = labels[np.newaxis] + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): boundaries = self.__find_boundaries(channel) @@ -358,7 +377,7 @@ def __find_boundaries(self, labels: np.ndarray): # bound.: 00000001000100000001000 2n - 1 if labels.dtype == bool: - raise ValueError("Labels should not be bools") + # raise ValueError("Labels should not be bools") labels = labels.astype(np.uint8) logger.debug(f"computing boundaries for {labels.shape}") diff --git a/dacapo/experiments/tasks/predictors/dummy_predictor.py b/dacapo/experiments/tasks/predictors/dummy_predictor.py index 2c495da56..46da2f6d9 100644 --- a/dacapo/experiments/tasks/predictors/dummy_predictor.py +++ b/dacapo/experiments/tasks/predictors/dummy_predictor.py @@ -50,7 +50,7 @@ def create_model(self, architecture): >>> model = predictor.create_model(architecture) """ head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + architecture.num_out_channels, self.embedding_dims, kernel_size=1 ) return Model(architecture, head) diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index 7c2361aee..4a18a3154 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -49,7 +49,7 @@ class HotDistancePredictor(Predictor): This is a subclass of Predictor. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool, kernel_size: int): """ Initializes the HotDistancePredictor. @@ -64,6 +64,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo Note: The channels argument is a list of strings, each string is the name of a class that is being segmented. """ + self.kernel_size = kernel_size self.channels = ( channels * 2 ) # one hot + distance (TODO: add hot/distance to channel names) @@ -119,11 +120,11 @@ def create_model(self, architecture): """ if architecture.dims == 2: head = torch.nn.Conv2d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + architecture.num_out_channels, self.embedding_dims, self.kernel_size ) elif architecture.dims == 3: head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + architecture.num_out_channels, self.embedding_dims, self.kernel_size ) return Model(architecture, head) diff --git a/dacapo/experiments/tasks/predictors/one_hot_predictor.py b/dacapo/experiments/tasks/predictors/one_hot_predictor.py index 1ad7fdeec..ff6e21db6 100644 --- a/dacapo/experiments/tasks/predictors/one_hot_predictor.py +++ b/dacapo/experiments/tasks/predictors/one_hot_predictor.py @@ -30,7 +30,7 @@ class OneHotPredictor(Predictor): This is a subclass of Predictor. """ - def __init__(self, classes: List[str]): + def __init__(self, classes: List[str], kernel_size: int): """ Initialize the OneHotPredictor. @@ -42,6 +42,7 @@ def __init__(self, classes: List[str]): >>> predictor = OneHotPredictor(classes) """ self.classes = classes + self.kernel_size = kernel_size @property def embedding_dims(self): @@ -70,8 +71,17 @@ def create_model(self, architecture): Examples: >>> model = predictor.create_model(architecture) """ - head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + + if architecture.dims == 3: + conv_layer = torch.nn.Conv3d + elif architecture.dims == 2: + conv_layer = torch.nn.Conv2d + else: + raise Exception(f"Unsupported number of dimensions: {architecture.dims}") + head = conv_layer( + architecture.num_out_channels, + self.embedding_dims, + kernel_size=self.kernel_size, ) return Model(architecture, head) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index bea9c96e2..507151ad7 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -268,13 +268,13 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): request.add(weight_key, output_size) request.add( mask_placeholder, - prediction_voxel_size * self.mask_integral_downsample_factor, + prediction_voxel_size, ) # request additional keys for snapshots request.add(gt_key, output_size) request.add(mask_key, output_size) request[mask_placeholder].roi = request[mask_placeholder].roi.snap_to_grid( - prediction_voxel_size * self.mask_integral_downsample_factor + prediction_voxel_size ) self._request = request diff --git a/dacapo/predict_local.py b/dacapo/predict_local.py index a4f21ab15..88aaa5cfc 100644 --- a/dacapo/predict_local.py +++ b/dacapo/predict_local.py @@ -44,10 +44,12 @@ def predict( else: input_roi = output_roi.grow(context, context) - read_roi = Roi((0, 0, 0), input_size) + read_roi = Roi((0,) * input_size.dims, input_size) write_roi = read_roi.grow(-context, -context) - axes = ["c^", "z", "y", "x"] + axes = raw_array.axis_names + if "c^" not in axes: + axes = ["c^"] + axes num_channels = model.num_out_channels @@ -73,8 +75,8 @@ def predict( model_device = str(next(model.parameters()).device).split(":")[0] - assert model_device == str( - device + assert ( + model_device == str(device) ), f"Model is not on the right device, Model: {model_device}, Compute device: {device}" def predict_fn(block): @@ -103,7 +105,7 @@ def predict_fn(block): predictions = Array( predictions, block.write_roi.offset, - raw_array.voxel_size, + output_voxel_size, axis_names, raw_array.units, ) @@ -120,7 +122,7 @@ def predict_fn(block): task = daisy.Task( f"predict_{out_container}_{out_dataset}", total_roi=input_roi, - read_roi=Roi((0, 0, 0), input_size), + read_roi=Roi((0,)*input_size.dims, input_size), write_roi=Roi(context, output_size), process_function=predict_fn, check_function=None, diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index e713745c6..5a53852ef 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -69,6 +69,9 @@ def balance_weights( scale_slab *= np.take(w, labels_slab) """ + if label_data.dtype == bool: + label_data = label_data.astype(np.uint8) + if moving_counts is None: moving_counts = [] unique_labels = np.unique(label_data) diff --git a/tests/conf.py b/tests/conf.py deleted file mode 100644 index 57a8708d5..000000000 --- a/tests/conf.py +++ /dev/null @@ -1,3 +0,0 @@ -import multiprocessing as mp - -mp.set_start_method("fork", force=True) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..9a90c5cab --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +import multiprocessing as mp +import os +import yaml + +from dacapo.options import Options + +import pytest + + +@pytest.fixture(params=["fork", "spawn"], autouse=True) +def context(monkeypatch): + ctx = mp.get_context("spawn") + monkeypatch.setattr(mp, "Queue", ctx.Queue) + monkeypatch.setattr(mp, "Process", ctx.Process) + monkeypatch.setattr(mp, "Event", ctx.Event) + monkeypatch.setattr(mp, "Value", ctx.Value) + + +@pytest.fixture(autouse=True) +def runs_base_dir(tmpdir): + options_file = tmpdir / "dacapo.yaml" + os.environ["DACAPO_OPTIONS_FILE"] = f"{options_file}" + + with open(options_file, "w") as f: + options_file.write(yaml.safe_dump({"runs_base_dir": f"{tmpdir}"})) + + assert Options.config_file() == options_file + assert Options.instance().runs_base_dir == tmpdir diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 373b80796..e0d4a47a0 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -3,7 +3,6 @@ dummy_architecture, unet_architecture, unet_3d_architecture, - unet_architecture_builder, ) from .arrays import dummy_array, zarr_array, cellmap_array from .datasplits import ( diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 4baca2da7..79e7f9fca 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -17,15 +17,15 @@ def dummy_architecture(): def unet_architecture(): yield CNNectomeUNetConfig( name="tmp_unet_architecture", - input_shape=(132, 132), - eval_shape_increase=(32, 32), + input_shape=(1, 132, 132), + eval_shape_increase=(1, 32, 32), fmaps_in=1, num_fmaps=8, fmaps_out=8, fmap_inc_factor=2, - downsample_factors=[(4, 4), (4, 4)], - kernel_size_down=[[(3, 3)] * 2] * 3, - kernel_size_up=[[(3, 3)] * 2] * 2, + downsample_factors=[(1, 4, 4), (1, 4, 4)], + kernel_size_down=[[(1, 3, 3)] * 2] * 3, + kernel_size_up=[[(1, 3, 3)] * 2] * 2, constant_upsample=True, padding="valid", ) @@ -44,44 +44,3 @@ def unet_3d_architecture(): downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], constant_upsample=True, ) - - -def unet_architecture_builder(batch_norm, upsample, use_attention, three_d): - name = "3d_unet" if three_d else "2d_unet" - name = f"{name}_bn" if batch_norm else name - name = f"{name}_up" if upsample else name - name = f"{name}_att" if use_attention else name - - if three_d: - return CNNectomeUNetConfig( - name=name, - input_shape=(188, 188, 188), - eval_shape_increase=(72, 72, 72), - fmaps_in=1, - num_fmaps=6, - fmaps_out=6, - fmap_inc_factor=2, - downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], - constant_upsample=True, - upsample_factors=[(2, 2, 2)] if upsample else [], - batch_norm=batch_norm, - use_attention=use_attention, - ) - else: - return CNNectomeUNetConfig( - name=name, - input_shape=(132, 132), - eval_shape_increase=(32, 32), - fmaps_in=1, - num_fmaps=8, - fmaps_out=8, - fmap_inc_factor=2, - downsample_factors=[(4, 4), (4, 4)], - kernel_size_down=[[(3, 3)] * 2] * 3, - kernel_size_up=[[(3, 3)] * 2] * 2, - constant_upsample=True, - padding="valid", - batch_norm=batch_norm, - use_attention=use_attention, - upsample_factors=[(2, 2)] if upsample else [], - ) diff --git a/tests/fixtures/predictors.py b/tests/fixtures/predictors.py index cc93369cf..c6dd6de51 100644 --- a/tests/fixtures/predictors.py +++ b/tests/fixtures/predictors.py @@ -10,4 +10,4 @@ def distance_predictor(): @pytest.fixture() def onehot_predictor(): - yield OneHotPredictor(classes=["a", "b", "c"]) + yield OneHotPredictor(classes=["a", "b", "c"], kernel_size=1) diff --git a/tests/fixtures/tasks.py b/tests/fixtures/tasks.py index bd8b25084..5792811b4 100644 --- a/tests/fixtures/tasks.py +++ b/tests/fixtures/tasks.py @@ -51,6 +51,7 @@ def onehot_task(): yield OneHotTaskConfig( name="one_hot_task", classes=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"], + kernel_size=1, ) @@ -59,4 +60,5 @@ def six_onehot_task(): yield OneHotTaskConfig( name="one_hot_task", classes=["a", "b", "c", "d", "e", "f"], + kernel_size=1, ) diff --git a/tests/operations/helpers.py b/tests/operations/helpers.py new file mode 100644 index 000000000..74fb43208 --- /dev/null +++ b/tests/operations/helpers.py @@ -0,0 +1,166 @@ +import numpy as np +from funlib.persistence import prepare_ds +from funlib.geometry import Coordinate + +from dacapo.experiments.datasplits import SimpleDataSplitConfig +from dacapo.experiments.tasks import ( + DistanceTaskConfig, + OneHotTaskConfig, + AffinitiesTaskConfig, +) +from dacapo.experiments.architectures import CNNectomeUNetConfig + +from pathlib import Path + + +def build_test_data_config( + tmpdir: Path, data_dims: int, channels: bool, upsample: bool, task_type: str +): + """ + Builds the simplest possible datasplit given the parameters. + + Labels are alternating planes/lines of 0/1 in the last dimension. + Intensities are random where labels are > 0, else 0. (If channels, stack twice.) + if task_type is "semantic", labels are binarized via labels > 0. + + if upsampling, labels are upsampled by a factor of 2 in each dimension + """ + + data_shape = (32, 32, 32)[-data_dims:] + axis_names = ["z", "y", "x"][-data_dims:] + mesh = np.meshgrid( + *[np.linspace(0, dim - 1, dim * (1 + upsample)) for dim in data_shape] + ) + labels = mesh[-1] * (mesh[-1] % 2 > 0.75) + + intensities = np.random.rand(*labels.shape) * labels > 0 + + if channels: + intensities = np.stack([intensities, intensities], axis=0) + + intensities_array = prepare_ds( + tmpdir / "test_data.zarr/raw", + intensities.shape, + offset=(0,) * data_dims, + voxel_size=(2,) * data_dims, + axis_names=["c^"] * int(channels) + axis_names, + dtype=intensities.dtype, + mode="w", + ) + intensities_array[:] = intensities + + if task_type == "semantic": + labels = labels > 0 + + labels_array = prepare_ds( + tmpdir / "test_data.zarr/labels", + labels.shape, + offset=(0,) * data_dims, + voxel_size=(2 - upsample,) * data_dims, + axis_names=axis_names, + dtype=labels.dtype, + mode="w", + ) + labels_array[:] = labels + + return SimpleDataSplitConfig(name="test_data", path=tmpdir / "test_data.zarr") + + +def build_test_task_config(task, data_dims: int, architecture_dims: int): + """ + Build the simplest task config given the parameters. + """ + if task == "distance": + return DistanceTaskConfig( + name="test_distance_task", + channels=["fg"], + clip_distance=4, + tol_distance=4, + scale_factor=8, + ) + if task == "onehot": + return OneHotTaskConfig( + name="test_onehot_task", classes=["bg", "fg"], kernel_size=1 + ) + if task == "affs": + # TODO: should configs be able to take any sequence for the neighborhood? + if data_dims == 2: + # 2D + neighborhood = [Coordinate(1, 0), Coordinate(0, 1)] + elif data_dims == 3 and architecture_dims == 2: + # 3D but only generate 2D affs + neighborhood = [Coordinate(0, 1, 0), Coordinate(0, 0, 1)] + elif data_dims == 3 and architecture_dims == 3: + # 3D + neighborhood = [ + Coordinate(1, 0, 0), + Coordinate(0, 1, 0), + Coordinate(0, 0, 1), + ] + return AffinitiesTaskConfig(name="test_affs_task", neighborhood=neighborhood) + + +def build_test_architecture_config( + data_dims: int, + architecture_dims: int, + channels: bool, + batch_norm: bool, + upsample: bool, + use_attention: bool, + padding: str, +): + """ + Build the simplest architecture config given the parameters. + """ + if data_dims == 2: + input_shape = (32, 32) + eval_shape_increase = (8, 8) + downsample_factors = [(2, 2)] + upsample_factors = [(2, 2)] * int(upsample) + + kernel_size_down = [[(3, 3)] * 2] * 2 + kernel_size_up = [[(3, 3)] * 2] * 1 + kernel_size_down = None # the default should work + kernel_size_up = None # the default should work + + elif data_dims == 3 and architecture_dims == 2: + input_shape = (1, 32, 32) + eval_shape_increase = (15, 8, 8) + downsample_factors = [(1, 2, 2)] + + # test data upsamples in all dimensions so we have + # to here too + upsample_factors = [(2, 2, 2)] * int(upsample) + + # we have to force the 3D kernels to be 2D + kernel_size_down = [[(1, 3, 3)] * 2] * 2 + kernel_size_up = [[(1, 3, 3)] * 2] * 1 + + elif data_dims == 3 and architecture_dims == 3: + input_shape = (32, 32, 32) + eval_shape_increase = (8, 8, 8) + downsample_factors = [(2, 2, 2)] + upsample_factors = [(2, 2, 2)] * int(upsample) + + kernel_size_down = [[(3, 3, 3)] * 2] * 2 + kernel_size_up = [[(3, 3, 3)] * 2] * 1 + kernel_size_down = None # the default should work + kernel_size_up = None # the default should work + + return CNNectomeUNetConfig( + name="test_cnnectome_unet", + input_shape=input_shape, + eval_shape_increase=eval_shape_increase, + fmaps_in=1 + channels, + num_fmaps=2, + fmaps_out=2, + fmap_inc_factor=2, + downsample_factors=downsample_factors, + kernel_size_down=kernel_size_down, + kernel_size_up=kernel_size_up, + constant_upsample=True, + upsample_factors=upsample_factors, + batch_norm=batch_norm, + use_attention=use_attention, + padding=padding, + ) diff --git a/tests/operations/test_architecture.py b/tests/operations/test_architecture.py index 1969ce33f..2be724d07 100644 --- a/tests/operations/test_architecture.py +++ b/tests/operations/test_architecture.py @@ -52,28 +52,14 @@ def test_stored_architecture( assert architecture.dims is not None, f"Architecture dims are None {architecture}" -@pytest.mark.parametrize( - "architecture_config", - [ - lf("unet_architecture"), - ], -) -def test_3d_conv_unet( - architecture_config, -): - architecture = architecture_config.architecture_type(architecture_config) - for name, module in architecture.named_modules(): - if isinstance(module, nn.Conv3d): - raise ValueError(f"Conv3d found in 2d unet {name}") - - @pytest.mark.parametrize( "architecture_config", [ lf("unet_3d_architecture"), + lf("unet_architecture"), ], ) -def test_2d_conv_unet( +def test_conv_dims( architecture_config, ): architecture = architecture_config.architecture_type(architecture_config) @@ -82,21 +68,6 @@ def test_2d_conv_unet( raise ValueError(f"Conv2d found in 3d unet {name}") -@pytest.mark.parametrize( - "run_config", - [ - lf("unet_2d_distance_run"), - ], -) -def test_2d_conv_unet_in_run( - run_config, -): - run = Run(run_config) - model = run.model - for name, module in model.named_modules(): - if isinstance(module, nn.Conv3d): - raise ValueError(f"Conv3d found in 2d unet {name}") - @pytest.mark.parametrize( "run_config", diff --git a/tests/operations/test_mini.py b/tests/operations/test_mini.py new file mode 100644 index 000000000..abc49239a --- /dev/null +++ b/tests/operations/test_mini.py @@ -0,0 +1,79 @@ +from ..fixtures import * +from .helpers import ( + build_test_data_config, + build_test_task_config, + build_test_architecture_config, +) + +from dacapo.experiments import Run +from dacapo.train import train_run +from dacapo.validate import validate_run + +import pytest +from pytest_lazy_fixtures import lf + +from dacapo.experiments.run_config import RunConfig + +import pytest + + +# TODO: Move unet parameters that don't affect interaction with other modules +# to a separate architcture test +@pytest.mark.parametrize("data_dims", [2, 3]) +@pytest.mark.parametrize("channels", [True, False]) +@pytest.mark.parametrize("task", ["distance", "onehot", "affs"]) +@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) +@pytest.mark.parametrize("architecture_dims", [2, 3]) +@pytest.mark.parametrize("upsample", [True, False]) +# @pytest.mark.parametrize("batch_norm", [True, False]) +@pytest.mark.parametrize("batch_norm", [False]) +# @pytest.mark.parametrize("use_attention", [True, False]) +@pytest.mark.parametrize("use_attention", [False]) +@pytest.mark.parametrize("padding", ["valid", "same"]) +@pytest.mark.parametrize("func", ["train", "validate"]) +def test_mini( + tmpdir, + data_dims, + channels, + task, + trainer, + architecture_dims, + batch_norm, + upsample, + use_attention, + padding, + func, +): + # Invalid configurations: + if data_dims == 2 and architecture_dims == 3: + # cannot train a 3D model on 2D data + # TODO: maybe check that an appropriate warning is raised somewhere + return + + data_config = build_test_data_config( + tmpdir, + data_dims, + channels, + upsample, + "instance" if task == "affs" else "semantic", + ) + task_config = build_test_task_config(task, data_dims, architecture_dims) + architecture_config = build_test_architecture_config( + data_dims, architecture_dims, channels, batch_norm, upsample, use_attention, padding + ) + + run_config = RunConfig( + name=f"test_{func}", + task_config=task_config, + architecture_config=architecture_config, + trainer_config=trainer, + datasplit_config=data_config, + repetition=0, + num_iterations=1, + ) + run = Run(run_config) + + if func == "train": + train_run(run) + elif func == "validate": + validate_run(run, 1) diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index 00c6b36e9..ae8ad1760 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -9,12 +9,6 @@ import pytest from pytest_lazy_fixtures import lf -from dacapo.experiments.run_config import RunConfig - -import logging - -logging.basicConfig(level=logging.INFO) - import pytest @@ -30,7 +24,7 @@ lf("hot_distance_run"), ], ) -def test_train( +def test_large( options, run_config, ): @@ -65,142 +59,3 @@ def test_train( assert training_stats.trained_until() == run_config.num_iterations - -@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")]) -@pytest.mark.parametrize("task", [lf("distance_task")]) -@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) -@pytest.mark.parametrize("batch_norm", [False]) -@pytest.mark.parametrize("upsample", [False]) -@pytest.mark.parametrize("use_attention", [False]) -@pytest.mark.parametrize("three_d", [False]) -def test_train_non_stored_unet( - datasplit, task, trainer, batch_norm, upsample, use_attention, three_d -): - architecture_config = unet_architecture_builder( - batch_norm, upsample, use_attention, three_d - ) - - run_config = RunConfig( - name=f"{architecture_config.name}_run_v", - task_config=task, - architecture_config=architecture_config, - trainer_config=trainer, - datasplit_config=datasplit, - repetition=0, - num_iterations=10, - ) - run = Run(run_config) - train_run(run) - - -@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")]) -@pytest.mark.parametrize("task", [lf("distance_task")]) -@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) -@pytest.mark.parametrize("batch_norm", [True, False]) -@pytest.mark.parametrize("upsample", [False]) -@pytest.mark.parametrize("use_attention", [True, False]) -@pytest.mark.parametrize("three_d", [True, False]) -def test_train_unet( - datasplit, task, trainer, batch_norm, upsample, use_attention, three_d -): - store = create_config_store() - stats_store = create_stats_store() - weights_store = create_weights_store() - - architecture_config = unet_architecture_builder( - batch_norm, upsample, use_attention, three_d - ) - - run_config = RunConfig( - name=f"{architecture_config.name}_run", - task_config=task, - architecture_config=architecture_config, - trainer_config=trainer, - datasplit_config=datasplit, - repetition=0, - num_iterations=10, - ) - try: - store.store_run_config(run_config) - except Exception as e: - store.delete_run_config(run_config.name) - store.store_run_config(run_config) - - run = Run(run_config) - - # ------------------------------------- - - # train - - weights_store.store_weights(run, 0) - train_run(run) - - init_weights = weights_store.retrieve_weights(run.name, 0) - final_weights = weights_store.retrieve_weights(run.name, 10) - - for name, weight in init_weights.model.items(): - weight_diff = (weight - final_weights.model[name]).any() - assert weight_diff != 0, "Weights did not change" - - # assert train_stats and validation_scores are available - - training_stats = stats_store.retrieve_training_stats(run_config.name) - - assert training_stats.trained_until() == run_config.num_iterations - - -# @pytest.mark.parametrize("upsample_datasplit", [lf("upsample_six_class_datasplit")]) -# @pytest.mark.parametrize("task", [lf("distance_task")]) -# @pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) -# @pytest.mark.parametrize("batch_norm", [True, False]) -# @pytest.mark.parametrize("upsample", [True]) -# @pytest.mark.parametrize("use_attention", [True, False]) -# @pytest.mark.parametrize("three_d", [True, False]) -# def test_upsample_train_unet( -# upsample_datasplit, task, trainer, batch_norm, upsample, use_attention, three_d -# ): -# store = create_config_store() -# stats_store = create_stats_store() -# weights_store = create_weights_store() - -# architecture_config = unet_architecture_builder( -# batch_norm, upsample, use_attention, three_d -# ) - -# run_config = RunConfig( -# name=f"{architecture_config.name}_run", -# task_config=task, -# architecture_config=architecture_config, -# trainer_config=trainer, -# datasplit_config=upsample_datasplit, -# repetition=0, -# num_iterations=10, -# ) -# try: -# store.store_run_config(run_config) -# except Exception as e: -# store.delete_run_config(run_config.name) -# store.store_run_config(run_config) - -# run = Run(run_config) - -# # ------------------------------------- - -# # train - -# weights_store.store_weights(run, 0) -# train_run(run) -# # weights_store.store_weights(run, run.train_until) - -# init_weights = weights_store.retrieve_weights(run.name, 0) -# final_weights = weights_store.retrieve_weights(run.name, 10) - -# for name, weight in init_weights.model.items(): -# weight_diff = (weight - final_weights.model[name]).any() -# assert weight_diff != 0, "Weights did not change" - -# # assert train_stats and validation_scores are available - -# training_stats = stats_store.retrieve_training_stats(run_config.name) - -# assert training_stats.trained_until() == run_config.num_iterations diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 97819400e..4df49a602 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -1,14 +1,9 @@ -import os -from upath import UPath as Path -import shutil from ..fixtures import * from dacapo.experiments import Run from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import validate, validate_run -from dacapo.experiments.run_config import RunConfig - import pytest from pytest_lazy_fixtures import lf @@ -24,110 +19,23 @@ lf("onehot_run"), ], ) -def test_validate( +def test_large( options, run_config, ): - # set debug to True to run the test in a specific directory (for debugging) - debug = False - if debug: - tmp_path = f"{Path(__file__).parent}/tmp" - if os.path.exists(tmp_path): - shutil.rmtree(tmp_path, ignore_errors=True) - os.makedirs(tmp_path, exist_ok=True) - old_path = os.getcwd() - os.chdir(tmp_path) - # when done debugging, delete "tests/operations/tmp" - # ------------------------------------- store = create_config_store() + weights_store = create_weights_store() store.store_run_config(run_config) + # validate validate(run_config.name, 0) - # weights_store.store_weights(run, 1) - # validate_run(run_config.name, 1) + + # validate_run + run = Run(run_config) + weights_store.store_weights(run, 1) + validate_run(run, 1) # test validating weights that don't exist with pytest.raises(FileNotFoundError): validate(run_config.name, 2) - if debug: - os.chdir(old_path) - - -@pytest.mark.parametrize( - "run_config", - [ - lf("distance_run"), - lf("onehot_run"), - ], -) -def test_validate_run( - options, - run_config, -): - # set debug to True to run the test in a specific directory (for debugging) - debug = False - if debug: - tmp_path = f"{Path(__file__).parent}/tmp" - if os.path.exists(tmp_path): - shutil.rmtree(tmp_path, ignore_errors=True) - os.makedirs(tmp_path, exist_ok=True) - old_path = os.getcwd() - os.chdir(tmp_path) - # when done debugging, delete "tests/operations/tmp" - # ------------------------------------- - - # create a store - - store = create_config_store() - weights_store = create_weights_store() - - # store the configs - - store.store_run_config(run_config) - - run_config = store.retrieve_run_config(run_config.name) - run = Run(run_config) - - # ------------------------------------- - - # validate - - # test validating iterations for which we know there are weights - weights_store.store_weights(run, 0) - validate_run(run, 0) - - if debug: - os.chdir(old_path) - - -@pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")]) -@pytest.mark.parametrize("task", [lf("distance_task"), lf("six_onehot_task")]) -@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) -@pytest.mark.parametrize( - "architecture", [lf("unet_architecture"), lf("unet_3d_architecture")] -) -def test_validate_unet(datasplit, task, trainer, architecture): - store = create_config_store() - weights_store = create_weights_store() - - run_config = RunConfig( - name=f"{architecture.name}_run_validate", - task_config=task, - architecture_config=architecture, - trainer_config=trainer, - datasplit_config=datasplit, - repetition=0, - num_iterations=10, - ) - try: - store.store_run_config(run_config) - except Exception as e: - store.delete_run_config(run_config.name) - store.store_run_config(run_config) - - run = Run(run_config) - - # ------------------------------------- - weights_store.store_weights(run, 0) - validate(run.name, 0)