Skip to content

Commit

Permalink
Merge branch 'main' into upsample
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Jan 2, 2025
2 parents 9db5c85 + e9f255c commit 59c8d64
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 57 deletions.
4 changes: 3 additions & 1 deletion dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ def iterate(self, num_iterations, model, optimizer, device):
),
}
if mask is not None:
snapshot_arrays["volumes/mask"] = mask
snapshot_arrays["volumes/mask"] = np_to_funlib_array(
mask[0], offset=target.offset, voxel_size=target.voxel_size
)
logger.warning(
f"Saving Snapshot. Iteration: {iteration}, "
f"Loss: {loss.detach().cpu().numpy().item()}!"
Expand Down
2 changes: 1 addition & 1 deletion dacapo/tmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def create_from_identifier(
return prepare_ds(
out_path,
shape=(*list_num_channels, *roi.shape / voxel_size),
offset=roi.offset / voxel_size,
offset=roi.offset,
voxel_size=voxel_size,
axis_names=axis_names,
dtype=dtype,
Expand Down
77 changes: 38 additions & 39 deletions dacapo/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,45 +83,44 @@ class MakeRaw(gp.BatchFilter):
process: Generate the raw image from the labels.
"""

class Pipeline:
def __init__(
self,
raw,
labels,
gaussian_noise_args: Iterable = (0.5, 0.1),
gaussian_noise_lim: float = 0.3,
gaussian_blur_args: Iterable = (0.5, 1.5),
membrane_like=True,
membrane_size=3,
inside_value=0.5,
):
"""
Initialize the Pipeline object.
Args:
raw: The raw data.
labels: The labels data.
gaussian_noise_args: Tuple of two floats representing the mean and standard deviation
of the Gaussian noise to be added to the data. Default is (0.5, 0.1).
gaussian_noise_lim: The limit of the Gaussian noise. Default is 0.3.
gaussian_blur_args: Tuple of two floats representing the mean and standard deviation
of the Gaussian blur to be applied to the data. Default is (0.5, 1.5).
membrane_like: Boolean indicating whether to apply membrane-like transformation to the data.
Default is True.
membrane_size: The size of the membrane. Default is 3.
inside_value: The value to be assigned to the inside of the membrane. Default is 0.5.
Examples:
>>> Pipeline(raw="RAW", labels="LABELS", gaussian_noise_args=(0.5, 0.1), gaussian_noise_lim=0.3,
>>> gaussian_blur_args=(0.5, 1.5), membrane_like=True, membrane_size=3, inside_value=0.5)
"""
self.raw = raw
self.labels = labels
self.gaussian_noise_args = gaussian_noise_args
self.gaussian_noise_lim = gaussian_noise_lim
self.gaussian_blur_args = gaussian_blur_args
self.membrane_like = membrane_like
self.membrane_size = membrane_size
self.inside_value = inside_value
def __init__(
self,
raw,
labels,
gaussian_noise_args: Iterable = (0.5, 0.1),
gaussian_noise_lim: float = 0.3,
gaussian_blur_args: Iterable = (0.5, 1.5),
membrane_like=True,
membrane_size=3,
inside_value=0.5,
):
"""
Initialize the Pipeline object.
Args:
raw: The raw data.
labels: The labels data.
gaussian_noise_args: Tuple of two floats representing the mean and standard deviation
of the Gaussian noise to be added to the data. Default is (0.5, 0.1).
gaussian_noise_lim: The limit of the Gaussian noise. Default is 0.3.
gaussian_blur_args: Tuple of two floats representing the mean and standard deviation
of the Gaussian blur to be applied to the data. Default is (0.5, 1.5).
membrane_like: Boolean indicating whether to apply membrane-like transformation to the data.
Default is True.
membrane_size: The size of the membrane. Default is 3.
inside_value: The value to be assigned to the inside of the membrane. Default is 0.5.
Examples:
>>> Pipeline(raw="RAW", labels="LABELS", gaussian_noise_args=(0.5, 0.1), gaussian_noise_lim=0.3,
>>> gaussian_blur_args=(0.5, 1.5), membrane_like=True, membrane_size=3, inside_value=0.5)
"""
self.raw = raw
self.labels = labels
self.gaussian_noise_args = gaussian_noise_args
self.gaussian_noise_lim = gaussian_noise_lim
self.gaussian_blur_args = gaussian_blur_args
self.membrane_like = membrane_like
self.membrane_size = membrane_size
self.inside_value = inside_value

def setup(self):
"""
Expand Down
18 changes: 14 additions & 4 deletions tests/operations/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from funlib.persistence import prepare_ds
from funlib.geometry import Coordinate

from dacapo.experiments.trainers import GunpowderTrainerConfig
from dacapo.experiments.datasplits import SimpleDataSplitConfig
from dacapo.experiments.tasks import (
DistanceTaskConfig,
Expand All @@ -13,6 +14,19 @@
from pathlib import Path


def build_test_train_config(multiprocessing: bool):
"""
Builds the simplest possible trainer given the parameters.
"""
return GunpowderTrainerConfig(
name="test_trainer",
batch_size=1,
learning_rate=0.0001,
num_data_fetchers=1 + multiprocessing,
snapshot_interval=1,
)


def build_test_data_config(
tmpdir: Path, data_dims: int, channels: bool, upsample: bool, task_type: str
):
Expand Down Expand Up @@ -104,9 +118,7 @@ def build_test_architecture_config(
data_dims: int,
architecture_dims: int,
channels: bool,
batch_norm: bool,
upsample: bool,
use_attention: bool,
padding: str,
):
"""
Expand Down Expand Up @@ -160,7 +172,5 @@ def build_test_architecture_config(
kernel_size_up=kernel_size_up,
constant_upsample=True,
upsample_factors=upsample_factors,
batch_norm=batch_norm,
use_attention=use_attention,
padding=padding,
)
36 changes: 24 additions & 12 deletions tests/operations/test_mini.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from ..fixtures import *
from .helpers import (
build_test_train_config,
build_test_data_config,
build_test_task_config,
build_test_architecture_config,
)

from dacapo.store.create_store import create_array_store
from dacapo.experiments import Run
from dacapo.train import train_run
from dacapo.validate import validate_run

import zarr

import pytest
from pytest_lazy_fixtures import lf

from dacapo.experiments.run_config import RunConfig

Expand All @@ -22,34 +25,30 @@
@pytest.mark.parametrize("data_dims", [2, 3])
@pytest.mark.parametrize("channels", [True, False])
@pytest.mark.parametrize("task", ["distance", "onehot", "affs"])
@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")])
@pytest.mark.parametrize("architecture_dims", [2, 3])
@pytest.mark.parametrize("upsample", [True, False])
# @pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("batch_norm", [False])
# @pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("use_attention", [False])
@pytest.mark.parametrize("padding", ["valid", "same"])
@pytest.mark.parametrize("func", ["train", "validate"])
@pytest.mark.parametrize("multiprocessing", [False])
def test_mini(
tmpdir,
data_dims,
channels,
task,
trainer,
architecture_dims,
batch_norm,
upsample,
use_attention,
padding,
func,
multiprocessing,
):
# Invalid configurations:
if data_dims == 2 and architecture_dims == 3:
# cannot train a 3D model on 2D data
# TODO: maybe check that an appropriate warning is raised somewhere
return

trainer_config = build_test_train_config(multiprocessing)

data_config = build_test_data_config(
tmpdir,
data_dims,
Expand All @@ -62,17 +61,15 @@ def test_mini(
data_dims,
architecture_dims,
channels,
batch_norm,
upsample,
use_attention,
padding,
)

run_config = RunConfig(
name=f"test_{func}",
task_config=task_config,
architecture_config=architecture_config,
trainer_config=trainer,
trainer_config=trainer_config,
datasplit_config=data_config,
repetition=0,
num_iterations=1,
Expand All @@ -81,5 +78,20 @@ def test_mini(

if func == "train":
train_run(run)
array_store = create_array_store()
snapshot_container = array_store.snapshot_container(run.name).container
assert snapshot_container.exists()
assert all(
x in zarr.open(snapshot_container)
for x in [
"0/volumes/raw",
"0/volumes/gt",
"0/volumes/target",
"0/volumes/weight",
"0/volumes/prediction",
"0/volumes/gradients",
"0/volumes/mask",
]
)
elif func == "validate":
validate_run(run, 1)

0 comments on commit 59c8d64

Please sign in to comment.