Skip to content

Commit

Permalink
Tests v0 3 5 (#346)
Browse files Browse the repository at this point in the history
Simplified the parameterized train test, and added validation. Fixed
bugs that were found
  • Loading branch information
mzouink authored Nov 19, 2024
2 parents 18a0b20 + e302c33 commit aa3b72c
Show file tree
Hide file tree
Showing 25 changed files with 379 additions and 350 deletions.
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/simple_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_paths(self, group_name: str) -> list[Path]:
len(level_2_matches) == 0
), f"Found raw data at {level_1} and {level_2}"
return [Path(x).parent for x in level_1_matches]
elif len(level_2_matches).parent > 0:
elif len(level_2_matches) > 0:
return [Path(x) for x in level_2_matches]

raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}")
Expand Down
10 changes: 10 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .predictors import HotDistancePredictor
from .task import Task

import warnings

class HotDistanceTask(Task):
"""
Expand Down Expand Up @@ -34,10 +35,19 @@ def __init__(self, task_config):
>>> task = HotDistanceTask(task_config)
"""

if task_config.kernel_size is None:
warnings.warn(
"The default kernel size of 3 will be changing to 1. "
"Please specify the kernel size explicitly.",
DeprecationWarning,
)
task_config.kernel_size = 3
self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
kernel_size=task_config.kernel_size,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
5 changes: 5 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ class HotDistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)


kernel_size: int | None = attr.ib(
default=None,
)
14 changes: 13 additions & 1 deletion dacapo/experiments/tasks/one_hot_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .predictors import OneHotPredictor
from .task import Task

import warnings


class OneHotTask(Task):
"""
Expand All @@ -30,7 +32,17 @@ def __init__(self, task_config):
Examples:
>>> task = OneHotTask(task_config)
"""
self.predictor = OneHotPredictor(classes=task_config.classes)

