Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validation local blockwise: in progress #294

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a1e46bd
local blockwise
mzouink Sep 23, 2024
a81983b
:art: Format Python code with psf/black
mzouink Sep 23, 2024
59a9b67
Format Python code with psf/black push (#295)
rhoadesScholar Sep 24, 2024
473bc38
all local changes
mzouink Sep 26, 2024
ca206aa
all local changes
mzouink Sep 26, 2024
1eca4e8
local changes
mzouink Sep 26, 2024
4520c3f
remove plot debug prints
mzouink Oct 2, 2024
aca7f68
env DACAPO_OPTIONS_FILE
mzouink Oct 2, 2024
9dfbdc1
remove trash
mzouink Oct 2, 2024
5de82f2
:art: Format Python code with psf/black
mzouink Oct 2, 2024
39f9dc9
Format Python code with psf/black push (#299)
mzouink Oct 2, 2024
00149c7
Merge branch 'main' into fix_local_predict
mzouink Oct 2, 2024
2775317
fix restart run
mzouink Oct 2, 2024
35d8644
batch norm params
mzouink Oct 2, 2024
d372cd8
revert changes
mzouink Oct 2, 2024
e9285c8
gunpowder trainer reject min option
mzouink Oct 2, 2024
ae93dbc
remove + uint8
mzouink Oct 2, 2024
6d7d5d8
fix validate
mzouink Oct 2, 2024
6c922da
:art: Format Python code with psf/black
mzouink Oct 2, 2024
3abcf96
Format Python code with psf/black push (#301)
mzouink Oct 2, 2024
eb03070
avoid adding auxiliary loss if not predicting affs since this results…
pattonw Oct 16, 2024
a1e492a
Add a minimal tutorial
pattonw Oct 16, 2024
0ef4f61
add the minimal tutorial and notebook parsing to the docs workflow
pattonw Oct 16, 2024
8de17a3
Basic tutorial (#303)
mzouink Oct 16, 2024
689a5a8
Merge branch 'basic_tutorial' into fix_local_predict
mzouink Oct 16, 2024
6c539e3
fix tests
mzouink Oct 16, 2024
27f531c
fix validate tests
mzouink Oct 16, 2024
60a2364
:art: Format Python code with psf/black
mzouink Oct 16, 2024
937371e
Format Python code with psf/black push (#307)
mzouink Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ jobs:
fetch-depth: 0 # otherwise, you will failed to push refs to dest repo
- name: install dacapo
# run: pip install .[docs]
run: pip install sphinx-autodoc-typehints sphinx-autoapi sphinx-click sphinx-rtd-theme myst-parser
run: pip install sphinx-autodoc-typehints sphinx-autoapi sphinx-click sphinx-rtd-theme myst-parser jupytext ipykernel nbsphinx
- name: parse notebooks
run: jupytext --to notebook --execute ./docs/source/notebooks/*.py
- name: remove notebook scripts
run: rm ./docs/source/notebooks/*.py
- name: Build and Commit
uses: sphinx-notes/pages@v2
with:
Expand Down
7 changes: 6 additions & 1 deletion dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def run_blockwise(
)
print("Running blockwise with worker_file: ", worker_file)
print(f"Using compute context: {create_compute_context()}")
success = daisy.run_blockwise([task])
compute_context = create_compute_context()
print(f"Using compute context: {compute_context}")

multiprocessing = compute_context.distribute_workers

success = daisy.run_blockwise([task], multiprocessing=multiprocessing)
return success


Expand Down
15 changes: 13 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,19 @@ def cli(log_level):
@click.option(
"-r", "--run-name", required=True, type=str, help="The NAME of the run to train."
)
def train(run_name):
dacapo.train(run_name) # TODO: run with compute_context
@click.option(
"--no-validation", is_flag=True, help="Disable validation after training."
)
def train(run_name, no_validation):
"""
Train a model with the specified run name.

Args:
run_name (str): The name of the run to train.
no_validation (bool): Flag to disable validation after training.
"""
do_validate = not no_validation
dacapo.train(run_name, do_validate=do_validate)


@cli.command()
Expand Down
38 changes: 35 additions & 3 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, architecture_config):
activation after the upsample operation.
- use_attention (optional): Whether or not to use an attention block
in the U-Net.
- batch_norm (optional): Whether to use batch normalization.
Raises:
ValueError: If the input shape is not given.
Examples:
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(self, architecture_config):
self.upsample_factors if self.upsample_factors is not None else []
)
self.use_attention = architecture_config.use_attention
self.batch_norm = architecture_config.batch_norm

self.unet = self.module()

Expand Down Expand Up @@ -261,6 +263,7 @@ def module(self):
upsample_channel_contraction=[False]
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
batch_norm=self.batch_norm,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand All @@ -279,6 +282,7 @@ def module(self):
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
activation="ReLU",
batch_norm=self.batch_norm,
)
layers.append(conv)
unet = torch.nn.Sequential(*layers)
Expand Down Expand Up @@ -455,6 +459,7 @@ def __init__(
upsample_channel_contraction=False,
activation_on_upsample=False,
use_attention=False,
batch_norm=True,
):
"""
Create a U-Net::
Expand Down Expand Up @@ -573,6 +578,7 @@ def __init__(

self.dims = len(downsample_factors[0])
self.use_attention = use_attention
self.batch_norm = batch_norm

# default arguments

Expand Down Expand Up @@ -611,6 +617,7 @@ def __init__(
kernel_size_down[level],
activation=activation,
padding=padding,
batch_norm=self.batch_norm,
)
for level in range(self.num_levels)
]
Expand Down Expand Up @@ -668,6 +675,7 @@ def __init__(
),
dims=self.dims,
upsample_factor=downsample_factors[level],
batch_norm=self.batch_norm,
)
for level in range(self.num_levels - 1)
]
Expand All @@ -694,6 +702,7 @@ def __init__(
kernel_size_up[level],
activation=activation,
padding=padding,
batch_norm=self.batch_norm,
)
for level in range(self.num_levels - 1)
]
Expand Down Expand Up @@ -827,7 +836,13 @@ class ConvPass(torch.nn.Module):
"""

def __init__(
self, in_channels, out_channels, kernel_sizes, activation, padding="valid"
self,
in_channels,
out_channels,
kernel_sizes,
activation,
padding="valid",
batch_norm=True,
):
"""
Convolutional pass module. This module performs a series of
Expand Down Expand Up @@ -869,6 +884,15 @@ def __init__(

try:
layers.append(conv(in_channels, out_channels, kernel_size, padding=pad))
if batch_norm:
layers.append(
{
2: torch.nn.BatchNorm2d,
3: torch.nn.BatchNorm3d,
}[
self.dims
](out_channels)
)
except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)

Expand Down Expand Up @@ -1283,7 +1307,7 @@ class AttentionBlockModule(nn.Module):
The AttentionBlockModule is an instance of the ``torch.nn.Module`` class.
"""

def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True):
"""
Initialize the Attention Block Module.

Expand All @@ -1306,13 +1330,19 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
self.batch_norm = batch_norm
if upsample_factor is not None:
self.upsample_factor = upsample_factor
else:
self.upsample_factor = (2,) * self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same"
F_g,
F_int,
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
)

self.W_x = nn.Sequential(
Expand All @@ -1322,6 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
),
Downsample(upsample_factor),
)
Expand All @@ -1332,6 +1363,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
kernel_sizes=self.kernel_sizes,
activation="Sigmoid",
padding="same",
batch_norm=self.batch_norm,
)

up_mode = {2: "bilinear", 3: "trilinear"}[self.dims]
Expand Down
4 changes: 4 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ class CNNectomeUNetConfig(ArchitectureConfig):
"help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D."
},
)
batch_norm: bool = attr.ib(
default=True,
metadata={"help_text": "Whether to use batch normalization."},
)
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .dummy_datasplit_config import DummyDataSplitConfig
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator
from .datasplit_generator import DataSplitGenerator, DatasetSpec
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def roi(self) -> Roi:
This method returns the region of interest of the resampled array.

"""
return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink")
return self._source_array.roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink"
)

@property
def writable(self) -> bool:
Expand Down Expand Up @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
Note:
This method returns the data of the resampled array within the given region of interest.
"""
snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow")
snapped_roi = roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow"
)
resampled_array = funlib.persistence.Array(
rescale(
self._source_array[snapped_roi].astype(np.float32),
Expand Down
8 changes: 4 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from collections import OrderedDict
import logging
from upath import UPath as Path
import json
from typing import Dict, Tuple, Any, Optional, List

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -273,7 +271,9 @@ def roi(self) -> Roi:
This method is used to return the region of interest of the array.
"""
if self.snap_to_grid is not None:
return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink")
return self._daisy_array.roi.snap_to_grid(
np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink"
)
else:
return self._daisy_array.roi

Expand Down Expand Up @@ -469,7 +469,7 @@ def create_from_array_identifier(
write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size
write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape)))
zarr_container = zarr.open(array_identifier.container, "a")
if num_channels is None or num_channels == 1:
if num_channels is None:
axes = [axis for axis in axes if "c" not in axis]
num_channels = None
else:
Expand Down
4 changes: 2 additions & 2 deletions dacapo/experiments/datasplits/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __repr__(self) -> str:
Notes:
This method is used to return the official string representation of the dataset object.
"""
return f"Dataset({self.name})"
return f"ds_{self.name.replace('/', '_')}"

def __str__(self) -> str:
"""
Expand All @@ -109,7 +109,7 @@ def __str__(self) -> str:
Notes:
This method is used to return the string representation of the dataset object.
"""
return f"Dataset({self.name})"
return f"ds_{self.name.replace('/', '_')}"

def _neuroglancer_layers(self, prefix="", exclude_layers=None):
"""
Expand Down
38 changes: 21 additions & 17 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def resize_if_needed(
f"have different dimensions {zarr_array.dims}"
)
if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]):
print(
f"dataset {array_config} needs resampling to {target_resolution}, upsample: {raw_upsample}, downsample: {raw_downsample}"
)
return ResampledArrayConfig(
name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled",
source_array_config=array_config,
Expand All @@ -93,6 +96,7 @@ def resize_if_needed(
interp_order=False,
)
else:
# print(f"dataset {array_config.dataset} does not need resampling found {raw_voxel_size}=={target_resolution}")
return array_config


Expand Down Expand Up @@ -959,23 +963,23 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
constant=1,
)

if len(target_images) > 1:
gt_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: gt for k, gt in target_images.items()},
source_array_configs={k: target_images[k] for k in current_targets},
)
mask_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: mask for k, mask in target_masks.items()},
# to be sure to have the same order
source_array_configs={k: target_masks[k] for k in current_targets},
)
else:
gt_config = list(target_images.values())[0]
mask_config = list(target_masks.values())[0]
# if len(target_images) > 1:
gt_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: gt for k, gt in target_images.items()},
source_array_configs={k: target_images[k] for k in current_targets},
)
mask_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: mask for k, mask in target_masks.items()},
# to be sure to have the same order
source_array_configs={k: target_masks[k] for k in current_targets},
)
# else:
# gt_config = list(target_images.values())[0]
# mask_config = list(target_masks.values())[0]

return raw_config, gt_config, mask_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def evaluate(self, output_array_identifier, evaluation_array):
This function is used to evaluate the output array against the evaluation array.
"""
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
evaluation_data = evaluation_array[evaluation_array.roi].squeeze()
output_data = output_array[output_array.roi].squeeze()
# removed the .squeeze() because it was used for batch size and now we are feeding 4d c, z, y, x
evaluation_data = evaluation_array[evaluation_array.roi]
output_data = output_array[output_array.roi]
print(
f"Evaluating binary segmentations on evaluation_data of shape: {evaluation_data.shape}"
)
Expand Down
19 changes: 12 additions & 7 deletions dacapo/experiments/tasks/losses/affinities_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,15 @@ def compute(self, prediction, target, weight):
weight[:, self.num_affinities :, ...],
)

return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()
if aux.shape[1] == 0:
return torch.nn.BCEWithLogitsLoss(reduction="none")(
affs, affs_target
).mean()
else:
return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()
Loading
Loading