Skip to content

Commit

Permalink
Merge branch 'rhoadesj/dev' into merge_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Feb 8, 2024
2 parents 5f50f9b + 673484c commit f9f85d3
Show file tree
Hide file tree
Showing 31 changed files with 977 additions and 214 deletions.
10 changes: 6 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def axes(self):
try:
return self._attributes["axes"]
except KeyError:
logger.debug(
logger.info(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
Expand All @@ -58,7 +58,7 @@ def axes(self):

@property
def dims(self) -> int:
return self.voxel_size.dims
return len(self.data.shape)

@lazy_property.LazyProperty
def _daisy_array(self) -> funlib.persistence.Array:
Expand All @@ -81,7 +81,7 @@ def writable(self) -> bool:

@property
def dtype(self) -> Any:
return self.data.dtype
return self.data.dtype # TODO: why not use self._daisy_array.dtype?

@property
def num_channels(self) -> Optional[int]:
Expand All @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]:
return [ax for ax in self.axes if ax not in set(["c", "b"])]

@property
def data(self) -> Any:
def data(self) -> Any: # TODO: why not use self._daisy_array.data?
zarr_container = zarr.open(str(self.file_name))
return zarr_container[self.dataset]

Expand All @@ -116,6 +116,7 @@ def create_from_array_identifier(
dtype,
write_size=None,
name=None,
overwrite=False,
):
"""
Create a new ZarrArray given an array identifier. It is assumed that
Expand Down Expand Up @@ -145,6 +146,7 @@ def create_from_array_identifier(
dtype,
num_channels=num_channels,
write_size=write_size,
delete=overwrite,
)
zarr_dataset = zarr_container[array_identifier.dataset]
zarr_dataset.attrs["offset"] = (
Expand Down
34 changes: 26 additions & 8 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from .validation_scores import ValidationScores
from .starts import Start
from .model import Model

import logging
import torch

logger = logging.getLogger(__file__)

class Run:
name: str
Expand Down Expand Up @@ -53,14 +54,31 @@ def __init__(self, run_config):
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)

if run_config.start_config is None:
return
try:
from ..store import create_config_store
start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
raise e

# preloaded weights from previous run
self.start = (
Start(run_config.start_config)
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
62 changes: 59 additions & 3 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,77 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]


def match_heads(model, weights, old_head, new_head):
# match the heads
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}")
# find the index of the label in the old_head
old_index = old_head.index(label)
# find the index of the label in the new_head
new_index = new_head.index(label)
# get the weight and bias of the old head
for key in [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]:
if key in model.state_dict().keys():
n_val = weights.model[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label}")
return model


class Start(ABC):
def __init__(self, start_config):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
self.old_head = old_head
self.new_head = new_head

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)

logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")

# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
if self.old_head and self.new_head:
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.info(f"old head: {self.old_head}")
logger.info(f"new head: {self.new_head}")
model = match_heads(model, weights, self.old_head, self.new_head)
logger.warning(
f"matched heads from run {self.run}, criterion: {self.criterion}"
)
self.remove_head = True
if self.remove_head:
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
weights.model.pop("prediction_head.weight", None)
weights.model.pop("prediction_head.bias", None)
weights.model.pop("chain.1.weight", None)
weights.model.pop("chain.1.bias", None)
logger.warning(
f"removed head from run {self.run}, criterion: {self.criterion}"
)
model.load_state_dict(weights.model, strict=False)
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)
else:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
46 changes: 46 additions & 0 deletions dacapo/experiments/tasks/evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,52 @@ def is_best(
else:
return getattr(score, criterion) < previous_best_score

def get_overall_best(self, dataset: "Dataset", criterion: str):
overall_best = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
return overall_best

def get_overall_best_parameters(self, dataset: "Dataset", criterion: str):
overall_best = None
overall_best_parameters = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
return overall_best_parameters

def set_best(self, validation_scores: "ValidationScores") -> None:
"""
Find the best iteration for each dataset/post_processing_parameter/criterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

@attr.s
class InstanceEvaluationScores(EvaluationScores):
criteria = ["voi_split", "voi_merge", "voi"]
criteria = ["voi_split", "voi_merge", "voi", "avg_iou"]

voi_split: float = attr.ib(default=float("nan"))
voi_merge: float = attr.ib(default=float("nan"))
avg_iou: float = attr.ib(default=float("nan"))

@property
def voi(self):
Expand All @@ -21,6 +22,7 @@ def higher_is_better(criterion: str) -> bool:
"voi_split": False,
"voi_merge": False,
"voi": False,
"avg_iou": True,
}
return mapping[criterion]

Expand All @@ -30,6 +32,7 @@ def bounds(criterion: str) -> Tuple[float, float]:
"voi_split": (0, 1),
"voi_merge": (0, 1),
"voi": (0, 1),
"avg_iou": (0, None),
}
return mapping[criterion]

Expand Down
32 changes: 29 additions & 3 deletions dacapo/experiments/tasks/evaluators/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,48 @@
from .evaluator import Evaluator
from .instance_evaluation_scores import InstanceEvaluationScores

from funlib.evaluate import rand_voi
from funlib.evaluate import rand_voi, detection_scores

try:
from funlib.segment.arrays import relabel

iou = True
except ImportError:
iou = False

import numpy as np


class InstanceEvaluator(Evaluator):
criteria = ["voi_merge", "voi_split", "voi"]
criteria = ["voi_merge", "voi_split", "voi", "avg_iou"]

def evaluate(self, output_array_identifier, evaluation_array):
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64)
output_data = output_array[output_array.roi].astype(np.uint64)
results = rand_voi(evaluation_data, output_data)
if iou:
try:
output_data, _ = relabel(output_data)
results.update(
detection_scores(
evaluation_data,
output_data,
matching_score="iou",
)
)
except Exception:
results["avg_iou"] = 0
logger.warning(
"Could not compute IoU because of an unknown error. Sorry about that."
)
else:
results["avg_iou"] = 0

return InstanceEvaluationScores(
voi_merge=results["voi_merge"], voi_split=results["voi_split"]
voi_merge=results["voi_merge"],
voi_split=results["voi_split"],
avg_iou=results["avg_iou"],
)

@property
Expand Down
24 changes: 24 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import HotDistanceLoss
from .post_processors import ThresholdPostProcessor
from .predictors import HotDistancePredictor
from .task import Task

class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

def __init__(self, task_config):
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`."""

self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
46 changes: 46 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import attr

from .hot_distance_task import HotDistanceTask
from .task_config import TaskConfig

from typing import List

class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
segmentations.
The advantage of generating distance transforms over regular
affinities is you can get a denser signal, i.e. 1 misclassified
pixel in an affinity prediction could merge 2 otherwise very
distinct objects, this cannot happen with distances.
"""

task_type = HotDistanceTask

channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."})
clip_distance: float = attr.ib(
metadata={
"help_text": "Maximum distance to consider for false positive/negatives."
},
)
tol_distance: float = attr.ib(
metadata={
"help_text": "Tolerance distance for counting false positives/negatives"
},
)
scale_factor: float = attr.ib(
default=1,
metadata={
"help_text": "The amount by which to scale distances before applying "
"a tanh normalization."
},
)
mask_distances: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to mask out regions where the true distance to "
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .mse_loss import MSELoss # noqa
from .loss import Loss # noqa
from .affinities_loss import AffinitiesLoss # noqa
from .hot_distance_loss import HotDistanceLoss # noqa
Loading

0 comments on commit f9f85d3

Please sign in to comment.