Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate CARE, Pix2Pix, and CycleGAN #55

Closed
wants to merge 16 commits into from
Closed
155 changes: 155 additions & 0 deletions care_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import dacapo
import logging
import math
import torch
from torchsummary import summary

# CARE task specific elements
from dacapo.experiments.datasplits.datasets.arrays import ZarrArrayConfig, IntensitiesArrayConfig
from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.architectures import CNNectomeUNetConfig
from dacapo.experiments.tasks import CARETaskConfig

from dacapo.experiments.trainers import GunpowderTrainerConfig
from dacapo.experiments.trainers.gp_augments import (
SimpleAugmentConfig,
ElasticAugmentConfig,
IntensityAugmentConfig,
)
from funlib.geometry import Coordinate
from dacapo.experiments.run_config import RunConfig
from dacapo.experiments.run import Run
from dacapo.store.create_store import create_config_store
from dacapo.train import train


# set basic login configs
logging.basicConfig(level=logging.INFO)

raw_array_config_zarr = ZarrArrayConfig(
name="raw",
file_name="/n/groups/htem/users/br128/data/CBvBottom/CBxs_lobV_bottomp100um_training_0.n5",
dataset="volumes/raw_30nm",
)

gt_array_config_zarr = ZarrArrayConfig(
name="gt",
file_name="/n/groups/htem/users/br128/data/CBvBottom/CBxs_lobV_bottomp100um_training_0.n5",
dataset="volumes/interpolated_90nm_aligned",
)

raw_array_config_int = IntensitiesArrayConfig(
name="raw_norm",
source_array_config = raw_array_config_zarr,
min = 0.,
max = 1.
)

gt_array_config_int = IntensitiesArrayConfig(
name="gt_norm",
source_array_config = gt_array_config_zarr,
min = 0.,
max = 1.
)

dataset_config = RawGTDatasetConfig(
name="CBxs_lobV_bottomp100um_CARE_0",
raw_config=raw_array_config_int,
gt_config=gt_array_config_int,
)

# TODO: check datasplit config, this honestly might work
datasplit_config = TrainValidateDataSplitConfig(
name="CBxs_lobV_bottomp100um_training_0.n5",
train_configs=[dataset_config],
validate_configs=[dataset_config],
)
"""
kernel size 3
2 conv passes per block

1 -- 100%, lose 4 pix - 286 pix
2 -- 50%, lose 8 pix - 142 pix
3 -- 25%, lose 16 pix - 32 pix
"""
# UNET config
architecture_config = CNNectomeUNetConfig(
name="small_unet",
input_shape=Coordinate(156, 156, 156),
# eval_shape_increase=Coordinate(72, 72, 72),
fmaps_in=1,
num_fmaps=8,
fmaps_out=32,
fmap_inc_factor=4,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
)


# CARE task
task_config = CARETaskConfig(name="CAREModel", num_channels=1, dims=3)


# trainier
trainer_config = GunpowderTrainerConfig(
name="gunpowder",
batch_size=2,
learning_rate=0.0001,
augments=[
SimpleAugmentConfig(),
ElasticAugmentConfig(
control_point_spacing=(100, 100, 100),
control_point_displacement_sigma=(10.0, 10.0, 10.0),
rotation_interval=(0, math.pi / 2.0),
subsample=8,
uniform_3d_rotation=True,
),
IntensityAugmentConfig(
scale=(0.25, 1.75),
shift=(-0.5, 0.35),
clip=False,
),
],
num_data_fetchers=20,
snapshot_interval=10000,
min_masked=0.15,
)


# run config
run_config = RunConfig(
name="CARE_train",
task_config=task_config,
architecture_config=architecture_config,
trainer_config=trainer_config,
datasplit_config=datasplit_config,
repetition=0,
num_iterations=100000,
validation_interval=1000,
)

run = Run(run_config)

# run summary TODO create issue
print(summary(run.model, (1, 156, 156, 156)))


# store configs, then train
config_store = create_config_store()

config_store.store_datasplit_config(datasplit_config)
config_store.store_architecture_config(architecture_config)
config_store.store_task_config(task_config)
config_store.store_trainer_config(trainer_config)
config_store.store_run_config(run_config)

# Optional start training by config name:
train(run_config.name)

# CLI dacapo train -r {run_config.name}


"""
RuntimeError: Can not downsample shape torch.Size([1, 128, 47, 47, 47]) with factor (2, 2, 2), mismatch in spatial dimension 2
"""
57 changes: 57 additions & 0 deletions dacapo/experiments/architectures/nlayer_discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from .architecture import Architecture

import torch
import torch.nn as nn
import functools

class NLayerDiscriminator(Architecture):
"""Defines a PatchGAN discriminator"""

def __init__(self, architecture_config):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ngf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super().__init__()

input_nc: int = architecture_config.input_nc
ngf: int = architecture_config.ngf
n_layers: int = architecture_config.n_layers
norm_layer = architecture_config.norm_layer

if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d

kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.model(input)
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class IntensitiesArrayConfig(ArrayConfig):

