Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rhoadesj/dev #21

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
68258d7
feat: better tqdm, training reloading
rhoadesScholar Aug 26, 2023
9eee5a7
feat: numpy version requirement (1.22.3); watershed
rhoadesScholar Aug 29, 2023
703a31d
feat: improve training logging
rhoadesScholar Aug 30, 2023
b97f3d7
Merge branch 'master' into rhoadesj/dev
rhoadesScholar Aug 30, 2023
8a03a57
feat: stdout logging, version
rhoadesScholar Sep 1, 2023
9626922
feat: validation logging, keep best validation
rhoadesScholar Sep 5, 2023
8212b8b
WIP: apply.py
rhoadesScholar Sep 8, 2023
8326400
feat: keep actual best scoring dataset
rhoadesScholar Sep 11, 2023
663fe0f
feat: ready to debug apply.py
rhoadesScholar Sep 11, 2023
beeed60
feat!: add cli for applying models
rhoadesScholar Sep 12, 2023
8553b5a
feat!: add cli for applying models
rhoadesScholar Sep 12, 2023
273ecb7
feat: apply.py roi
rhoadesScholar Sep 12, 2023
2622b94
feat: postprocess parameter parsing
rhoadesScholar Sep 13, 2023
bd192e5
feat: overwrite option for predict & postprocess
rhoadesScholar Sep 13, 2023
6421bd1
bugfix!?: removed odd zarr creation
rhoadesScholar Sep 14, 2023
8198648
!: attempts to fix roi mismatch failed.
rhoadesScholar Sep 15, 2023
ed0d4ad
bugfix: prediction works with zarrs
rhoadesScholar Sep 15, 2023
ae23179
feat: backlogged updates and wip
rhoadesScholar Sep 20, 2023
53b3556
feat: zarr fix dimension detection, simple augment config kwargs, val…
rhoadesScholar Sep 21, 2023
8e6dfa3
bugfix: simple augment config
rhoadesScholar Sep 21, 2023
0869015
bugfix: simple augment config
rhoadesScholar Sep 21, 2023
922ba62
feat: black format, iou score
rhoadesScholar Sep 25, 2023
7d954f9
bugfix: remove array overspecification
rhoadesScholar Sep 25, 2023
ec1b0d8
hot distance loss function
mzouink Sep 25, 2023
be0f6db
feat: hotdistance predictor, model/target
rhoadesScholar Sep 25, 2023
da48612
hot distance task
mzouink Sep 25, 2023
a6714c0
Merge branch 'hot_distance' of github.com:janelia-cellmap/dacapo into…
rhoadesScholar Sep 25, 2023
9567ec4
feat: hot_distance_predictor target/weight
rhoadesScholar Sep 25, 2023
bfcc339
bugfix: temporarily remove IOU
rhoadesScholar Sep 26, 2023
d8c8517
feat: avg_iou
rhoadesScholar Sep 28, 2023
976abe6
feat: revert weighting, add torch.compile(model)
rhoadesScholar Oct 10, 2023
bc56278
bugfix: revert torch.compile addition
rhoadesScholar Oct 10, 2023
dc62447
feat: handling buggy iou evaluation
rhoadesScholar Oct 11, 2023
82e2743
fix use coord
mzouink Oct 15, 2023
8f648cd
fix use coord
mzouink Oct 15, 2023
d95cf7a
weight cross class
mzouink Oct 15, 2023
a8884a1
start head matching
mzouink Oct 30, 2023
e8aaf06
Merge branch 'mzouink' into rhoadesj/dev
rhoadesScholar Oct 31, 2023
19e4a58
fix: ✨ Add Marwan's start-setup, and remove reloading on validation.
rhoadesScholar Nov 10, 2023
673484c
Merge branch 'main' into rhoadesj/dev
rhoadesScholar Feb 8, 2024
f9f85d3
Merge branch 'rhoadesj/dev' into merge_fix
mzouink Feb 8, 2024
c3a81b7
:art: Format Python code with psf/black
mzouink Feb 8, 2024
75eaff4
Merge pull request #27 from janelia-cellmap/actions/black
mzouink Feb 8, 2024
290e57f
Merge pull request #26 from janelia-cellmap/merge_fix
mzouink Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def axes(self):
try:
return self._attributes["axes"]
except KeyError:
logger.debug(
logger.info(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
Expand All @@ -58,7 +58,7 @@ def axes(self):

@property
def dims(self) -> int:
return self.voxel_size.dims
return len(self.data.shape)

@lazy_property.LazyProperty
def _daisy_array(self) -> funlib.persistence.Array:
Expand All @@ -81,7 +81,7 @@ def writable(self) -> bool:

@property
def dtype(self) -> Any:
return self.data.dtype
return self.data.dtype # TODO: why not use self._daisy_array.dtype?

@property
def num_channels(self) -> Optional[int]:
Expand All @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]:
return [ax for ax in self.axes if ax not in set(["c", "b"])]

@property
def data(self) -> Any:
def data(self) -> Any: # TODO: why not use self._daisy_array.data?
zarr_container = zarr.open(str(self.file_name))
return zarr_container[self.dataset]

Expand All @@ -116,6 +116,7 @@ def create_from_array_identifier(
dtype,
write_size=None,
name=None,
overwrite=False,
):
"""
Create a new ZarrArray given an array identifier. It is assumed that
Expand Down Expand Up @@ -145,6 +146,7 @@ def create_from_array_identifier(
dtype,
num_channels=num_channels,
write_size=write_size,
delete=overwrite,
)
zarr_dataset = zarr_container[array_identifier.dataset]
zarr_dataset.attrs["offset"] = (
Expand Down
41 changes: 33 additions & 8 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from .validation_scores import ValidationScores
from .starts import Start
from .model import Model

import logging
import torch

logger = logging.getLogger(__file__)


class Run:
name: str
Expand Down Expand Up @@ -53,14 +55,37 @@ def __init__(self, run_config):
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)

if run_config.start_config is None:
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
self.start = (
Start(run_config.start_config)
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
62 changes: 59 additions & 3 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,77 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]


def match_heads(model, weights, old_head, new_head):
# match the heads
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}")
# find the index of the label in the old_head
old_index = old_head.index(label)
# find the index of the label in the new_head
new_index = new_head.index(label)
# get the weight and bias of the old head
for key in [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]:
if key in model.state_dict().keys():
n_val = weights.model[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label}")
return model


class Start(ABC):
def __init__(self, start_config):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
self.old_head = old_head
self.new_head = new_head

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)

logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")

# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
if self.old_head and self.new_head:
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.info(f"old head: {self.old_head}")
logger.info(f"new head: {self.new_head}")
model = match_heads(model, weights, self.old_head, self.new_head)
logger.warning(
f"matched heads from run {self.run}, criterion: {self.criterion}"
)
self.remove_head = True
if self.remove_head:
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
weights.model.pop("prediction_head.weight", None)
weights.model.pop("prediction_head.bias", None)
weights.model.pop("chain.1.weight", None)
weights.model.pop("chain.1.bias", None)
logger.warning(
f"removed head from run {self.run}, criterion: {self.criterion}"
)
model.load_state_dict(weights.model, strict=False)
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)
else:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
46 changes: 46 additions & 0 deletions dacapo/experiments/tasks/evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,52 @@ def is_best(
else:
return getattr(score, criterion) < previous_best_score

def get_overall_best(self, dataset: "Dataset", criterion: str):
overall_best = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
return overall_best

def get_overall_best_parameters(self, dataset: "Dataset", criterion: str):
overall_best = None
overall_best_parameters = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
return overall_best_parameters

def set_best(self, validation_scores: "ValidationScores") -> None:
"""
Find the best iteration for each dataset/post_processing_parameter/criterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

@attr.s
class InstanceEvaluationScores(EvaluationScores):
criteria = ["voi_split", "voi_merge", "voi"]
criteria = ["voi_split", "voi_merge", "voi", "avg_iou"]

voi_split: float = attr.ib(default=float("nan"))
voi_merge: float = attr.ib(default=float("nan"))
avg_iou: float = attr.ib(default=float("nan"))

@property
def voi(self):
Expand All @@ -21,6 +22,7 @@ def higher_is_better(criterion: str) -> bool:
"voi_split": False,
"voi_merge": False,
"voi": False,
"avg_iou": True,
}
return mapping[criterion]

