diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml index a9ebfdec7..ad03af004 100644 --- a/.github/workflows/black.yaml +++ b/.github/workflows/black.yaml @@ -1,27 +1,17 @@ -name: black-action +name: Python Black on: [push, pull_request] jobs: - linter_name: - name: runner / black + lint: + name: Python Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Check files using the black formatter - uses: rickstaa/action-black@v1 - id: action_black - with: - black_args: "." - - name: Create Pull Request - if: steps.action_black.outputs.is_formatted == 'true' - uses: peter-evans/create-pull-request@v3 - with: - token: ${{ secrets.GITHUB_TOKEN }} - title: "Format Python code with psf/black push" - commit-message: ":art: Format Python code with psf/black" - body: | - There appear to be some python formatting errors in ${{ github.sha }}. This pull request - uses the [psf/black](https://github.com/psf/black) formatter to fix these issues. - base: ${{ github.head_ref }} # Creates pull request onto pull request or commit branch - branch: actions/black \ No newline at end of file + - name: Setup Python + uses: actions/setup-python@v1 + - name: Setup checkout + uses: actions/checkout@master + - name: Lint with Black + run: | + pip install black + black -v --check dacapo tests diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index d8d7b388d..5a84cc86b 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -1,7 +1,8 @@ -name: Generate Pages - -on: [push, pull_request] - +name: Pages +on: + push: + branches: + - master jobs: docs: runs-on: ubuntu-latest diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 000000000..58d200cff --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,34 @@ +name: Publish + +on: + push: + tags: "*" + +jobs: + build-n-publish: + name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install pypa/build + run: >- + python -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: >- + python -m + build + --sdist + --wheel + --outdir dist/ + - name: Publish distribution 📦 to PyPI + if: startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2ecaf3f05..020ca3074 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,6 +1,7 @@ name: Test -on: [push, pull_request] +on: + push: jobs: test: @@ -8,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.9", "3.10"] steps: - uses: actions/checkout@v2 @@ -22,4 +23,4 @@ jobs: pip install -r requirements-dev.txt - name: Test with pytest run: | - pytest tests + pytest tests \ No newline at end of file diff --git a/README.md b/README.md index 64d35064a..a51d4f996 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# DaCapo ![DaCapo](docs/source/_static/icon_dacapo.png) +![DaCapo](docs/source/_static/dacapo.svg) [![tests](https://github.com/funkelab/dacapo/actions/workflows/tests.yaml/badge.svg)](https://github.com/funkelab/dacapo/actions/workflows/tests.yaml) [![black](https://github.com/funkelab/dacapo/actions/workflows/black.yaml/badge.svg)](https://github.com/funkelab/dacapo/actions/workflows/black.yaml) diff --git a/care_train.py b/care_train.py new file mode 100644 index 000000000..399bc0dd0 --- /dev/null +++ b/care_train.py @@ -0,0 +1,152 @@ +import dacapo +import logging +import math +import torch +from torchsummary import summary + +# CARE task specific elements +from dacapo.experiments.datasplits.datasets.arrays import ( + ZarrArrayConfig, + IntensitiesArrayConfig, +) +from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig +from dacapo.experiments.datasplits import TrainValidateDataSplitConfig +from dacapo.experiments.architectures import CNNectomeUNetConfig +from dacapo.experiments.tasks import CARETaskConfig + +from dacapo.experiments.trainers import GunpowderTrainerConfig +from dacapo.experiments.trainers.gp_augments import ( + SimpleAugmentConfig, + ElasticAugmentConfig, + IntensityAugmentConfig, +) +from funlib.geometry import Coordinate +from dacapo.experiments.run_config import RunConfig +from dacapo.experiments.run import Run +from dacapo.store.create_store import create_config_store +from dacapo.train import train + + +# set basic login configs +logging.basicConfig(level=logging.INFO) + +raw_array_config_zarr = ZarrArrayConfig( + name="raw", + file_name="/n/groups/htem/users/br128/data/CBvBottom/CBxs_lobV_bottomp100um_training_0.n5", + dataset="volumes/raw_30nm", +) + +gt_array_config_zarr = ZarrArrayConfig( + name="gt", + file_name="/n/groups/htem/users/br128/data/CBvBottom/CBxs_lobV_bottomp100um_training_0.n5", + dataset="volumes/interpolated_90nm_aligned", +) + +raw_array_config_int = IntensitiesArrayConfig( + name="raw_norm", source_array_config=raw_array_config_zarr, min=0.0, max=1.0 +) + +gt_array_config_int = IntensitiesArrayConfig( + name="gt_norm", source_array_config=gt_array_config_zarr, min=0.0, max=1.0 +) + +dataset_config = RawGTDatasetConfig( + name="CBxs_lobV_bottomp100um_CARE_0", + raw_config=raw_array_config_int, + gt_config=gt_array_config_int, +) + +# TODO: check datasplit config, this honestly might work +datasplit_config = TrainValidateDataSplitConfig( + name="CBxs_lobV_bottomp100um_training_0.n5", + train_configs=[dataset_config], + validate_configs=[dataset_config], +) +""" +kernel size 3 +2 conv passes per block + +1 -- 100%, lose 4 pix - 286 pix +2 -- 50%, lose 8 pix - 142 pix +3 -- 25%, lose 16 pix - 32 pix +""" +# UNET config +architecture_config = CNNectomeUNetConfig( + name="small_unet", + input_shape=Coordinate(156, 156, 156), + # eval_shape_increase=Coordinate(72, 72, 72), + fmaps_in=1, + num_fmaps=8, + fmaps_out=32, + fmap_inc_factor=4, + downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + constant_upsample=True, +) + + +# CARE task +task_config = CARETaskConfig(name="CAREModel", num_channels=1, dims=3) + + +# trainier +trainer_config = GunpowderTrainerConfig( + name="gunpowder", + batch_size=2, + learning_rate=0.0001, + augments=[ + SimpleAugmentConfig(), + ElasticAugmentConfig( + control_point_spacing=(100, 100, 100), + control_point_displacement_sigma=(10.0, 10.0, 10.0), + rotation_interval=(0, math.pi / 2.0), + subsample=8, + uniform_3d_rotation=True, + ), + IntensityAugmentConfig( + scale=(0.25, 1.75), + shift=(-0.5, 0.35), + clip=False, + ), + ], + num_data_fetchers=20, + snapshot_interval=10000, + min_masked=0.15, +) + + +# run config +run_config = RunConfig( + name="CARE_train", + task_config=task_config, + architecture_config=architecture_config, + trainer_config=trainer_config, + datasplit_config=datasplit_config, + repetition=0, + num_iterations=100000, + validation_interval=1000, +) + +run = Run(run_config) + +# run summary TODO create issue +print(summary(run.model, (1, 156, 156, 156))) + + +# store configs, then train +config_store = create_config_store() + +config_store.store_datasplit_config(datasplit_config) +config_store.store_architecture_config(architecture_config) +config_store.store_task_config(task_config) +config_store.store_trainer_config(trainer_config) +config_store.store_run_config(run_config) + +# Optional start training by config name: +train(run_config.name) + +# CLI dacapo train -r {run_config.name} + + +""" +RuntimeError: Can not downsample shape torch.Size([1, 128, 47, 47, 47]) with factor (2, 2, 2), mismatch in spatial dimension 2 +""" diff --git a/dacapo/apply.py b/dacapo/apply.py index 434002ef6..64f23df3c 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -10,4 +10,3 @@ def apply(run_name: str, iteration: int, dataset_name: str): iteration, dataset_name, ) - raise NotImplementedError("This function is not yet implemented.") diff --git a/dacapo/cli.py b/dacapo/cli.py index be59df0c0..76a5e18e0 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -40,7 +40,7 @@ def validate(run_name, iteration): @cli.command() @click.option( - "-r", "--run-name", required=True, type=str, help="The name of the run to use." + "-r", "--run", required=True, type=str, help="The name of the run to use." ) @click.option( "-i", diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 01a261d09..338ac6ea1 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -273,9 +273,11 @@ def __init__( self.l_conv = nn.ModuleList( [ ConvPass( - in_channels - if level == 0 - else num_fmaps * fmap_inc_factor ** (level - 1), + ( + in_channels + if level == 0 + else num_fmaps * fmap_inc_factor ** (level - 1) + ), num_fmaps * fmap_inc_factor**level, kernel_size_down[level], activation=activation, @@ -327,9 +329,11 @@ def __init__( + num_fmaps * fmap_inc_factor ** (level + (1 - upsample_channel_contraction[level])), - num_fmaps * fmap_inc_factor**level - if num_fmaps_out is None or level != 0 - else num_fmaps_out, + ( + num_fmaps * fmap_inc_factor**level + if num_fmaps_out is None or level != 0 + else num_fmaps_out + ), kernel_size_up[level], activation=activation, padding=padding, diff --git a/dacapo/experiments/architectures/nlayer_discriminator.py b/dacapo/experiments/architectures/nlayer_discriminator.py new file mode 100644 index 000000000..c203fdbfb --- /dev/null +++ b/dacapo/experiments/architectures/nlayer_discriminator.py @@ -0,0 +1,79 @@ +from .architecture import Architecture + +import torch +import torch.nn as nn +import functools + + +class NLayerDiscriminator(Architecture): + """Defines a PatchGAN discriminator""" + + def __init__(self, architecture_config): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ngf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super().__init__() + + input_nc: int = architecture_config.input_nc + ngf: int = architecture_config.ngf + n_layers: int = architecture_config.n_layers + norm_layer = architecture_config.norm_layer + + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 1475c7b97..aceda2e77 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -5,9 +5,6 @@ import numpy as np from typing import Dict, Any -import logging - -logger = logging.getLogger(__file__) class ConcatArray(Array): @@ -102,9 +99,11 @@ def __getitem__(self, roi: Roi) -> np.ndarray: else self.default_array[roi] ) arrays = [ - self.source_arrays[channel][roi] - if channel in self.source_arrays - else default + ( + self.source_arrays[channel][roi] + if channel in self.source_arrays + else default + ) for channel in self.channels ] shapes = [array.shape for array in arrays] @@ -119,7 +118,5 @@ def __getitem__(self, roi: Roi) -> np.ndarray: axis=0, ) if concatenated.shape[0] == 1: - logger.info( - f"Concatenated array has only one channel: {self.name} {concatenated.shape}" - ) + raise Exception(f"{concatenated.shape}, shapes") return concatenated diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py index e08ffe562..beaa474d1 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py @@ -41,7 +41,7 @@ def attrs(self): @property def axes(self): - return ["c", "z", "y", "x"][-self.dims :] + return ["t", "z", "y", "x"][-self.dims :] @property def dims(self) -> int: diff --git a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py index 5f2bc0483..7101d737e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array): ((["b", "c"] if len(array.data.shape) == instance.dims + 2 else [])) + (["c"] if len(array.data.shape) == instance.dims + 1 else []) + [ - "c", + "t", "z", "y", "x", diff --git a/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py b/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py index ccdf50376..e16ef26e0 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py @@ -56,7 +56,7 @@ def voxel_size(self) -> Coordinate: @lazy_property.LazyProperty def roi(self) -> Roi: - return Roi(self._offset, self.shape) + return Roi(self._offset * self.shape) @property def writable(self) -> bool: diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 25f2c224e..cadfcb6cd 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -52,9 +52,9 @@ def axes(self): logger.debug( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" - f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", + f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}", ) - return ["c", "z", "y", "x"][-self.dims : :] + return ["t", "z", "y", "x"][-self.dims : :] @property def dims(self) -> int: diff --git a/dacapo/experiments/model.py b/dacapo/experiments/model.py index 8ca2b2b9e..bbaacb2dc 100644 --- a/dacapo/experiments/model.py +++ b/dacapo/experiments/model.py @@ -24,7 +24,7 @@ def __init__( self, architecture: Architecture, prediction_head: torch.nn.Module, - eval_activation: torch.nn.Module | None = None, + eval_activation: torch.nn.Module = None, ): super().__init__() @@ -46,7 +46,7 @@ def forward(self, x): result = self.eval_activation(result) return result - def compute_output_shape(self, input_shape: Coordinate) -> Tuple[int, Coordinate]: + def compute_output_shape(self, input_shape: Coordinate) -> Coordinate: """Compute the spatial shape (i.e., not accounting for channels and batch dimensions) of this model, when fed a tensor of the given spatial shape as input.""" diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index da7badbf9..1dfefbeee 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -15,6 +15,7 @@ def initialize_weights(self, model): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}") + # load the model weights (taken from torch load_state_dict source) try: model.load_state_dict(weights.model) diff --git a/dacapo/experiments/tasks/CARE_task.py b/dacapo/experiments/tasks/CARE_task.py new file mode 100644 index 000000000..519cdf701 --- /dev/null +++ b/dacapo/experiments/tasks/CARE_task.py @@ -0,0 +1,18 @@ +from .evaluators import IntensitiesEvaluator +from .losses import MSELoss +from .post_processors import CAREPostProcessor +from .predictors import CAREPredictor +from .task import Task + + +class CARETask(Task): + """CAREPredictor.""" + + def __init__(self, task_config) -> None: + """Create a `CARETask`.""" + self.predictor = CAREPredictor( + num_channels=task_config.num_channels, dims=task_config.dims + ) + self.loss = MSELoss() + self.post_processor = CAREPostProcessor() + self.evaluator = IntensitiesEvaluator() diff --git a/dacapo/experiments/tasks/CARE_task_config.py b/dacapo/experiments/tasks/CARE_task_config.py new file mode 100644 index 000000000..fccae9333 --- /dev/null +++ b/dacapo/experiments/tasks/CARE_task_config.py @@ -0,0 +1,28 @@ +import attr + +from .CARE_task import CARETask +from .task_config import TaskConfig + + +@attr.s +class CARETaskConfig(TaskConfig): + """This is a CARE task config used for generating and + evaluating voxel affinities for instance segmentations. + """ + + task_type = CARETask + num_channels: int = attr.ib( + default=2, + metadata={ + "help_text": "Number of output channels for the image. " + "Number of ouptut channels should match the number of channels in the ground truth." + }, + ) + + dims: int = attr.ib( + default=2, + metadata={ + "help_text": "Number of UNet dimensions. " + "Number of dimensions should match the number of channels in the ground truth." + }, + ) diff --git a/dacapo/experiments/tasks/CycleGAN_task.py b/dacapo/experiments/tasks/CycleGAN_task.py new file mode 100644 index 000000000..cfbda72f3 --- /dev/null +++ b/dacapo/experiments/tasks/CycleGAN_task.py @@ -0,0 +1,17 @@ +from .evaluators import IntensitiesEvaluator +from .losses import GANLoss +from .post_processors import CycleGANPostProcessor +from .predictors import CycleGANPredictor +from .task import Task + + +class CycleGANTask(Task): + """CycleGAN Task.""" + + def __init__(self, task_config) -> None: + """Create a `CycleGAN Task`.""" + + self.predictor = CycleGANPredictor(num_channels=task_config.num_channels) + self.loss = GANLoss() + self.post_processor = CycleGANPostProcessor() + self.evaluator = IntensitiesEvaluator() diff --git a/dacapo/experiments/tasks/CycleGAN_task_config.py b/dacapo/experiments/tasks/CycleGAN_task_config.py new file mode 100644 index 000000000..f63e90b31 --- /dev/null +++ b/dacapo/experiments/tasks/CycleGAN_task_config.py @@ -0,0 +1,21 @@ +import attr + +from .CycleGAN_task import CycleGANTask +from .task_config import TaskConfig + + +@attr.s +class CycleGANTaskConfig(TaskConfig): + """This is a Affinities task config used for generating and + evaluating voxel affinities for instance segmentations. + """ + + task_type = CycleGANTask + + num_channels: int = attr.ib( + default=1, + metadata={ + "help_text": "Number of output channels for the image. " + "Number of ouptut channels should match the number of channels in the ground truth." + }, + ) diff --git a/dacapo/experiments/tasks/Pix2Pix_task.py b/dacapo/experiments/tasks/Pix2Pix_task.py new file mode 100644 index 000000000..9a52fc77e --- /dev/null +++ b/dacapo/experiments/tasks/Pix2Pix_task.py @@ -0,0 +1,18 @@ +from .evaluators import IntensitiesEvaluator +from .losses import MSELoss +from .post_processors import CAREPostProcessor +from .predictors import CAREPredictor +from .task import Task + + +class Pix2PixTask(Task): + """Pix2Pix Predictor.""" + + def __init__(self, task_config) -> None: + """Create a `Pix2PixTask`.""" + self.predictor = Pix2Pix_predictor( + num_channels=task_config.num_channels, dims=task_config.dims + ) + self.loss = MSELoss() # TODO: change losses + self.post_processor = CAREPostProcessor() # TODO: change post processor + self.evaluator = IntensitiesEvaluator() diff --git a/dacapo/experiments/tasks/Pix2Pix_task_config.py b/dacapo/experiments/tasks/Pix2Pix_task_config.py new file mode 100644 index 000000000..ca5751fad --- /dev/null +++ b/dacapo/experiments/tasks/Pix2Pix_task_config.py @@ -0,0 +1,28 @@ +import attr + +from .Pix2Pix_task import Pix2PixTask +from .task_config import TaskConfig + + +@attr.s +class Pix2PixTaskConfig(TaskConfig): + """This is a Pix2Pix task config used for generating and + evaluating voxel affinities for instance segmentations. + """ + + task_type = Pix2PixTask + num_channels: int = attr.ib( + default=2, + metadata={ + "help_text": "Number of output channels for the image. " + "Number of ouptut channels should match the number of channels in the ground truth." + }, + ) + + dims: int = attr.ib( + default=2, + metadata={ + "help_text": "Number of UNet dimensions. " + "Number of dimensions should match the number of channels in the ground truth." + }, + ) diff --git a/dacapo/experiments/tasks/__init__.py b/dacapo/experiments/tasks/__init__.py index 780f343d1..65ce71a5a 100644 --- a/dacapo/experiments/tasks/__init__.py +++ b/dacapo/experiments/tasks/__init__.py @@ -5,3 +5,5 @@ from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa +from .CARE_task_config import CARETaskConfig, CARETask # noqa +from .CycleGAN_task_config import CycleGANTaskConfig, CycleGANTask # noqa diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index 859494e7e..c1014fd02 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -14,8 +14,6 @@ def __init__(self, task_config): self.predictor = AffinitiesPredictor( neighborhood=task_config.neighborhood, lsds=task_config.lsds ) - self.loss = AffinitiesLoss( - len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio - ) + self.loss = AffinitiesLoss(len(task_config.neighborhood)) self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood) self.evaluator = InstanceEvaluator() diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index a50c2141e..d4b2c6199 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -30,9 +30,3 @@ class AffinitiesTaskConfig(TaskConfig): "It has been shown that lsds as an auxiliary task can help affinity predictions." }, ) - lsds_to_affs_weight_ratio: float = attr.ib( - default=1, - metadata={ - "help_text": "If training with lsds, set how much they should be weighted compared to affs." - }, - ) diff --git a/dacapo/experiments/tasks/arraytypes/__init__.py b/dacapo/experiments/tasks/arraytypes/__init__.py new file mode 100644 index 000000000..456d192e5 --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/__init__.py @@ -0,0 +1,6 @@ +from .annotations import AnnotationArray +from .intensities import IntensitiesArray +from .distances import DistanceArray +from .mask import Mask +from .embedding import EmbeddingArray +from .probabilities import ProbabilityArray diff --git a/dacapo/experiments/tasks/arraytypes/annotations.py b/dacapo/experiments/tasks/arraytypes/annotations.py new file mode 100644 index 000000000..f7fc2f9b1 --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/annotations.py @@ -0,0 +1,23 @@ +from .arraytype import ArrayType + +import attr +from typing import Dict + + +@attr.s +class AnnotationArray(ArrayType): + """ + An AnnotationArray is a uint8, uint16, uint32 or uint64 Array where each + voxel has a value associated with its class. + """ + + classes: Dict[int, str] = attr.ib( + metadata={ + "help_text": "A mapping from class label to class name. " + "For example {1:'mitochondria', 2:'membrane'} etc." + } + ) + + @property + def interpolatable(self): + return False diff --git a/dacapo/experiments/tasks/arraytypes/arraytype.py b/dacapo/experiments/tasks/arraytypes/arraytype.py new file mode 100644 index 000000000..783519bbb --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/arraytype.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + + +# TODO: Should be read only +class ArrayType(ABC): + """ + The type of data provided by an array. The ArrayType class helps to keep + track of the semantic meaning of an Array. Additionally the ArrayType + keeps track of metadata that is specific to this datatype such as + num_classes for an annotated volume or channel names for intensity + arrays. + """ + + @property + @abstractmethod + def interpolatable(self) -> bool: + pass diff --git a/dacapo/experiments/tasks/arraytypes/binary.py b/dacapo/experiments/tasks/arraytypes/binary.py new file mode 100644 index 000000000..9dc6eb3fd --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/binary.py @@ -0,0 +1,23 @@ +from .arraytype import ArrayType + +import attr + +from typing import Dict + + +@attr.s +class BinaryArray(ArrayType): + """ + An BinaryArray is a bool or uint8 Array where each + voxel is either 1 or 0. + """ + + channels: Dict[int, str] = attr.ib( + metadata={ + "help_text": "A mapping from channel to class for the binary classification." + } + ) + + @property + def interpolatable(self) -> bool: + return False diff --git a/dacapo/experiments/tasks/arraytypes/distances.py b/dacapo/experiments/tasks/arraytypes/distances.py new file mode 100644 index 000000000..057f8f1b2 --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/distances.py @@ -0,0 +1,23 @@ +from .arraytype import ArrayType + +import attr + +from typing import Dict + + +@attr.s +class DistanceArray(ArrayType): + """ + An array containing signed distances to the nearest boundary voxel for a particular label class. + Distances should be positive outside an object and negative inside an object. + """ + + classes: Dict[int, str] = attr.ib( + metadata={ + "help_text": "A mapping from channel to class on which distances were calculated" + } + ) + + @property + def interpolatable(self) -> bool: + return True diff --git a/dacapo/experiments/tasks/arraytypes/embedding.py b/dacapo/experiments/tasks/arraytypes/embedding.py new file mode 100644 index 000000000..e18ae1ea2 --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/embedding.py @@ -0,0 +1,21 @@ +from .arraytype import ArrayType + +import attr + +from typing import Dict + + +@attr.s +class EmbeddingArray(ArrayType): + """ + A generic output of a model that could represent almost anything. Assumed to be + float, interpolatable, and have sum number of channels. + """ + + embedding_dims: int = attr.ib( + metadata={"help_text": "The dimension of your embedding."} + ) + + @property + def interpolatable(self) -> bool: + return True diff --git a/dacapo/experiments/tasks/arraytypes/intensities.py b/dacapo/experiments/tasks/arraytypes/intensities.py new file mode 100644 index 000000000..e31650eef --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/intensities.py @@ -0,0 +1,30 @@ +from .arraytype import ArrayType + +import numpy as np + +import attr + +from typing import Dict, Union + + +@attr.s +class IntensitiesArray(ArrayType): + """ + An IntensitiesArray is an Array of measured intensities. + """ + + channels: Dict[int, str] = attr.ib( + metadata={ + "help_text": "A mapping from channel to a name describing that channel." + } + ) + min: float = attr.ib( + metadata={"help_text": "The minimum possible value of your intensities."} + ) + max: float = attr.ib( + metadata={"help_text": "The maximum possible value of your intensities."} + ) + + @property + def interpolatable(self) -> bool: + return True diff --git a/dacapo/experiments/tasks/arraytypes/mask.py b/dacapo/experiments/tasks/arraytypes/mask.py new file mode 100644 index 000000000..f3ad62c0c --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/mask.py @@ -0,0 +1,10 @@ +from .arraytype import ArrayType + +import attr + + +@attr.s +class Mask(ArrayType): + @property + def interpolatable(self) -> bool: + return False diff --git a/dacapo/experiments/tasks/arraytypes/probabilities.py b/dacapo/experiments/tasks/arraytypes/probabilities.py new file mode 100644 index 000000000..16896ff71 --- /dev/null +++ b/dacapo/experiments/tasks/arraytypes/probabilities.py @@ -0,0 +1,25 @@ +from .arraytype import ArrayType + +import attr + +from typing import List + + +@attr.s +class ProbabilityArray(ArrayType): + """ + An array containing probabilities for each voxel. I.e. each voxel has a vector + of length `c` where `c` is the number of classes. The l1 norm of this vector should + always be 1. The class of this voxel can be determined by simply taking the + argmax. + """ + + classes: List[str] = attr.ib( + metadata={ + "help_text": "A mapping from channel to class on which distances were calculated" + } + ) + + @property + def interpolatable(self) -> bool: + return True diff --git a/dacapo/experiments/tasks/evaluators/__init__.py b/dacapo/experiments/tasks/evaluators/__init__.py index 19badc8d5..2daf37545 100644 --- a/dacapo/experiments/tasks/evaluators/__init__.py +++ b/dacapo/experiments/tasks/evaluators/__init__.py @@ -9,3 +9,7 @@ from .binary_segmentation_evaluator import BinarySegmentationEvaluator # noqa from .instance_evaluation_scores import InstanceEvaluationScores # noqa from .instance_evaluator import InstanceEvaluator # noqa + + +from .intensities_evaluation_scores import IntensitiesEvaluationScores # noqa +from .intensities_evaluator import IntensitiesEvaluator # noqa diff --git a/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py new file mode 100644 index 000000000..60dd56a13 --- /dev/null +++ b/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py @@ -0,0 +1,35 @@ +from .evaluation_scores import EvaluationScores +import attr + +from typing import Tuple + + +@attr.s +class IntensitiesEvaluationScores(EvaluationScores): + criteria: property = ["ssim", "psnr", "nrmse"] + + ssim: float = attr.ib(default=float("nan")) + psnr: float = attr.ib(default=float("nan")) + nrmse: float = attr.ib(default=float("nan")) + + @staticmethod + def higher_is_better(criterion: str) -> bool: + mapping: dict[str, bool] = { + "ssim": True, + "psnr": True, + "nrmse": False, + } + return mapping[criterion] + + @staticmethod + def bounds(criterion: str) -> Tuple[float, float]: + mapping: dict[str, tuple] = { + "ssim": (0, None), + "psnr": (0, None), + "nrmse": (0, None), + } + return mapping[criterion] + + @staticmethod + def store_best(criterion: str) -> bool: + return True diff --git a/dacapo/experiments/tasks/evaluators/intensities_evaluator.py b/dacapo/experiments/tasks/evaluators/intensities_evaluator.py new file mode 100644 index 000000000..81b8a25d5 --- /dev/null +++ b/dacapo/experiments/tasks/evaluators/intensities_evaluator.py @@ -0,0 +1,37 @@ +import xarray as xr +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +import numpy as np +from skimage.metrics import ( + structural_similarity, + peak_signal_noise_ratio, + normalized_root_mse, +) + +from .evaluator import Evaluator +from .intensities_evaluation_scores import IntensitiesEvaluationScores + + +class IntensitiesEvaluator(Evaluator): + """IntensitiesEvaluator Class + + An evaluator takes a post-processor's output and compares it against + ground-truth. + """ + + criteria = ["ssim", "psnr", "nrmse"] + + def evaluate( + self, output_array_identifier, evaluation_array + ) -> IntensitiesEvaluationScores: + output_array = ZarrArray.open_from_array_identifier(output_array_identifier) + evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) + output_data = output_array[output_array.roi].astype(np.uint64) + return IntensitiesEvaluationScores( + ssim=structural_similarity(evaluation_data, output_data), + psnr=peak_signal_noise_ratio(evaluation_data, output_data), + nrmse=normalized_root_mse(evaluation_data, output_data), + ) + + @property + def score(self) -> IntensitiesEvaluationScores: + return IntensitiesEvaluationScores() diff --git a/dacapo/experiments/tasks/losses/GANLoss.py b/dacapo/experiments/tasks/losses/GANLoss.py new file mode 100644 index 000000000..57bcf41f8 --- /dev/null +++ b/dacapo/experiments/tasks/losses/GANLoss.py @@ -0,0 +1,65 @@ +from .loss import Loss +import torch + + +# TODO: refactor for Dacapo +class GANLoss(Loss): + """Define different GAN objectives. + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """Initialize the GANLoss class. + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer("real_label", torch.tensor(target_real_label)) + self.register_buffer("fake_label", torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == "lsgan": + self.loss = nn.MSELoss() + elif gan_mode == "vanilla": + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ["wgangp"]: + self.loss = None + else: + raise NotImplementedError("gan mode %s not implemented" % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + Parameters: + prediction (tensor) - - tpyically the prediction from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + Returns: + A label tensor filled with ground truth label, and with the size of the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + Parameters: + prediction (tensor) - - tpyically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + Returns: + the calculated loss. + """ + if self.gan_mode in ["lsgan", "vanilla"]: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == "wgangp": + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index b675faa96..f14fae958 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -2,3 +2,4 @@ from .mse_loss import MSELoss # noqa from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa +from .GANLoss import GANLoss # noqa diff --git a/dacapo/experiments/tasks/losses/affinities_loss.py b/dacapo/experiments/tasks/losses/affinities_loss.py index 74fc7fe67..65ada8843 100644 --- a/dacapo/experiments/tasks/losses/affinities_loss.py +++ b/dacapo/experiments/tasks/losses/affinities_loss.py @@ -3,9 +3,8 @@ class AffinitiesLoss(Loss): - def __init__(self, num_affinities: int, lsds_to_affs_weight_ratio: float): + def __init__(self, num_affinities: int): self.num_affinities = num_affinities - self.lsds_to_affs_weight_ratio = lsds_to_affs_weight_ratio def compute(self, prediction, target, weight): affs, affs_target, affs_weight = ( @@ -22,7 +21,7 @@ def compute(self, prediction, target, weight): return ( torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target) * affs_weight - ).mean() + self.lsds_to_affs_weight_ratio * ( + ).mean() + ( torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target) * aux_weight ).mean() diff --git a/dacapo/experiments/tasks/post_processors/CARE_post_processor.py b/dacapo/experiments/tasks/post_processors/CARE_post_processor.py new file mode 100644 index 000000000..cc0af3742 --- /dev/null +++ b/dacapo/experiments/tasks/post_processors/CARE_post_processor.py @@ -0,0 +1,47 @@ +from typing import Iterable +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray + +from .CARE_post_processor_parameters import CAREPostProcessorParameters +from .post_processor import PostProcessor +import numpy as np +import zarr + +from typing import TYPE_CHECKING, Iterable + +if TYPE_CHECKING: + from dacapo.store.local_array_store import LocalArrayIdentifier + from dacapo.experiments.tasks.post_processors import PostProcessorParameters + + +class CAREPostProcessor(PostProcessor): + def __init__(self) -> None: + super().__init__() + + def enumerate_parameters(self) -> Iterable[CAREPostProcessorParameters]: + """Enumerate all possible parameters of this post-processor.""" + + yield CAREPostProcessorParameters(id=1) + + def set_prediction( + self, prediction_array_identifier: "LocalArrayIdentifier" + ): # TODO + self.prediction_array = ZarrArray.open_from_array_identifier( + prediction_array_identifier + ) + + def process( + self, + parameters: "PostProcessorParameters", + output_array_identifier: "LocalArrayIdentifier", + ) -> ZarrArray: + + output_array: ZarrArray = ZarrArray.create_from_array_identifier( + output_array_identifier, + self.prediction_array.axes, + self.prediction_array.roi, + self.prediction_array.num_channels, + self.prediction_array.voxel_size, + np.uint8, + ) + + return output_array diff --git a/dacapo/experiments/tasks/post_processors/CARE_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/CARE_post_processor_parameters.py new file mode 100644 index 000000000..ff1b6d584 --- /dev/null +++ b/dacapo/experiments/tasks/post_processors/CARE_post_processor_parameters.py @@ -0,0 +1,7 @@ +from .post_processor_parameters import PostProcessorParameters +import attr + + +@attr.s(frozen=True) +class CAREPostProcessorParameters(PostProcessorParameters): + pass diff --git a/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py b/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py new file mode 100644 index 000000000..8928c1330 --- /dev/null +++ b/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py @@ -0,0 +1,47 @@ +from typing import Iterable +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray + +from .CycleGAN_post_processor_parameters import CycleGANPostProcessorParameters +from .post_processor import PostProcessor +import numpy as np +import zarr + +from typing import TYPE_CHECKING, Iterable + +if TYPE_CHECKING: + from dacapo.store.local_array_store import LocalArrayIdentifier + from dacapo.experiments.tasks.post_processors import PostProcessorParameters + + +class CycleGANPostProcessor(PostProcessor): + def __init__(self) -> None: + super().__init__() + + def enumerate_parameters(self) -> Iterable[CycleGANPostProcessorParameters]: + """Enumerate all possible parameters of this post-processor.""" + + yield CycleGANPostProcessorParameters(id=1) + + def set_prediction( + self, prediction_array_identifier: "LocalArrayIdentifier" + ): # TODO + self.prediction_array = ZarrArray.open_from_array_identifier( + prediction_array_identifier + ) + + def process( + self, + parameters: "PostProcessorParameters", + output_array_identifier: "LocalArrayIdentifier", + ) -> ZarrArray: + + output_array: ZarrArray = ZarrArray.create_from_array_identifier( + output_array_identifier, + self.prediction_array.axes, + self.prediction_array.roi, + self.prediction_array.num_channels, + self.prediction_array.voxel_size, + np.uint8, + ) + + return output_array diff --git a/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor_parameters.py new file mode 100644 index 000000000..f1b091847 --- /dev/null +++ b/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor_parameters.py @@ -0,0 +1,7 @@ +from .post_processor_parameters import PostProcessorParameters +import attr + + +@attr.s(frozen=True) +class CycleGANPostProcessorParameters(PostProcessorParameters): + pass diff --git a/dacapo/experiments/tasks/post_processors/__init__.py b/dacapo/experiments/tasks/post_processors/__init__.py index fe0cde3d9..e3794594d 100644 --- a/dacapo/experiments/tasks/post_processors/__init__.py +++ b/dacapo/experiments/tasks/post_processors/__init__.py @@ -12,3 +12,9 @@ from .watershed_post_processor_parameters import ( WatershedPostProcessorParameters, ) # noqa + +from .CARE_post_processor import CAREPostProcessor # noqa +from .CARE_post_processor_parameters import CAREPostProcessorParameters # noqa + +from .CycleGAN_post_processor import CycleGANPostProcessor # noqa +from .CycleGAN_post_processor_parameters import CycleGANPostProcessorParameters # noqa diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 1a7c4627b..8fa6104bc 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -24,7 +24,7 @@ def enumerate_parameters(self): """Enumerate all possible parameters of this post-processor. Should return instances of ``PostProcessorParameters``.""" - for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]): + for i, bias in enumerate([0.1, 0.5, 0.9]): yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): @@ -44,9 +44,9 @@ def process(self, parameters, output_array_identifier): # if a previous segmentation is provided, it must have a "grid graph" # in its metadata. pred_data = self.prediction_array[self.prediction_array.roi] - affs = pred_data[: len(self.offsets)].astype(np.float64) + affs = pred_data[: len(self.offsets)] segmentation = mws.agglom( - affs - parameters.bias, + affs - 0.5, self.offsets, ) # filter fragments @@ -59,17 +59,12 @@ def process(self, parameters, output_array_identifier): for fragment, mean in zip( fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) ): - if mean < parameters.bias: + if mean < 0.5: filtered_fragments.append(fragment) filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) replace = np.zeros_like(filtered_fragments) - - # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input - if filtered_fragments.size > 0: - segmentation = npi.remap( - segmentation.flatten(), filtered_fragments, replace - ).reshape(segmentation.shape) + segmentation = npi.remap(segmentation, filtered_fragments, replace) output_array[self.prediction_array.roi] = segmentation diff --git a/dacapo/experiments/tasks/predictors/CARE_predictor.py b/dacapo/experiments/tasks/predictors/CARE_predictor.py new file mode 100644 index 000000000..a52fa2ef2 --- /dev/null +++ b/dacapo/experiments/tasks/predictors/CARE_predictor.py @@ -0,0 +1,54 @@ +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import IntensitiesArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray, ZarrArray + +from funlib.geometry import Coordinate + +import numpy as np +import torch + + +class CAREPredictor(Predictor): + def __init__(self, num_channels, dims): + self.num_channels = num_channels + self.dims = dims + + def create_model(self, architecture): + if self.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.num_channels, kernel_size=1 + ) + elif self.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.num_channels, kernel_size=1 + ) + else: + raise NotImplementedError( + f"CAREPredictor not implemented for {self.dims} dimensions" + ) + + return Model(architecture, head) + + def create_target(self, gt): + return gt + + def create_weight(self, gt, target=None, mask=None): + if mask is None: + # array of ones + return NumpyArray.from_np_array( + np.ones(gt.data.shape), + gt.roi, + gt.voxel_size, + gt.axes, + ) + else: + return mask + + @property + def output_array_type(self): + return IntensitiesArray( + {"channels": {n: str(n) for n in range(self.num_channels)}}, + min=0.0, + max=1.0, + ) diff --git a/dacapo/experiments/tasks/predictors/CycleGANPredictor.py b/dacapo/experiments/tasks/predictors/CycleGANPredictor.py new file mode 100644 index 000000000..8a75482d4 --- /dev/null +++ b/dacapo/experiments/tasks/predictors/CycleGANPredictor.py @@ -0,0 +1,66 @@ +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray, ZarrArray + +from funlib.geometry import Coordinate # TODO: pip install + +import numpy as np +import torch + + +class CycleGANPredictor(Predictor): + def __init__(self, num_channels): + self.num_channels = num_channels + + def create_model(self, netG1, netG2): + if self.dims == 2: + netG1 = torch.nn.Conv2d( + netG1.num_out_channels, self.num_channels, kernel_size=1 + ) + netG2 = torch.nn.Conv2d( + netG2.num_out_channels, self.num_channels, kernel_size=1 + ) + elif self.dims == 3: + netG1 = torch.nn.Conv3d( + netG1.num_out_channels, self.num_channels, kernel_size=1 + ) + netG2 = torch.nn.Conv3d( + netG2.num_out_channels, self.num_channels, kernel_size=1 + ) + else: + raise NotImplementedError( + f"CycleGANPredictor not implemented for {self.dims} dimensions" + ) + # TODO: + # return Model(architecture, head) + pass + + def create_target(self, gt): + return gt + + def create_weight(self, gt): + # ones + return NumpyArray.from_np_array( + np.ones(gt.data.shape), + gt.roi, + gt.voxel_size, + gt.axes, + ) + + @property + def output_array_type(self): + return ZarrArray(self.num_channels) + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims) diff --git a/dacapo/experiments/tasks/predictors/Pix2Pix_predictor.py b/dacapo/experiments/tasks/predictors/Pix2Pix_predictor.py new file mode 100644 index 000000000..f3a014313 --- /dev/null +++ b/dacapo/experiments/tasks/predictors/Pix2Pix_predictor.py @@ -0,0 +1,73 @@ +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import IntensitiesArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray, ZarrArray +from dacapo.experiments.architecutres import CNNectomeUNet, NLayerDiscriminator + +from funlib.geometry import Coordinate + +import numpy as np +import torch + + +class Pix2PixPredictor(Predictor): + def __init__(self, num_channels, dims): + self.num_channels = num_channels + self.dims = dims + + def create_model(self, g_architecture, d_architecture=None, is_train=True): + generator: CNNectomeUNet = CNNectomeUNet( + input_shape=g_architecture.input_shape + fmaps_out=g_architecture.fmaps_out + fmaps_in=g_architecture.fmaps_in, + num_fmaps=g_architecture.num_fmaps, + fmap_inc_factor=g_architecture.fmap_inc_factor, + downsample_factors=g_architecture.downsample_factors, + constant_upsample=g_architecture.constant_upsample, + padding=g_architecture.padding + ) + + if self.dims == 2: + head = torch.nn.Conv2d( + g_architecture.num_out_channels, self.num_channels, kernel_size=1 + ) + elif self.dims == 3: + head = torch.nn.Conv3d( + g_architecture.num_out_channels, self.num_channels, kernel_size=1 + ) + else: + raise NotImplementedError( + f"CAREPredictor not implemented for {self.dims} dimensions" + ) + + if is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + try: + discriminator: NLayerDiscriminator = NLayerDiscriminator(d_architecture.input_nc, + d_architecture.ngf, + d_architecture.n_layers, + d_architecture.norm_layer) + except Exception as e: + return Model(g_architecture, head), Model(d_architecture, head) + + return Model(g_architecture, head) + + def create_target(self, gt): + return gt + + def create_weight(self, gt, target=None, mask=None): + if mask is None: + # array of ones + return NumpyArray.from_np_array( + np.ones(gt.data.shape), + gt.roi, + gt.voxel_size, + gt.axes, + ) + else: + return mask + + @property + def output_array_type(self): + return IntensitiesArray({"channels": {n: str(n) for n in range(self.num_channels)}}, min=0., max=1.) + + diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index 76f82138d..040a9b874 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -3,3 +3,4 @@ from .one_hot_predictor import OneHotPredictor # noqa from .predictor import Predictor # noqa from .affinities_predictor import AffinitiesPredictor # noqa +from .CARE_predictor import CAREPredictor # noqa diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 40d81f5da..81efb2375 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -17,17 +17,9 @@ class AffinitiesPredictor(Predictor): - def __init__( - self, - neighborhood: List[Coordinate], - lsds: bool = True, - num_voxels: int = 20, - downsample_lsds: int = 1, - grow_boundary_iterations: int = 0, - ): + def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): self.neighborhood = neighborhood self.lsds = lsds - self.num_voxels = num_voxels if lsds: self._extractor = None if self.dims == 2: @@ -38,16 +30,12 @@ def __init__( raise ValueError( f"Cannot compute lsds on volumes with {self.dims} dimensions" ) - self.downsample_lsds = downsample_lsds else: self.num_lsds = 0 - self.grow_boundary_iterations = grow_boundary_iterations def extractor(self, voxel_size): if self._extractor is None: - self._extractor = LsdExtractor( - self.sigma(voxel_size), downsample=self.downsample_lsds - ) + self._extractor = LsdExtractor(self.sigma(voxel_size)) return self._extractor @@ -57,7 +45,8 @@ def dims(self): def sigma(self, voxel_size): voxel_dist = max(voxel_size) # arbitrarily chosen - sigma = voxel_dist * self.num_voxels # arbitrarily chosen + num_voxels = 10 # arbitrarily chosen + sigma = voxel_dist * num_voxels return Coordinate((sigma,) * self.dims) def lsd_pad(self, voxel_size): @@ -129,9 +118,7 @@ def _grow_boundaries(self, mask, slab): slice(start[d], start[d] + slab[d]) for d in range(len(slab)) ) mask_slab = mask[slices] - dilated_mask_slab = ndimage.binary_dilation( - mask_slab, iterations=self.grow_boundary_iterations - ) + dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1) foreground[slices] = dilated_mask_slab # label new background @@ -143,12 +130,10 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): (moving_class_counts, moving_lsd_class_counts) = ( moving_class_counts if moving_class_counts is not None else (None, None) ) - if self.grow_boundary_iterations > 0: - mask_data = self._grow_boundaries( - mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) - ) - else: - mask_data = mask[target.roi] + # mask_data = self._grow_boundaries( + # mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) + # ) + mask_data = mask[target.roi] aff_weights, moving_class_counts = balance_weights( target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8), 2, diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 70c2bde4a..a8fa44492 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,7 +27,7 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + def __init__(self, channels: List[str], scale_factor: float, mask_distances=bool): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index f5d8fcd52..75f06bbd1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -42,11 +42,6 @@ def __init__(self, trainer_config): self.mask_integral_downsample_factor = 4 self.clip_raw = trainer_config.clip_raw - # Testing out if calculating multiple times and multiplying is necessary - self.add_predictor_nodes_to_dataset = ( - trainer_config.add_predictor_nodes_to_dataset - ) - self.scheduler = None def create_optimizer(self, model): @@ -137,12 +132,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): + gp.Pad(gt_key, None, 0) + gp.Pad(mask_key, None, 0) + gp.RandomLocation( - ensure_nonempty=sample_points_key - if points_source is not None - else None, - ensure_centered=sample_points_key - if points_source is not None - else None, + ensure_nonempty=( + sample_points_key if points_source is not None else None + ), + ensure_centered=( + sample_points_key if points_source is not None else None + ), ) ) @@ -151,14 +146,13 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) - if self.add_predictor_nodes_to_dataset: - # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=dataset_weight_key, - mask_key=mask_key, - ) + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) @@ -168,14 +162,11 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key - if self.add_predictor_nodes_to_dataset - else weight_key, + weights_key=datasets_weight_key, mask_key=mask_key, ) - if self.add_predictor_nodes_to_dataset: - pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) + pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) # Trainer attributes: if self.num_data_fetchers > 1: @@ -332,9 +323,11 @@ def next(self): NumpyArray.from_gp_array(batch[self._gt_key]), NumpyArray.from_gp_array(batch[self._target_key]), NumpyArray.from_gp_array(batch[self._weight_key]), - NumpyArray.from_gp_array(batch[self._mask_key]) - if self._mask_key is not None - else None, + ( + NumpyArray.from_gp_array(batch[self._mask_key]) + if self._mask_key is not None + else None + ), ) def __enter__(self): diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 539e3c5e1..ae4243059 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -29,10 +29,3 @@ class GunpowderTrainerConfig(TrainerConfig): ) min_masked: Optional[float] = attr.ib(default=0.15) clip_raw: bool = attr.ib(default=True) - - add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( - default=True, - metadata={ - "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" - }, - ) diff --git a/dacapo/experiments/training_stats.py b/dacapo/experiments/training_stats.py index 72c631ed4..cd3fcd012 100644 --- a/dacapo/experiments/training_stats.py +++ b/dacapo/experiments/training_stats.py @@ -16,9 +16,7 @@ class TrainingStats: def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None: if len(self.iteration_stats) > 0: - assert ( - iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 - ), f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}" + assert iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 self.iteration_stats.append(iteration_stats) diff --git a/dacapo/plot.py b/dacapo/plot.py index c1e02ec95..fcd1c6ee2 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -74,9 +74,11 @@ def get_runs_info( run_config.architecture_config.name, run_config.trainer_config.name, run_config.datasplit_config.name, - stats_store.retrieve_training_stats(run_config_name, subsample=True) - if plot_loss - else None, + ( + stats_store.retrieve_training_stats(run_config_name, subsample=True) + if plot_loss + else None + ), validation_scores, validation_score_name, plot_loss, diff --git a/dacapo/predict.py b/dacapo/predict.py index 1df4d779e..5a40e303c 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -24,8 +24,6 @@ def predict( num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, - output_dtype: np.dtype = np.float32, # type: ignore - overwrite: bool = False, ): # get the model's input and output size @@ -58,7 +56,7 @@ def predict( output_roi, model.num_out_channels, output_voxel_size, - output_dtype, + np.float32, ) # create gunpowder keys @@ -77,8 +75,8 @@ def predict( # raw: (1, c, d, h, w) gt_padding = (output_size - output_roi.shape) % output_size - prediction_roi = output_roi.grow(gt_padding) # TODO: are we sure this makes sense? - # TODO: Add cache node? + prediction_roi = output_roi.grow(gt_padding) + # predict pipeline += gp_torch.Predict( model=model, @@ -86,9 +84,7 @@ def predict( outputs={0: prediction}, array_specs={ prediction: gp.ArraySpec( - roi=prediction_roi, - voxel_size=output_voxel_size, - dtype=np.float32, # assumes network output is float32 + roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32 ) }, spawn_subprocess=False, @@ -101,29 +97,22 @@ def predict( pipeline += gp.Squeeze([raw, prediction]) # raw: (c, d, h, w) # prediction: (c, d, h, w) - - # convert to uint8 if necessary: - if output_dtype == np.uint8: - pipeline += gp.IntensityScaleShift( - prediction, scale=255.0, shift=0.0 - ) # assumes float32 is [0,1] - pipeline += gp.AsType(prediction, output_dtype) + # raw: (c, d, h, w) + # prediction: (c, d, h, w) # write to zarr pipeline += gp.ZarrWrite( {prediction: prediction_array_identifier.dataset}, prediction_array_identifier.container.parent, prediction_array_identifier.container.name, - dataset_dtypes={prediction: output_dtype}, + dataset_dtypes={prediction: np.float32}, ) # create reference batch request ref_request = gp.BatchRequest() ref_request.add(raw, input_size) ref_request.add(prediction, output_size) - pipeline += gp.Scan( - ref_request - ) # TODO: This is a slow implementation for rendering + pipeline += gp.Scan(ref_request) # build pipeline and predict in complete output ROI diff --git a/dacapo/train.py b/dacapo/train.py index 7beb096b4..9203c1be3 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -16,7 +16,6 @@ def train(run_name: str, compute_context: ComputeContext = LocalTorch()): """Train a run""" if compute_context.train(run_name): - logger.error("Run %s is already being trained", run_name) # if compute context runs train in some other process # we are done here. return @@ -97,10 +96,9 @@ def train_run( weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - weights_store.retrieve_weights(run, iteration=latest_weights_iteration) - logger.error( + raise RuntimeError( f"Found weights for iteration {latest_weights_iteration}, but " - f"run {run.name} was only trained until {trained_until}. " + f"run {run.name} was only trained until {trained_until}." ) # start/resume training @@ -129,20 +127,18 @@ def train_run( # train for at most 100 iterations at a time, then store training stats iterations = min(100, run.train_until - trained_until) iteration_stats = None - bar = tqdm( + + for iteration_stats in tqdm( trainer.iterate( iterations, run.model, run.optimizer, compute_context.device, ), - desc=f"training until {iterations + trained_until}", - total=run.train_until, - initial=trained_until, - ) - for iteration_stats in bar: + "training", + iterations, + ): run.training_stats.add_iteration_stats(iteration_stats) - bar.set_postfix({"loss": iteration_stats.loss}) if (iteration_stats.iteration + 1) % run.validation_interval == 0: break @@ -164,26 +160,22 @@ def train_run( run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) - stats_store.store_training_stats(run.name, run.training_stats) weights_store.store_weights(run, iteration_stats.iteration + 1) - try: - validate_run( - run, - iteration_stats.iteration + 1, - compute_context=compute_context, - ) - stats_store.store_validation_iteration_scores( - run.name, run.validation_scores - ) - except Exception as e: - logger.error( - f"Validation failed for run {run.name} at iteration " - f"{iteration_stats.iteration + 1}.", - exc_info=e, - ) + validate_run( + run, + iteration_stats.iteration + 1, + compute_context=compute_context, + ) + stats_store.store_validation_iteration_scores( + run.name, run.validation_scores + ) + stats_store.store_training_stats(run.name, run.training_stats) # make sure to move optimizer back to the correct device run.move_optimizer(compute_context.device) run.model.train() + weights_store.store_weights(run, run.training_stats.trained_until()) + stats_store.store_training_stats(run.name, run.training_stats) + logger.info("Trained until %d, finished.", trained_until) diff --git a/dacapo/validate.py b/dacapo/validate.py index a1cf9da7d..25b7463e1 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -141,7 +141,6 @@ def validate_run( prediction_array_identifier = array_store.validation_prediction_array( run.name, iteration, validation_dataset ) - logger.info("Predicting on dataset %s", validation_dataset.name) predict( run.model, validation_dataset.raw, @@ -149,7 +148,6 @@ def validate_run( compute_context=compute_context, output_roi=validation_dataset.gt.roi, ) - logger.info("Predicted on dataset %s", validation_dataset.name) post_processor.set_prediction(prediction_array_identifier) diff --git a/docs/source/_static/icon_dacapo.png b/docs/source/_static/icon_dacapo.png deleted file mode 100644 index f04fc9315..000000000 Binary files a/docs/source/_static/icon_dacapo.png and /dev/null differ diff --git a/mypy.ini b/mypy.ini index aadc732e4..722c11df8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -68,7 +68,4 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-mwatershed.*] -ignore_missing_imports = True - -[mypy-numpy_indexed.*] -ignore_missing_imports = True +ignore_missing_imports = True \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 12afa83a4..492c8e6f4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ black mypy -pytest==7.4.4 +pytest pytest-cov pytest-lazy-fixture \ No newline at end of file diff --git a/requirments.txt b/requirments.txt new file mode 100644 index 000000000..73c085cf0 --- /dev/null +++ b/requirments.txt @@ -0,0 +1,37 @@ +bokeh +configargparse +pymongo +tqdm>=4.63.2 +zarr>=2.11.2 +click>=8.1.2 +flask>=2.1.1 +flask_wtf +flask_login +wtforms +attr +Flask-SocketIO==4.3.1 +python-engineio==3.13.2 +python-socketio==4.6.0 +Werkzeug==2.0 + +# dacapo requirements.txt +numpy >= 1.22.2 +pyyaml >= 6.0 +zarr >= 2.10.3 +cattrs >= 1.8.0 +pymongo >= 3.12.0 +tqdm >= 4.62.3 +simpleitk >= 2.1.1 +lazy-property >= 0.0.1 +neuroglancer >= 2.22 +torch >= 1.9.1 +fibsem_tools >= 0.2.11 + +git+https://github.com/funkelab/daisy +git+https://github.com/funkelab/funlib.math@0c623f71c083d33184cac40ef7b1b995216be8ef +git+https://github.com/pattonw/funlib.evaluate +git+https://github.com/funkelab/funlib.geometry@cf30e4d74eb860e46de40533c4f8278dc25147b1 +git+https://github.com/cremi/cremi_python@python3 +git+https://github.com/funkey/gunpowder@v1.3-dev + +git+https://github.com/funkelab/lsd@publish \ No newline at end of file diff --git a/setup.py b/setup.py index 3e6f51064..f111ed47d 100644 --- a/setup.py +++ b/setup.py @@ -5,16 +5,16 @@ description="Framework for easy composition of volumetric machine learning jobs.", long_description=open("README.md", "r").read(), long_description_content_type="text/markdown", - version="0.2.0", + version="0.1", url="https://github.com/janelia-cellmap/dacapo", - author="Jan Funke, Will Patton, Jeff Rhoades, Marwan Zouinkhi", - author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org, rhoadesj@hhmi.org, zouinkhim@hhmi.org", + author="Jan Funke, Will Patton", + author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org", license="MIT", packages=find_packages(), entry_points={"console_scripts": ["dacapo=dacapo.cli:cli"]}, include_package_data=True, install_requires=[ - "numpy==1.22.3", + "numpy", "pyyaml", "zarr", "cattrs", @@ -32,14 +32,9 @@ "funlib.math>=0.1", "funlib.geometry>=0.2", "mwatershed>=0.1", - "funlib.persistence @ git+https://github.com/janelia-cellmap/funlib.persistence", + "funlib.persistence>=0.1", "funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate", "gunpowder>=1.3", - # "lsds>=0.1.3", - "lsds @ git+https://github.com/funkelab/lsd", - "xarray", - "cattrs", - "numpy-indexed", - "click", + "lsds>=0.1.3", ], )