Skip to content

Commit

Permalink
Merge branch 'tests_v0_3_5' of https://github.com/janelia-cellmap/dacapo
Browse files Browse the repository at this point in the history
 into tests_v0_3_5
  • Loading branch information
mzouink committed Nov 13, 2024
2 parents 4192606 + 9f7b0e4 commit c81f4de
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

logger = logging.getLogger(__name__)


class ThresholdPostProcessor(PostProcessor):
"""
A post-processor that applies a threshold to the prediction.
Expand Down
6 changes: 2 additions & 4 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,8 @@ def process(
sampling = tuple(float(v) / 2 for v in voxel_size)
# fixing the sampling for 2D images
if len(boundaries.shape) < len(sampling):
sampling = sampling[-len(boundaries.shape):]
distances = distance_transform_edt(
boundaries, sampling=sampling
)
sampling = sampling[-len(boundaries.shape) :]
distances = distance_transform_edt(boundaries, sampling=sampling)
distances = distances.astype(np.float32)

# restore original shape
Expand Down
4 changes: 1 addition & 3 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# criterion,
# )
dataset_iteration_scores.append(
[getattr(scores, criterion) for criterion in scores.criteria]
[getattr(scores, criterion) for criterion in scores.criteria]
)
except:
logger.error(
Expand All @@ -260,8 +260,6 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# the evaluator
# array_store.remove(output_array_identifier)



iteration_scores.append(dataset_iteration_scores)
# array_store.remove(prediction_array_identifier)

Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
docker
aws
cosem_starter
roadmap
autoapi/index
cli

.. include:: ../../README.md
:parser: myst_parser.sphinx_
:parser: myst_parser.sphinx_
77 changes: 77 additions & 0 deletions docs/source/roadmap.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
.. _sec_roadmap:

Road Map
========

Overview
--------

+-----------------------------------+------------------+-------------------------------+
| Task | Priority | Current State |
+===================================+==================+===============================+
| Write Documentation | High | Started with a long way to go |
+-----------------------------------+------------------+-------------------------------+
| Simplify configurations | High | First draft complete |
+-----------------------------------+------------------+-------------------------------+
| Develop Data Conventions | High | First draft complete |
+-----------------------------------+------------------+-------------------------------+
| Improve Blockwise Post-Processing | Low | Not Started |
+-----------------------------------+------------------+-------------------------------+
| Simplify Array handling | High | Almost done (Up/Down sampling)|
+-----------------------------------+------------------+-------------------------------+

Detailed Road Map
-----------------

- [ ] Write Documentation
- [ ] tutorials: not more than three, simple and continuously tested (with Github actions, small U-Net on CPU could work)
- [x] Basic tutorial: train a U-Net on a toy dataset
- [ ] Parametrize the basic tutorial across tasks (instance/semantic segmentation).
- [ ] Improve visualizations. Move some simple plotting functions to DaCapo.
- [ ] Add a pure pytorch implementation to show benefits side-by-side
- [ ] Track performance metrics (e.g., loss, accuracy, etc.) so we can make sure we aren't regressing
- [ ] semantic segmentation (LM and EM)
- [ ] instance segmentation (LM or EM, can be simulated)
- [ ] general documentation of CLI, also API for developers (curate docstrings)
- [x] Simplify configurations
- [x] Depricate old configs
- [x] Add simplified config for simple cases
- [x] can still get rid of `*Config` classes
- [x] Develop Data Conventions
- [x] document conventions
- [ ] convenience scripts to convert dataset into our convention (even starting from directories of PNG files)
- [ ] Improve Blockwise Post-Processing
- [ ] De-duplicate code between “in-memory” and “block-wise” processing
- [ ] have only block-wise algorithms, use those also for “in-memory”
- [ ] no more “in-memory”, this is just a run with a different Compute Context
- [ ] Incorporate `volara` into DaCapo (embargo until January)
- [ ] Improve debugging support (logging of chain of commands for reproducible runs)
- [ ] Split long post-processing steps into several smaller ones for composability (e.g., support running each step independently if we want to support choosing between `waterz` and `mutex_watershed` for fragment generation or agglomeration)
- [x] Incorporate `funlib.persistence` adaptors.
- [x] all of those can be adapters:
- [x] Binarize Labels into Mask
- [x] Scale/Shift intensities
- [ ] Up/Down sample (if easily possible)
- [ ] DVID source
- [x] Datatype conversions
- [x] everything else
- [x] simplify array configs accordingly

Can Have
--------

- [ ] Support other stats stores. Too much time, effort and code was put into the stats and didn’t provide a very nice interface:
- [ ] defining variables to store
- [ ] efficiently batch writing, storing and reading stats to both files and mongodb
- [ ] visualizing stats.
- [ ] Jeff and Marwan suggest MLFlow instead of WandB
- [ ] Support for slurm clusters
- [ ] Support for cloud computing (AWS)
- [ ] Lazy loading of dependencies (import takes too long)
- [ ] Support bioimage model spec for model dissemination

Non-Goals (for v1.0)
--------------------

- custom dash board
- GUI to run experiments
53 changes: 24 additions & 29 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,35 @@

logging.basicConfig(level=logging.INFO)

from dacapo.experiments.architectures import DummyArchitectureConfig, CNNectomeUNetConfig
from dacapo.experiments.architectures import (
DummyArchitectureConfig,
CNNectomeUNetConfig,
)

import pytest


def unet_architecture(batch_norm, upsample,use_attention, three_d):
def unet_architecture(batch_norm, upsample, use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
name = f"{name}_up" if upsample else name
name = f"{name}_att" if use_attention else name

if three_d:
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
else:
return CNNectomeUNetConfig(
name=name,
Expand All @@ -61,7 +64,6 @@ def unet_architecture(batch_norm, upsample,use_attention, three_d):
)



# skip the test for the Apple Paravirtual device
# that does not support Metal 2.0
@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning")
Expand Down Expand Up @@ -117,19 +119,15 @@ def test_train(
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_train_unet(
datasplit,
task,
trainer,
batch_norm,
upsample,
use_attention,
three_d):

datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

architecture_config = unet_architecture(batch_norm, upsample,use_attention, three_d)
architecture_config = unet_architecture(
batch_norm, upsample, use_attention, three_d
)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
Expand Down Expand Up @@ -167,6 +165,3 @@ def test_train_unet(
training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations



0 comments on commit c81f4de

Please sign in to comment.