Skip to content

Commit

Permalink
Basic tutorial (#304)
Browse files Browse the repository at this point in the history
@pattonw 
Adds a simple tutorial start-to-finish tutorial for DaCapo.
Still a few things left to resolve before its ready for merging:

- [ ] Loss behaving strangely after the first validation (temporary
solution involves giving an extra singleton dimension to the provided
data)
- [ ] Validation fails (we want to show validation loss/scores,
validation outputs across iterations, and how to finally take the best
iteration and apply it with the post processing to a volume manually)
- [ ] stdout too verbose
- [ ] plotting functions should be built into `DaCapo` (there are
plotting functions, not sure how to include bokeh plots so I didn't dive
too deep into this)
  • Loading branch information
mzouink authored Oct 22, 2024
2 parents 1369fa3 + be53a19 commit f1fdef3
Show file tree
Hide file tree
Showing 40 changed files with 1,076 additions and 392 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ jobs:
fetch-depth: 0 # otherwise, you will failed to push refs to dest repo
- name: install dacapo
# run: pip install .[docs]
run: pip install sphinx-autodoc-typehints sphinx-autoapi sphinx-click sphinx-rtd-theme myst-parser
run: pip install sphinx-autodoc-typehints sphinx-autoapi sphinx-click sphinx-rtd-theme myst-parser jupytext ipykernel nbsphinx
- name: parse notebooks
run: jupytext --to notebook --execute ./docs/source/notebooks/*.py
- name: remove notebook scripts
run: rm ./docs/source/notebooks/*.py
- name: Build and Commit
uses: sphinx-notes/pages@v2
with:
Expand Down
1 change: 1 addition & 0 deletions dacapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .validate import validate, validate_run # noqa
from .predict import predict # noqa
from .blockwise import run_blockwise, segment_blockwise # noqa
from . import predict_local
1 change: 0 additions & 1 deletion dacapo/blockwise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .blockwise_task import DaCapoBlockwiseTask
from .scheduler import run_blockwise, segment_blockwise
from . import global_vars
1 change: 0 additions & 1 deletion dacapo/blockwise/global_vars.py

This file was deleted.

27 changes: 4 additions & 23 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np
import click
from dacapo.blockwise import global_vars

import logging

Expand All @@ -28,20 +27,6 @@
path = __file__


def is_global_run_set(run_name) -> bool:
if global_vars.current_run is not None:
if global_vars.current_run.name == run_name:
return True
else:
logger.error(
f"Found global run {global_vars.current_run.name} but looking for {run_name}"
)
return False
else:
logger.error("No global run is set.")
return False


@click.group()
@click.option(
"--log-level",
Expand Down Expand Up @@ -131,14 +116,10 @@ def io_loop():
compute_context = create_compute_context()
device = compute_context.device

if is_global_run_set(run_name):
logger.warning("Using global run variable")
run = global_vars.current_run
else:
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None and compute_context.distribute_workers:
# create weights store
Expand Down
7 changes: 6 additions & 1 deletion dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def run_blockwise(
)
print("Running blockwise with worker_file: ", worker_file)
print(f"Using compute context: {create_compute_context()}")
success = daisy.run_blockwise([task])
compute_context = create_compute_context()
print(f"Using compute context: {compute_context}")

multiprocessing = compute_context.distribute_workers

success = daisy.run_blockwise([task], multiprocessing=multiprocessing)
return success


Expand Down
15 changes: 13 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,19 @@ def cli(log_level):
@click.option(
"-r", "--run-name", required=True, type=str, help="The NAME of the run to train."
)
def train(run_name):
dacapo.train(run_name) # TODO: run with compute_context
@click.option(
"--no-validation", is_flag=True, help="Disable validation after training."
)
def train(run_name, no_validation):
"""
Train a model with the specified run name.
Args:
run_name (str): The name of the run to train.
no_validation (bool): Flag to disable validation after training.
"""
do_validate = not no_validation
dacapo.train(run_name, do_validate=do_validate)


@cli.command()
Expand Down
38 changes: 35 additions & 3 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, architecture_config):
activation after the upsample operation.
- use_attention (optional): Whether or not to use an attention block
in the U-Net.
- batch_norm (optional): Whether to use batch normalization.
Raises:
ValueError: If the input shape is not given.
Examples:
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(self, architecture_config):
self.upsample_factors if self.upsample_factors is not None else []
)
self.use_attention = architecture_config.use_attention
self.batch_norm = architecture_config.batch_norm

self.unet = self.module()

Expand Down Expand Up @@ -261,6 +263,7 @@ def module(self):
upsample_channel_contraction=[False]
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
batch_norm=self.batch_norm,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand All @@ -279,6 +282,7 @@ def module(self):
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
activation="ReLU",
batch_norm=self.batch_norm,
)
layers.append(conv)
unet = torch.nn.Sequential(*layers)
Expand Down Expand Up @@ -455,6 +459,7 @@ def __init__(
upsample_channel_contraction=False,
activation_on_upsample=False,
use_attention=False,
batch_norm=True,
):
"""
Create a U-Net::
Expand Down Expand Up @@ -573,6 +578,7 @@ def __init__(

self.dims = len(downsample_factors[0])
self.use_attention = use_attention
self.batch_norm = batch_norm

# default arguments

Expand Down Expand Up @@ -611,6 +617,7 @@ def __init__(
kernel_size_down[level],
activation=activation,
padding=padding,
batch_norm=self.batch_norm,
)
for level in range(self.num_levels)
]
Expand Down Expand Up @@ -668,6 +675,7 @@ def __init__(
),
dims=self.dims,
upsample_factor=downsample_factors[level],
batch_norm=self.batch_norm,
)
for level in range(self.num_levels - 1)
]
Expand All @@ -694,6 +702,7 @@ def __init__(
kernel_size_up[level],
activation=activation,
padding=padding,
batch_norm=self.batch_norm,
)
for level in range(self.num_levels - 1)
]
Expand Down Expand Up @@ -827,7 +836,13 @@ class ConvPass(torch.nn.Module):
"""

def __init__(
self, in_channels, out_channels, kernel_sizes, activation, padding="valid"
self,
in_channels,
out_channels,
kernel_sizes,
activation,
padding="valid",
batch_norm=True,
):
"""
Convolutional pass module. This module performs a series of
Expand Down Expand Up @@ -869,6 +884,15 @@ def __init__(

try:
layers.append(conv(in_channels, out_channels, kernel_size, padding=pad))
if batch_norm:
layers.append(
{
2: torch.nn.BatchNorm2d,
3: torch.nn.BatchNorm3d,
}[
self.dims
](out_channels)
)
except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)

Expand Down Expand Up @@ -1283,7 +1307,7 @@ class AttentionBlockModule(nn.Module):
The AttentionBlockModule is an instance of the ``torch.nn.Module`` class.
"""

def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True):
"""
Initialize the Attention Block Module.
Expand All @@ -1306,13 +1330,19 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
self.batch_norm = batch_norm
if upsample_factor is not None:
self.upsample_factor = upsample_factor
else:
self.upsample_factor = (2,) * self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same"
F_g,
F_int,
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
)

self.W_x = nn.Sequential(
Expand All @@ -1322,6 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
),
Downsample(upsample_factor),
)
Expand All @@ -1332,6 +1363,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
kernel_sizes=self.kernel_sizes,
activation="Sigmoid",
padding="same",
batch_norm=self.batch_norm,
)

up_mode = {2: "bilinear", 3: "trilinear"}[self.dims]
Expand Down
4 changes: 4 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ class CNNectomeUNetConfig(ArchitectureConfig):
"help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D."
},
)
batch_norm: bool = attr.ib(
default=True,
metadata={"help_text": "Whether to use batch normalization."},
)
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .dummy_datasplit_config import DummyDataSplitConfig
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator
from .datasplit_generator import DataSplitGenerator, DatasetSpec
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
print(
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def roi(self) -> Roi:
This method returns the region of interest of the resampled array.
"""
return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink")
return self._source_array.roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink"
)

@property
def writable(self) -> bool:
Expand Down Expand Up @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
Note:
This method returns the data of the resampled array within the given region of interest.
"""
snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow")
snapped_roi = roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow"
)
resampled_array = funlib.persistence.Array(
rescale(
self._source_array[snapped_roi].astype(np.float32),
Expand Down
8 changes: 4 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from collections import OrderedDict
import logging
from upath import UPath as Path
import json
from typing import Dict, Tuple, Any, Optional, List

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -273,7 +271,9 @@ def roi(self) -> Roi:
This method is used to return the region of interest of the array.
"""
if self.snap_to_grid is not None:
return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink")
return self._daisy_array.roi.snap_to_grid(
np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink"
)
else:
return self._daisy_array.roi

Expand Down Expand Up @@ -469,7 +469,7 @@ def create_from_array_identifier(
write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size
write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape)))
zarr_container = zarr.open(array_identifier.container, "a")
if num_channels is None or num_channels == 1:
if num_channels is None:
axes = [axis for axis in axes if "c" not in axis]
num_channels = None
else:
Expand Down
4 changes: 2 additions & 2 deletions dacapo/experiments/datasplits/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __repr__(self) -> str:
Notes:
This method is used to return the official string representation of the dataset object.
"""
return f"Dataset({self.name})"
return f"ds_{self.name.replace('/', '_')}"

def __str__(self) -> str:
"""
Expand All @@ -109,7 +109,7 @@ def __str__(self) -> str:
Notes:
This method is used to return the string representation of the dataset object.
"""
return f"Dataset({self.name})"
return f"ds_{self.name.replace('/', '_')}"

def _neuroglancer_layers(self, prefix="", exclude_layers=None):
"""
Expand Down
Loading

0 comments on commit f1fdef3

Please sign in to comment.