Expand All @@ -30,6 +32,7 @@ def bounds(criterion: str) -> Tuple[float, float]:
"voi_split": (0, 1),
"voi_merge": (0, 1),
"voi": (0, 1),
"avg_iou": (0, None),
}
return mapping[criterion]

Expand Down
32 changes: 29 additions & 3 deletions dacapo/experiments/tasks/evaluators/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,48 @@
from .evaluator import Evaluator
from .instance_evaluation_scores import InstanceEvaluationScores

from funlib.evaluate import rand_voi
from funlib.evaluate import rand_voi, detection_scores

try:
from funlib.segment.arrays import relabel

iou = True
except ImportError:
iou = False

import numpy as np


class InstanceEvaluator(Evaluator):
criteria = ["voi_merge", "voi_split", "voi"]
criteria = ["voi_merge", "voi_split", "voi", "avg_iou"]

def evaluate(self, output_array_identifier, evaluation_array):
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64)
output_data = output_array[output_array.roi].astype(np.uint64)
results = rand_voi(evaluation_data, output_data)
if iou:
try:
output_data, _ = relabel(output_data)
results.update(
detection_scores(
evaluation_data,
output_data,
matching_score="iou",
)
)
except Exception:
results["avg_iou"] = 0
logger.warning(
"Could not compute IoU because of an unknown error. Sorry about that."
)
else:
results["avg_iou"] = 0

return InstanceEvaluationScores(
voi_merge=results["voi_merge"], voi_split=results["voi_split"]
voi_merge=results["voi_merge"],
voi_split=results["voi_split"],
avg_iou=results["avg_iou"],
)

@property
Expand Down
25 changes: 25 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import HotDistanceLoss
from .post_processors import ThresholdPostProcessor
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

def __init__(self, task_config):
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`."""

self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
47 changes: 47 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import attr

from .hot_distance_task import HotDistanceTask
from .task_config import TaskConfig

from typing import List


class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
segmentations.

The advantage of generating distance transforms over regular
affinities is you can get a denser signal, i.e. 1 misclassified
pixel in an affinity prediction could merge 2 otherwise very
distinct objects, this cannot happen with distances.
"""

task_type = HotDistanceTask

channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."})
clip_distance: float = attr.ib(
metadata={
"help_text": "Maximum distance to consider for false positive/negatives."
},
)
tol_distance: float = attr.ib(
metadata={
"help_text": "Tolerance distance for counting false positives/negatives"
},
)
scale_factor: float = attr.ib(
default=1,
metadata={
"help_text": "The amount by which to scale distances before applying "
"a tanh normalization."
},
)
mask_distances: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to mask out regions where the true distance to "
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
Loading
Loading