if task_config.kernel_size is None:
warnings.warn(
"The default kernel size of 3 will be changing to 1. "
"Please specify the kernel size explicitly.",
DeprecationWarning,
)
task_config.kernel_size = 3
self.predictor = OneHotPredictor(
classes=task_config.classes, kernel_size=task_config.kernel_size
)
self.loss = DummyLoss()
self.post_processor = ArgmaxPostProcessor()
self.evaluator = DummyEvaluator()
3 changes: 3 additions & 0 deletions dacapo/experiments/tasks/one_hot_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ class OneHotTaskConfig(TaskConfig):
classes: List[str] = attr.ib(
metadata={"help_text": "The classes corresponding with each id starting from 0"}
)
kernel_size: int | None = attr.ib(
default=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def process(
overwrite=True,
)

read_roi = Roi((0, 0, 0), block_size[-self.prediction_array.dims :])
read_roi = Roi((0,)*block_size.dims, block_size)
input_array = open_ds(
f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def process(
if self.prediction_array._source_data.chunks is not None:
block_size = self.prediction_array._source_data.chunks

write_size = [
write_size = Coordinate([
b * v
for b, v in zip(
block_size[-self.prediction_array.dims :],
self.prediction_array.voxel_size,
)
]
])
output_array = create_from_identifier(
output_array_identifier,
self.prediction_array.axis_names,
Expand All @@ -128,7 +128,7 @@ def process(
overwrite=True,
)

read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :])
read_roi = Roi(write_size * 0, write_size)
input_array = open_ds(
f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}"
)
Expand Down
21 changes: 20 additions & 1 deletion dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def create_distance_mask(
>>> predictor.create_distance_mask(distances, mask, voxel_size, normalize, normalize_args)
"""
no_channel_dim = len(mask.shape) == len(distances.shape) - 1
if no_channel_dim:
mask = mask[np.newaxis]

mask_output = mask.copy()
for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)):
tmp = np.zeros(
Expand Down Expand Up @@ -275,6 +279,8 @@ def create_distance_mask(
np.sum(channel_mask_output)
)
)
if no_channel_dim:
mask_output = mask_output[0]
return mask_output

def process(
Expand All @@ -300,7 +306,20 @@ def process(
>>> predictor.process(labels, voxel_size, normalize, normalize_args)
"""

num_dims = len(labels.shape)
if num_dims == voxel_size.dims:
channel_dim = False
elif num_dims == voxel_size.dims + 1:
channel_dim = True
else:
raise ValueError("Cannot handle multiple channel dims")

if not channel_dim:
labels = labels[np.newaxis]

all_distances = np.zeros(labels.shape, dtype=np.float32) - 1

for ii, channel in enumerate(labels):
boundaries = self.__find_boundaries(channel)

Expand Down Expand Up @@ -358,7 +377,7 @@ def __find_boundaries(self, labels: np.ndarray):
# bound.: 00000001000100000001000 2n - 1

if labels.dtype == bool:
raise ValueError("Labels should not be bools")
# raise ValueError("Labels should not be bools")
labels = labels.astype(np.uint8)

logger.debug(f"computing boundaries for {labels.shape}")
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/predictors/dummy_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_model(self, architecture):
>>> model = predictor.create_model(architecture)
"""
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=3
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)

return Model(architecture, head)
Expand Down
7 changes: 4 additions & 3 deletions dacapo/experiments/tasks/predictors/hot_distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class HotDistancePredictor(Predictor):
This is a subclass of Predictor.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool, kernel_size: int):
"""
Initializes the HotDistancePredictor.
Expand All @@ -64,6 +64,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
Note:
The channels argument is a list of strings, each string is the name of a class that is being segmented.
"""
self.kernel_size = kernel_size
self.channels = (
channels * 2
) # one hot + distance (TODO: add hot/distance to channel names)
Expand Down Expand Up @@ -119,11 +120,11 @@ def create_model(self, architecture):
"""
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=3
architecture.num_out_channels, self.embedding_dims, self.kernel_size
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=3
architecture.num_out_channels, self.embedding_dims, self.kernel_size
)

return Model(architecture, head)
Expand Down
16 changes: 13 additions & 3 deletions dacapo/experiments/tasks/predictors/one_hot_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class OneHotPredictor(Predictor):
This is a subclass of Predictor.
"""

def __init__(self, classes: List[str]):
def __init__(self, classes: List[str], kernel_size: int):
"""
Initialize the OneHotPredictor.
Expand All @@ -42,6 +42,7 @@ def __init__(self, classes: List[str]):
>>> predictor = OneHotPredictor(classes)
"""
self.classes = classes
self.kernel_size = kernel_size

@property
def embedding_dims(self):
Expand Down Expand Up @@ -70,8 +71,17 @@ def create_model(self, architecture):
Examples:
>>> model = predictor.create_model(architecture)
"""
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=3

if architecture.dims == 3:
conv_layer = torch.nn.Conv3d
elif architecture.dims == 2:
conv_layer = torch.nn.Conv2d
else:
raise Exception(f"Unsupported number of dimensions: {architecture.dims}")
head = conv_layer(
architecture.num_out_channels,
self.embedding_dims,
kernel_size=self.kernel_size,
)

return Model(architecture, head)
Expand Down
4 changes: 2 additions & 2 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,13 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
request.add(weight_key, output_size)
request.add(
mask_placeholder,
prediction_voxel_size * self.mask_integral_downsample_factor,
prediction_voxel_size,
)
# request additional keys for snapshots
request.add(gt_key, output_size)
request.add(mask_key, output_size)
request[mask_placeholder].roi = request[mask_placeholder].roi.snap_to_grid(
prediction_voxel_size * self.mask_integral_downsample_factor
prediction_voxel_size
)

self._request = request
Expand Down
14 changes: 8 additions & 6 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def predict(
else:
input_roi = output_roi.grow(context, context)

read_roi = Roi((0, 0, 0), input_size)
read_roi = Roi((0,) * input_size.dims, input_size)
write_roi = read_roi.grow(-context, -context)

axes = ["c^", "z", "y", "x"]
axes = raw_array.axis_names
if "c^" not in axes:
axes = ["c^"] + axes

num_channels = model.num_out_channels

Expand All @@ -73,8 +75,8 @@ def predict(

model_device = str(next(model.parameters()).device).split(":")[0]

assert model_device == str(
device
assert (
model_device == str(device)
), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
Expand Down Expand Up @@ -103,7 +105,7 @@ def predict_fn(block):
predictions = Array(
predictions,
block.write_roi.offset,
raw_array.voxel_size,
output_voxel_size,
axis_names,
raw_array.units,
)
Expand All @@ -120,7 +122,7 @@ def predict_fn(block):
task = daisy.Task(
f"predict_{out_container}_{out_dataset}",
total_roi=input_roi,
read_roi=Roi((0, 0, 0), input_size),
read_roi=Roi((0,)*input_size.dims, input_size),
write_roi=Roi(context, output_size),
process_function=predict_fn,
check_function=None,
Expand Down
3 changes: 3 additions & 0 deletions dacapo/utils/balance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def balance_weights(
scale_slab *= np.take(w, labels_slab)
"""

if label_data.dtype == bool:
label_data = label_data.astype(np.uint8)

if moving_counts is None:
moving_counts = []
unique_labels = np.unique(label_data)
Expand Down
3 changes: 0 additions & 3 deletions tests/conf.py

This file was deleted.

28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import multiprocessing as mp
import os
import yaml

from dacapo.options import Options

import pytest


@pytest.fixture(params=["fork", "spawn"], autouse=True)
def context(monkeypatch):
ctx = mp.get_context("spawn")
monkeypatch.setattr(mp, "Queue", ctx.Queue)
monkeypatch.setattr(mp, "Process", ctx.Process)
monkeypatch.setattr(mp, "Event", ctx.Event)
monkeypatch.setattr(mp, "Value", ctx.Value)


@pytest.fixture(autouse=True)
def runs_base_dir(tmpdir):
options_file = tmpdir / "dacapo.yaml"
os.environ["DACAPO_OPTIONS_FILE"] = f"{options_file}"

with open(options_file, "w") as f:
options_file.write(yaml.safe_dump({"runs_base_dir": f"{tmpdir}"}))

assert Options.config_file() == options_file
assert Options.instance().runs_base_dir == tmpdir
1 change: 0 additions & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
dummy_architecture,
unet_architecture,
unet_3d_architecture,
unet_architecture_builder,
)
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import (
Expand Down
Loading

0 comments on commit aa3b72c

Please sign in to comment.