min: float = attr.ib(metadata={"help_text": "The minimum intensity in your data"})
max: float = attr.ib(metadata={"help_text": "The maximum intensity in your data"})

15 changes: 15 additions & 0 deletions dacapo/experiments/tasks/CARE_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .evaluators import IntensitiesEvaluator
from .losses import MSELoss
from .post_processors import CAREPostProcessor
from .predictors import CAREPredictor
from .task import Task

class CARETask(Task):
"""CAREPredictor."""

def __init__(self, task_config) -> None:
"""Create a `CARETask`."""
self.predictor = CAREPredictor(num_channels=task_config.num_channels, dims=task_config.dims)
self.loss = MSELoss()
self.post_processor = CAREPostProcessor()
self.evaluator = IntensitiesEvaluator()
27 changes: 27 additions & 0 deletions dacapo/experiments/tasks/CARE_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import attr

from .CARE_task import CARETask
from .task_config import TaskConfig


@attr.s
class CARETaskConfig(TaskConfig):
"""This is a CARE task config used for generating and
evaluating voxel affinities for instance segmentations.
"""

task_type = CARETask
num_channels: int = attr.ib(
default=2,
metadata={
"help_text": "Number of output channels for the image. "
"Number of ouptut channels should match the number of channels in the ground truth."
})

dims: int = attr.ib(
default=2,
metadata={
"help_text": "Number of UNet dimensions. "
"Number of dimensions should match the number of channels in the ground truth."
}
)
17 changes: 17 additions & 0 deletions dacapo/experiments/tasks/CycleGAN_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .evaluators import IntensitiesEvaluator
from .losses import GANLoss
from .post_processors import CycleGANPostProcessor
from .predictors import CycleGANPredictor
from .task import Task


class CycleGANTask(Task):
"""CycleGAN Task."""

def __init__(self, task_config) -> None:
"""Create a `CycleGAN Task`."""

self.predictor = CycleGANPredictor(num_channels=task_config.num_channels)
self.loss = GANLoss()
self.post_processor = CycleGANPostProcessor()
self.evaluator = IntensitiesEvaluator()
21 changes: 21 additions & 0 deletions dacapo/experiments/tasks/CycleGAN_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import attr

from .CycleGAN_task import CycleGANTask
from .task_config import TaskConfig


@attr.s
class CycleGANTaskConfig(TaskConfig):
"""This is a Affinities task config used for generating and
evaluating voxel affinities for instance segmentations.
"""

task_type = CycleGANTask

num_channels: int = attr.ib(
default=1,
metadata={
"help_text": "Number of output channels for the image. "
"Number of ouptut channels should match the number of channels in the ground truth."
}
)
15 changes: 15 additions & 0 deletions dacapo/experiments/tasks/Pix2Pix_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .evaluators import IntensitiesEvaluator
from .losses import MSELoss
from .post_processors import CAREPostProcessor
from .predictors import CAREPredictor
from .task import Task

class Pix2PixTask(Task):
"""Pix2Pix Predictor."""

def __init__(self, task_config) -> None:
"""Create a `Pix2PixTask`."""
self.predictor = Pix2Pix_predictor(num_channels=task_config.num_channels, dims=task_config.dims)
self.loss = MSELoss() # TODO: change losses
self.post_processor = CAREPostProcessor() # TODO: change post processor
self.evaluator = IntensitiesEvaluator()
27 changes: 27 additions & 0 deletions dacapo/experiments/tasks/Pix2Pix_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import attr

from .Pix2Pix_task import Pix2PixTask
from .task_config import TaskConfig


@attr.s
class Pix2PixTaskConfig(TaskConfig):
"""This is a Pix2Pix task config used for generating and
evaluating voxel affinities for instance segmentations.
"""

task_type = Pix2PixTask
num_channels: int = attr.ib(
default=2,
metadata={
"help_text": "Number of output channels for the image. "
"Number of ouptut channels should match the number of channels in the ground truth."
})

dims: int = attr.ib(
default=2,
metadata={
"help_text": "Number of UNet dimensions. "
"Number of dimensions should match the number of channels in the ground truth."
}
)
2 changes: 2 additions & 0 deletions dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .CARE_task_config import CARETaskConfig, CARETask # noqa
from .CycleGAN_task_config import CycleGANTaskConfig, CycleGANTask # noqa
6 changes: 6 additions & 0 deletions dacapo/experiments/tasks/arraytypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .annotations import AnnotationArray
from .intensities import IntensitiesArray
from .distances import DistanceArray
from .mask import Mask
from .embedding import EmbeddingArray
from .probabilities import ProbabilityArray
23 changes: 23 additions & 0 deletions dacapo/experiments/tasks/arraytypes/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .arraytype import ArrayType

import attr
from typing import Dict


@attr.s
class AnnotationArray(ArrayType):
"""
An AnnotationArray is a uint8, uint16, uint32 or uint64 Array where each
voxel has a value associated with its class.
"""

classes: Dict[int, str] = attr.ib(
metadata={
"help_text": "A mapping from class label to class name. "
"For example {1:'mitochondria', 2:'membrane'} etc."
}
)

@property
def interpolatable(self):
return False
Loading
Loading