diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 507151ad..15da00a0 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -361,7 +361,9 @@ def iterate(self, num_iterations, model, optimizer, device): ), } if mask is not None: - snapshot_arrays["volumes/mask"] = mask + snapshot_arrays["volumes/mask"] = np_to_funlib_array( + mask[0], offset=target.offset, voxel_size=target.voxel_size + ) logger.warning( f"Saving Snapshot. Iteration: {iteration}, " f"Loss: {loss.detach().cpu().numpy().item()}!" diff --git a/dacapo/tmp.py b/dacapo/tmp.py index 672745c9..6c455aff 100644 --- a/dacapo/tmp.py +++ b/dacapo/tmp.py @@ -71,7 +71,7 @@ def create_from_identifier( return prepare_ds( out_path, shape=(*list_num_channels, *roi.shape / voxel_size), - offset=roi.offset / voxel_size, + offset=roi.offset, voxel_size=voxel_size, axis_names=axis_names, dtype=dtype, diff --git a/dacapo/utils/pipeline.py b/dacapo/utils/pipeline.py index 99f823eb..9daef7c1 100644 --- a/dacapo/utils/pipeline.py +++ b/dacapo/utils/pipeline.py @@ -83,45 +83,44 @@ class MakeRaw(gp.BatchFilter): process: Generate the raw image from the labels. """ - class Pipeline: - def __init__( - self, - raw, - labels, - gaussian_noise_args: Iterable = (0.5, 0.1), - gaussian_noise_lim: float = 0.3, - gaussian_blur_args: Iterable = (0.5, 1.5), - membrane_like=True, - membrane_size=3, - inside_value=0.5, - ): - """ - Initialize the Pipeline object. - - Args: - raw: The raw data. - labels: The labels data. - gaussian_noise_args: Tuple of two floats representing the mean and standard deviation - of the Gaussian noise to be added to the data. Default is (0.5, 0.1). - gaussian_noise_lim: The limit of the Gaussian noise. Default is 0.3. - gaussian_blur_args: Tuple of two floats representing the mean and standard deviation - of the Gaussian blur to be applied to the data. Default is (0.5, 1.5). - membrane_like: Boolean indicating whether to apply membrane-like transformation to the data. - Default is True. - membrane_size: The size of the membrane. Default is 3. - inside_value: The value to be assigned to the inside of the membrane. Default is 0.5. - Examples: - >>> Pipeline(raw="RAW", labels="LABELS", gaussian_noise_args=(0.5, 0.1), gaussian_noise_lim=0.3, - >>> gaussian_blur_args=(0.5, 1.5), membrane_like=True, membrane_size=3, inside_value=0.5) - """ - self.raw = raw - self.labels = labels - self.gaussian_noise_args = gaussian_noise_args - self.gaussian_noise_lim = gaussian_noise_lim - self.gaussian_blur_args = gaussian_blur_args - self.membrane_like = membrane_like - self.membrane_size = membrane_size - self.inside_value = inside_value + def __init__( + self, + raw, + labels, + gaussian_noise_args: Iterable = (0.5, 0.1), + gaussian_noise_lim: float = 0.3, + gaussian_blur_args: Iterable = (0.5, 1.5), + membrane_like=True, + membrane_size=3, + inside_value=0.5, + ): + """ + Initialize the Pipeline object. + + Args: + raw: The raw data. + labels: The labels data. + gaussian_noise_args: Tuple of two floats representing the mean and standard deviation + of the Gaussian noise to be added to the data. Default is (0.5, 0.1). + gaussian_noise_lim: The limit of the Gaussian noise. Default is 0.3. + gaussian_blur_args: Tuple of two floats representing the mean and standard deviation + of the Gaussian blur to be applied to the data. Default is (0.5, 1.5). + membrane_like: Boolean indicating whether to apply membrane-like transformation to the data. + Default is True. + membrane_size: The size of the membrane. Default is 3. + inside_value: The value to be assigned to the inside of the membrane. Default is 0.5. + Examples: + >>> Pipeline(raw="RAW", labels="LABELS", gaussian_noise_args=(0.5, 0.1), gaussian_noise_lim=0.3, + >>> gaussian_blur_args=(0.5, 1.5), membrane_like=True, membrane_size=3, inside_value=0.5) + """ + self.raw = raw + self.labels = labels + self.gaussian_noise_args = gaussian_noise_args + self.gaussian_noise_lim = gaussian_noise_lim + self.gaussian_blur_args = gaussian_blur_args + self.membrane_like = membrane_like + self.membrane_size = membrane_size + self.inside_value = inside_value def setup(self): """ diff --git a/tests/operations/helpers.py b/tests/operations/helpers.py index 74fb4320..d4be3f70 100644 --- a/tests/operations/helpers.py +++ b/tests/operations/helpers.py @@ -2,6 +2,7 @@ from funlib.persistence import prepare_ds from funlib.geometry import Coordinate +from dacapo.experiments.trainers import GunpowderTrainerConfig from dacapo.experiments.datasplits import SimpleDataSplitConfig from dacapo.experiments.tasks import ( DistanceTaskConfig, @@ -13,6 +14,19 @@ from pathlib import Path +def build_test_train_config(multiprocessing: bool): + """ + Builds the simplest possible trainer given the parameters. + """ + return GunpowderTrainerConfig( + name="test_trainer", + batch_size=1, + learning_rate=0.0001, + num_data_fetchers=1 + multiprocessing, + snapshot_interval=1, + ) + + def build_test_data_config( tmpdir: Path, data_dims: int, channels: bool, upsample: bool, task_type: str ): @@ -104,9 +118,7 @@ def build_test_architecture_config( data_dims: int, architecture_dims: int, channels: bool, - batch_norm: bool, upsample: bool, - use_attention: bool, padding: str, ): """ @@ -160,7 +172,5 @@ def build_test_architecture_config( kernel_size_up=kernel_size_up, constant_upsample=True, upsample_factors=upsample_factors, - batch_norm=batch_norm, - use_attention=use_attention, padding=padding, ) diff --git a/tests/operations/test_mini.py b/tests/operations/test_mini.py index f5070553..57b25bdf 100644 --- a/tests/operations/test_mini.py +++ b/tests/operations/test_mini.py @@ -1,16 +1,19 @@ from ..fixtures import * from .helpers import ( + build_test_train_config, build_test_data_config, build_test_task_config, build_test_architecture_config, ) +from dacapo.store.create_store import create_array_store from dacapo.experiments import Run from dacapo.train import train_run from dacapo.validate import validate_run +import zarr + import pytest -from pytest_lazy_fixtures import lf from dacapo.experiments.run_config import RunConfig @@ -22,27 +25,21 @@ @pytest.mark.parametrize("data_dims", [2, 3]) @pytest.mark.parametrize("channels", [True, False]) @pytest.mark.parametrize("task", ["distance", "onehot", "affs"]) -@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) @pytest.mark.parametrize("architecture_dims", [2, 3]) @pytest.mark.parametrize("upsample", [True, False]) -# @pytest.mark.parametrize("batch_norm", [True, False]) -@pytest.mark.parametrize("batch_norm", [False]) -# @pytest.mark.parametrize("use_attention", [True, False]) -@pytest.mark.parametrize("use_attention", [False]) @pytest.mark.parametrize("padding", ["valid", "same"]) @pytest.mark.parametrize("func", ["train", "validate"]) +@pytest.mark.parametrize("multiprocessing", [False]) def test_mini( tmpdir, data_dims, channels, task, - trainer, architecture_dims, - batch_norm, upsample, - use_attention, padding, func, + multiprocessing, ): # Invalid configurations: if data_dims == 2 and architecture_dims == 3: @@ -50,6 +47,8 @@ def test_mini( # TODO: maybe check that an appropriate warning is raised somewhere return + trainer_config = build_test_train_config(multiprocessing) + data_config = build_test_data_config( tmpdir, data_dims, @@ -62,9 +61,7 @@ def test_mini( data_dims, architecture_dims, channels, - batch_norm, upsample, - use_attention, padding, ) @@ -72,7 +69,7 @@ def test_mini( name=f"test_{func}", task_config=task_config, architecture_config=architecture_config, - trainer_config=trainer, + trainer_config=trainer_config, datasplit_config=data_config, repetition=0, num_iterations=1, @@ -81,5 +78,20 @@ def test_mini( if func == "train": train_run(run) + array_store = create_array_store() + snapshot_container = array_store.snapshot_container(run.name).container + assert snapshot_container.exists() + assert all( + x in zarr.open(snapshot_container) + for x in [ + "0/volumes/raw", + "0/volumes/gt", + "0/volumes/target", + "0/volumes/weight", + "0/volumes/prediction", + "0/volumes/gradients", + "0/volumes/mask", + ] + ) elif func == "validate": validate_run(run, 1)