Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mzouink #18

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np

from typing import Dict, Any
import logging

logger = logging.getLogger(__file__)


class ConcatArray(Array):
Expand Down Expand Up @@ -116,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
raise Exception(f"{concatenated.shape}, shapes")
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def attrs(self):

@property
def axes(self):
return ["t", "z", "y", "x"][-self.dims :]
return ["c", "z", "y", "x"][-self.dims :]

@property
def dims(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array):
((["b", "c"] if len(array.data.shape) == instance.dims + 2 else []))
+ (["c"] if len(array.data.shape) == instance.dims + 1 else [])
+ [
"t",
"c",
"z",
"y",
"x",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def axes(self):
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
)
return ["t", "z", "y", "x"][-self.dims : :]
return ["c", "z", "y", "x"][-self.dims : :]

@property
def dims(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
architecture: Architecture,
prediction_head: torch.nn.Module,
eval_activation: torch.nn.Module = None,
eval_activation: torch.nn.Module | None = None,
):
super().__init__()

Expand Down
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, task_config):
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
extra_conv=task_config.extra_conv,
)
self.loss = MSELoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
7 changes: 7 additions & 0 deletions dacapo/experiments/tasks/distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ class DistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)

extra_conv: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to add an extra conv layer before the head"
},
)
56 changes: 47 additions & 9 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
extra_conv: bool,
):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -36,20 +42,52 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8
self.extra_conv = extra_conv
self.extra_conv_dims = len(self.channels) * 2

@property
def embedding_dims(self):
return len(self.channels)

def create_model(self, architecture):
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
if self.extra_conv:
if architecture.dims == 2:
head = torch.nn.Sequential(
torch.nn.Conv2d(
architecture.num_out_channels,
self.extra_conv_dims,
kernel_size=3,
padding=1,
),
torch.nn.Conv2d(
self.extra_conv_dims,
self.embedding_dims,
kernel_size=1,
),
)
elif architecture.dims == 3:
head = torch.nn.Sequential(
torch.nn.Conv3d(
architecture.num_out_channels,
self.extra_conv_dims,
kernel_size=3,
padding=1,
),
torch.nn.Conv3d(
self.extra_conv_dims,
self.embedding_dims,
kernel_size=1,
),
)
else:
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)

return Model(architecture, head)

Expand Down
47 changes: 34 additions & 13 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,26 @@ def __init__(self, trainer_config):
self.mask_integral_downsample_factor = 4
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)
self.finetune_head_only = trainer_config.finetune_head_only

self.scheduler = None

def create_optimizer(self, model):
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters())
if self.finetune_head_only:
logger.warning("Finetuning head only")
parameters = []
for name, param in model.named_parameters():
if "prediction_head" in name:
parameters.append(param)
else:
param.requires_grad = False
else:
parameters = model.parameters()
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters)
self.scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.01,
Expand Down Expand Up @@ -146,13 +162,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)
if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)
Expand All @@ -162,11 +179,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key,
weights_key=datasets_weight_key
if self.add_predictor_nodes_to_dataset
else weight_key,
mask_key=mask_key,
)

pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)
if self.add_predictor_nodes_to_dataset:
pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
Expand Down Expand Up @@ -208,15 +228,15 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
def iterate(self, num_iterations, model, optimizer, device):
t_start_fetch = time.time()

logger.info("Starting iteration!")

for iteration in range(self.iteration, self.iteration + num_iterations):
raw, gt, target, weight, mask = self.next()
logger.debug(
f"Trainer fetch batch took {time.time() - t_start_fetch} seconds"
)

for param in model.parameters():
for (
param
) in model.parameters(): # TODO: get parameters from optimizer instead
param.grad = None

t_start_prediction = time.time()
Expand All @@ -227,6 +247,7 @@ def iterate(self, num_iterations, model, optimizer, device):
torch.as_tensor(target[target.roi]).to(device).float(),
torch.as_tensor(weight[weight.roi]).to(device).float(),
)

loss.backward()
optimizer.step()

Expand Down
12 changes: 12 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@ class GunpowderTrainerConfig(TrainerConfig):
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)

finetune_head_only: Optional[bool] = attr.ib(
default=False,
metadata={"help_text": "Whether to fine-tune head only or all layers"},
)
10 changes: 8 additions & 2 deletions dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def train(run_name: str, compute_context: ComputeContext = LocalTorch()):
"""Train a run"""

if compute_context.train(run_name):
logger.error("Run %s is already being trained", run_name)
# if compute context runs train in some other process
# we are done here.
return
Expand Down Expand Up @@ -96,10 +97,15 @@ def train_run(
weights_store.retrieve_weights(run, iteration=trained_until)

elif latest_weights_iteration > trained_until:
raise RuntimeError(
weights_store.retrieve_weights(run, iteration=latest_weights_iteration)
logger.error(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}."
)
# raise RuntimeError(
# f"Found weights for iteration {latest_weights_iteration}, but "
# f"run {run.name} was only trained until {trained_until}."
# )

# start/resume training

Expand Down Expand Up @@ -157,7 +163,7 @@ def train_run(

run.model.eval()
# free up optimizer memory to allow larger validation blocks
run.model = run.model.to(torch.device("cpu"))
# run.model = run.model.to(torch.device("cpu"))
run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True)

weights_store.store_weights(run, iteration_stats.iteration + 1)
Expand Down
2 changes: 2 additions & 0 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ def validate_run(
prediction_array_identifier = array_store.validation_prediction_array(
run.name, iteration, validation_dataset
)
logger.info("Predicting on dataset %s", validation_dataset.name)
predict(
run.model,
validation_dataset.raw,
prediction_array_identifier,
compute_context=compute_context,
output_roi=validation_dataset.gt.roi,
)
logger.info("Predicted on dataset %s", validation_dataset.name)

post_processor.set_prediction(prediction_array_identifier)

Expand Down
23 changes: 12 additions & 11 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,44 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('../..'))

sys.path.insert(0, os.path.abspath("../.."))


# -- Project information -----------------------------------------------------

project = 'DaCapo'
copyright = '2022, William Patton, David Ackerman, Jan Funke'
author = 'William Patton, David Ackerman, Jan Funke'
project = "DaCapo"
copyright = "2022, William Patton, David Ackerman, Jan Funke"
author = "William Patton, David Ackerman, Jan Funke"


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx_autodoc_typehints']
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints"]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_material'
html_theme = "sphinx_material"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
html_css_files = [
'css/custom.css',
]
"css/custom.css",
]
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@
"funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate",
"gunpowder>=1.3",
"lsds>=0.1.3",
"xarray",
"cattrs",
"numpy-indexed",
"click",
],
)
Loading