From 8c29de8e942f1224c2a83eb5831bc98cedb42878 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Thu, 9 May 2024 11:55:34 -0400 Subject: [PATCH 1/4] cellpose_unet_model --- dacapo/experiments/architectures/__init__.py | 1 + .../architectures/cellpose_unet.py | 76 +++++++++++++++++++ .../architectures/cellpose_unet_config.py | 41 ++++++++++ tests/fixtures/__init__.py | 4 +- tests/fixtures/architectures.py | 10 ++- tests/fixtures/runs.py | 17 +++++ 6 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 dacapo/experiments/architectures/cellpose_unet.py create mode 100644 dacapo/experiments/architectures/cellpose_unet_config.py diff --git a/dacapo/experiments/architectures/__init__.py b/dacapo/experiments/architectures/__init__.py index 6125893c1..f21fe05f1 100644 --- a/dacapo/experiments/architectures/__init__.py +++ b/dacapo/experiments/architectures/__init__.py @@ -5,3 +5,4 @@ DummyArchitecture, ) # noqa from .cnnectome_unet_config import CNNectomeUNetConfig, CNNectomeUNet # noqa +from .cellpose_unet_config import CellposUNetConfig, CellposeUnet # noqa diff --git a/dacapo/experiments/architectures/cellpose_unet.py b/dacapo/experiments/architectures/cellpose_unet.py new file mode 100644 index 000000000..c1645f1f4 --- /dev/null +++ b/dacapo/experiments/architectures/cellpose_unet.py @@ -0,0 +1,76 @@ +from cellpose.resnet_torch import CPnet +from .architecture import Architecture +from funlib.geometry import Coordinate + + # example + # nout = 4 + # sz = 3 + # self.net = CPnet( + # nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0 + # ) +# currently the input channels are embedded in nbdase, but they should be passed as a separate parameternbase = [in_chan, 32, 64, 128, 256] +class CellposeUnet(Architecture): + def __init__(self, architecture_config): + super().__init__() + self._input_shape = Coordinate(architecture_config.input_shape) + self._nbase = architecture_config.nbase + self._sz = self._input_shape.dims + self._eval_shape_increase = Coordinate((0,) * self._sz) + self._nout = architecture_config.nout + print("conv_3D:",architecture_config.conv_3D) + self.unet = CPnet( + architecture_config.nbase, + architecture_config.nout, + self._sz, + architecture_config.mkldnn, + architecture_config.conv_3D, + architecture_config.max_pool, + architecture_config.diam_mean, + ) + print(self.unet) + + def forward(self, data): + """ + Forward pass of the CPnet model. + + Args: + data (torch.Tensor): Input data. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + if self.unet.mkldnn: + data = data.to_mkldnn() + T0 = self.unet.downsample(data) + if self.unet.mkldnn: + style = self.unet.make_style(T0[-1].to_dense()) + else: + style = self.unet.make_style(T0[-1]) + # style0 = style + if not self.unet.style_on: + style = style * 0 + T1 = self.unet.upsample(style, T0, self.unet.mkldnn) + # head layer + # T1 = self.unet.output(T1) + if self.unet.mkldnn: + T0 = [t0.to_dense() for t0 in T0] + T1 = T1.to_dense() + return T1 + + @property + def input_shape(self): + return self._input_shape + + @property + def num_in_channels(self) -> int: + return self._nbase[0] + + @property + def num_out_channels(self) -> int: + return self._nout + + @property + def eval_shape_increase(self): + return self._eval_shape_increase + + diff --git a/dacapo/experiments/architectures/cellpose_unet_config.py b/dacapo/experiments/architectures/cellpose_unet_config.py new file mode 100644 index 000000000..3d3338042 --- /dev/null +++ b/dacapo/experiments/architectures/cellpose_unet_config.py @@ -0,0 +1,41 @@ +import attr + +from .architecture_config import ArchitectureConfig +from .cellpose_unet import CellposeUnet + +from funlib.geometry import Coordinate + +from typing import List, Optional + + +@attr.s +class CellposUNetConfig(ArchitectureConfig): + """This class configures the CellPose based on + https://github.com/MouseLand/cellpose/blob/main/cellpose/resnet_torch.py + """ + + architecture_type = CellposeUnet + + input_shape: Coordinate = attr.ib( + metadata={ + "help_text": "The shape of the data passed into the network during training." + } + ) + nbase: List[int] = attr.ib( + metadata={ + "help_text": "List of integers representing the number of channels in each layer of the downsample path." + } + ) + nout: int = attr.ib(metadata={"help_text": "Number of output channels."}) + mkldnn: Optional[bool] = attr.ib( + default=False, metadata={"help_text": "Whether to use MKL-DNN acceleration."} + ) + conv_3D: bool = attr.ib( + default=False, metadata={"help_text": "Whether to use 3D convolution."} + ) + max_pool: Optional[bool] = attr.ib( + default=True, metadata={"help_text": "Whether to use max pooling."} + ) + diam_mean: Optional[float] = attr.ib( + default=30.0, metadata={"help_text": "Mean diameter of the cells."} + ) \ No newline at end of file diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 3ea282acc..ca5ef6365 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,11 +1,11 @@ from .db import options -from .architectures import dummy_architecture +from .architectures import dummy_architecture, cellpose_architecture from .arrays import dummy_array, zarr_array, cellmap_array from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit from .evaluators import binary_3_channel_evaluator from .losses import dummy_loss from .post_processors import argmax, threshold from .predictors import distance_predictor, onehot_predictor -from .runs import dummy_run, distance_run, onehot_run +from .runs import dummy_run, distance_run, onehot_run, cellpose_run from .tasks import dummy_task, distance_task, onehot_task from .trainers import dummy_trainer, gunpowder_trainer diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 6980c8f6b..89c66cc7a 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -1,4 +1,4 @@ -from dacapo.experiments.architectures import DummyArchitectureConfig +from dacapo.experiments.architectures import DummyArchitectureConfig, CellposUNetConfig import pytest @@ -8,3 +8,11 @@ def dummy_architecture(): yield DummyArchitectureConfig( name="dummy_architecture", num_in_channels=1, num_out_channels=12 ) + +@pytest.fixture +def cellpose_architecture(): + yield CellposUNetConfig( + name="cellpose_architecture", input_shape=(216, 216, 216), + nbase=[1, 12, 24, 48, 96], nout = 12, conv_3D = True + # nbase=[1, 32, 64, 128, 256], nout = 32, conv_3D = True + ) \ No newline at end of file diff --git a/tests/fixtures/runs.py b/tests/fixtures/runs.py index 99c4d3269..a732b2ead 100644 --- a/tests/fixtures/runs.py +++ b/tests/fixtures/runs.py @@ -55,3 +55,20 @@ def onehot_run( repetition=0, num_iterations=100, ) + +@pytest.fixture() +def cellpose_run( + dummy_datasplit, + cellpose_architecture, + dummy_task, + dummy_trainer, +): + yield RunConfig( + name="cellpose_run", + task_config=dummy_task, + architecture_config=cellpose_architecture, + trainer_config=dummy_trainer, + datasplit_config=dummy_datasplit, + repetition=0, + num_iterations=100, + ) \ No newline at end of file From 92c7a44f1ea6a360a96d8285ce265f7987ff0ac2 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Thu, 9 May 2024 11:56:39 -0400 Subject: [PATCH 2/4] cellpose_unet --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0ab64cdff..e15fefa87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "scipy", "upath", "boto3", + "cellpose", ] # extras From ec221b8641c544ac556c49840b1eb2bd82c593c5 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Thu, 9 May 2024 11:59:46 -0400 Subject: [PATCH 3/4] pytest --- tests/operations/test_cellpose.py | 61 +++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/operations/test_cellpose.py diff --git a/tests/operations/test_cellpose.py b/tests/operations/test_cellpose.py new file mode 100644 index 000000000..2ac98fd07 --- /dev/null +++ b/tests/operations/test_cellpose.py @@ -0,0 +1,61 @@ +import numpy as np +from dacapo.store.create_store import create_stats_store +from ..fixtures import * + +from dacapo.experiments import Run +from dacapo.store.create_store import create_config_store, create_weights_store +from dacapo.train import train_run +from pytest_lazy_fixtures import lf +import pytest + +import logging + +logging.basicConfig(level=logging.INFO) + + +# skip the test for the Apple Paravirtual device +# that does not support Metal 2.0 +@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning") +@pytest.mark.parametrize( + "run_config", + [ + lf("cellpose_run"), + ], +) +def test_train( + run_config, +): + print("Test train") + # create a store + + store = create_config_store() + stats_store = create_stats_store() + weights_store = create_weights_store() + + # store the configs + + store.store_run_config(run_config) + run = Run(run_config) + print("Run created ") + print(run.model) + + # # ------------------------------------- + + # # train + + # weights_store.store_weights(run, 0) + # print("Weights stored") + # train_run(run) + + # init_weights = weights_store.retrieve_weights(run.name, 0) + # final_weights = weights_store.retrieve_weights(run.name, run.train_until) + + # for name, weight in init_weights.model.items(): + # weight_diff = (weight - final_weights.model[name]).sum() + # assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff + + # # assert train_stats and validation_scores are available + + # training_stats = stats_store.retrieve_training_stats(run_config.name) + + # assert training_stats.trained_until() == run_config.num_iterations From c698dc5da3fa5e49f9b64bb469e059415ac4ef4a Mon Sep 17 00:00:00 2001 From: mzouink Date: Thu, 9 May 2024 16:00:49 +0000 Subject: [PATCH 4/4] :art: Format Python code with psf/black --- .../architectures/cellpose_unet.py | 23 +++++++++---------- .../architectures/cellpose_unet_config.py | 2 +- tests/fixtures/architectures.py | 10 +++++--- tests/fixtures/runs.py | 3 ++- tests/operations/test_cellpose.py | 2 +- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/dacapo/experiments/architectures/cellpose_unet.py b/dacapo/experiments/architectures/cellpose_unet.py index c1645f1f4..cbdf707ad 100644 --- a/dacapo/experiments/architectures/cellpose_unet.py +++ b/dacapo/experiments/architectures/cellpose_unet.py @@ -2,12 +2,13 @@ from .architecture import Architecture from funlib.geometry import Coordinate - # example - # nout = 4 - # sz = 3 - # self.net = CPnet( - # nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0 - # ) + +# example +# nout = 4 +# sz = 3 +# self.net = CPnet( +# nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0 +# ) # currently the input channels are embedded in nbdase, but they should be passed as a separate parameternbase = [in_chan, 32, 64, 128, 256] class CellposeUnet(Architecture): def __init__(self, architecture_config): @@ -17,7 +18,7 @@ def __init__(self, architecture_config): self._sz = self._input_shape.dims self._eval_shape_increase = Coordinate((0,) * self._sz) self._nout = architecture_config.nout - print("conv_3D:",architecture_config.conv_3D) + print("conv_3D:", architecture_config.conv_3D) self.unet = CPnet( architecture_config.nbase, architecture_config.nout, @@ -50,13 +51,13 @@ def forward(self, data): if not self.unet.style_on: style = style * 0 T1 = self.unet.upsample(style, T0, self.unet.mkldnn) - # head layer + # head layer # T1 = self.unet.output(T1) if self.unet.mkldnn: T0 = [t0.to_dense() for t0 in T0] T1 = T1.to_dense() return T1 - + @property def input_shape(self): return self._input_shape @@ -68,9 +69,7 @@ def num_in_channels(self) -> int: @property def num_out_channels(self) -> int: return self._nout - + @property def eval_shape_increase(self): return self._eval_shape_increase - - diff --git a/dacapo/experiments/architectures/cellpose_unet_config.py b/dacapo/experiments/architectures/cellpose_unet_config.py index 3d3338042..63d71c83d 100644 --- a/dacapo/experiments/architectures/cellpose_unet_config.py +++ b/dacapo/experiments/architectures/cellpose_unet_config.py @@ -38,4 +38,4 @@ class CellposUNetConfig(ArchitectureConfig): ) diam_mean: Optional[float] = attr.ib( default=30.0, metadata={"help_text": "Mean diameter of the cells."} - ) \ No newline at end of file + ) diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 89c66cc7a..0c67ae15d 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -9,10 +9,14 @@ def dummy_architecture(): name="dummy_architecture", num_in_channels=1, num_out_channels=12 ) + @pytest.fixture def cellpose_architecture(): yield CellposUNetConfig( - name="cellpose_architecture", input_shape=(216, 216, 216), - nbase=[1, 12, 24, 48, 96], nout = 12, conv_3D = True + name="cellpose_architecture", + input_shape=(216, 216, 216), + nbase=[1, 12, 24, 48, 96], + nout=12, + conv_3D=True # nbase=[1, 32, 64, 128, 256], nout = 32, conv_3D = True - ) \ No newline at end of file + ) diff --git a/tests/fixtures/runs.py b/tests/fixtures/runs.py index a732b2ead..16b00d746 100644 --- a/tests/fixtures/runs.py +++ b/tests/fixtures/runs.py @@ -56,6 +56,7 @@ def onehot_run( num_iterations=100, ) + @pytest.fixture() def cellpose_run( dummy_datasplit, @@ -71,4 +72,4 @@ def cellpose_run( datasplit_config=dummy_datasplit, repetition=0, num_iterations=100, - ) \ No newline at end of file + ) diff --git a/tests/operations/test_cellpose.py b/tests/operations/test_cellpose.py index 2ac98fd07..e55cc3321 100644 --- a/tests/operations/test_cellpose.py +++ b/tests/operations/test_cellpose.py @@ -23,7 +23,7 @@ ], ) def test_train( - run_config, + run_config, ): print("Test train") # create a store