diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 25f2c224e..081f08e9b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -49,7 +49,7 @@ def axes(self): try: return self._attributes["axes"] except KeyError: - logger.debug( + logger.info( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", @@ -58,7 +58,7 @@ def axes(self): @property def dims(self) -> int: - return self.voxel_size.dims + return len(self.data.shape) @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: @@ -81,7 +81,7 @@ def writable(self) -> bool: @property def dtype(self) -> Any: - return self.data.dtype + return self.data.dtype # TODO: why not use self._daisy_array.dtype? @property def num_channels(self) -> Optional[int]: @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]: return [ax for ax in self.axes if ax not in set(["c", "b"])] @property - def data(self) -> Any: + def data(self) -> Any: # TODO: why not use self._daisy_array.data? zarr_container = zarr.open(str(self.file_name)) return zarr_container[self.dataset] @@ -116,6 +116,7 @@ def create_from_array_identifier( dtype, write_size=None, name=None, + overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that @@ -145,6 +146,7 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, + delete=overwrite, ) zarr_dataset = zarr_container[array_identifier.dataset] zarr_dataset.attrs["offset"] = ( diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 129f947ab..9ea496758 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -6,9 +6,11 @@ from .validation_scores import ValidationScores from .starts import Start from .model import Model - +import logging import torch +logger = logging.getLogger(__file__) + class Run: name: str @@ -53,14 +55,37 @@ def __init__(self, run_config): self.task.parameters, self.datasplit.validate, self.task.evaluation_scores ) + if run_config.start_config is None: + return + try: + from ..store import create_config_store + + start_config_store = create_config_store() + starter_config = start_config_store.retrieve_run_config( + run_config.start_config.run + ) + except Exception as e: + logger.error( + f"could not load start config: {e} Should be added to the database config store RUN" + ) + raise e + # preloaded weights from previous run - self.start = ( - Start(run_config.start_config) - if run_config.start_config is not None - else None - ) - if self.start is not None: - self.start.initialize_weights(self.model) + if run_config.task_config.name == starter_config.task_config.name: + self.start = Start(run_config.start_config) + else: + # Match labels between old and new head + if hasattr(run_config.task_config, "channels"): + # Map old head and new head + old_head = starter_config.task_config.channels + new_head = run_config.task_config.channels + self.start = Start( + run_config.start_config, old_head=old_head, new_head=new_head + ) + else: + logger.warning("Not implemented channel match for this task") + self.start = Start(run_config.start_config, remove_head=True) + self.start.initialize_weights(self.model) @staticmethod def get_validation_scores(run_config) -> ValidationScores: diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index a5b68069c..bb634ff88 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,21 +3,77 @@ logger = logging.getLogger(__file__) +# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] +# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] + + +def match_heads(model, weights, old_head, new_head): + # match the heads + for label in new_head: + if label in old_head: + logger.warning(f"matching head for {label}") + # find the index of the label in the old_head + old_index = old_head.index(label) + # find the index of the label in the new_head + new_index = new_head.index(label) + # get the weight and bias of the old head + for key in [ + "prediction_head.weight", + "prediction_head.bias", + "chain.1.weight", + "chain.1.bias", + ]: + if key in model.state_dict().keys(): + n_val = weights.model[key][old_index] + model.state_dict()[key][new_index] = n_val + logger.warning(f"matched head for {label}") + return model + class Start(ABC): - def __init__(self, start_config): + def __init__(self, start_config, remove_head=False, old_head=None, new_head=None): self.run = start_config.run self.criterion = start_config.criterion + self.remove_head = remove_head + self.old_head = old_head + self.new_head = new_head def initialize_weights(self, model): from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) + logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}") - # load the model weights (taken from torch load_state_dict source) try: - model.load_state_dict(weights.model) + if self.old_head and self.new_head: + logger.warning( + f"matching heads from run {self.run}, criterion: {self.criterion}" + ) + logger.info(f"old head: {self.old_head}") + logger.info(f"new head: {self.new_head}") + model = match_heads(model, weights, self.old_head, self.new_head) + logger.warning( + f"matched heads from run {self.run}, criterion: {self.criterion}" + ) + self.remove_head = True + if self.remove_head: + logger.warning( + f"removing head from run {self.run}, criterion: {self.criterion}" + ) + weights.model.pop("prediction_head.weight", None) + weights.model.pop("prediction_head.bias", None) + weights.model.pop("chain.1.weight", None) + weights.model.pop("chain.1.bias", None) + logger.warning( + f"removed head from run {self.run}, criterion: {self.criterion}" + ) + model.load_state_dict(weights.model, strict=False) + logger.warning( + f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}" + ) + else: + model.load_state_dict(weights.model) except RuntimeError as e: logger.warning(e) diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 9d5cbbda0..24096261a 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -63,6 +63,52 @@ def is_best( else: return getattr(score, criterion) < previous_best_score + def get_overall_best(self, dataset: "Dataset", criterion: str): + overall_best = None + if self.best_scores: + for _, parameter, _ in self.best_scores.keys(): + score = self.best_scores[(dataset, parameter, criterion)] + if score is None: + overall_best = None + else: + _, current_parameter_score = score + if overall_best is None: + overall_best = current_parameter_score + else: + if current_parameter_score: + if self.higher_is_better(criterion): + if current_parameter_score > overall_best: + overall_best = current_parameter_score + else: + if current_parameter_score < overall_best: + overall_best = current_parameter_score + return overall_best + + def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): + overall_best = None + overall_best_parameters = None + if self.best_scores: + for _, parameter, _ in self.best_scores.keys(): + score = self.best_scores[(dataset, parameter, criterion)] + if score is None: + overall_best = None + else: + _, current_parameter_score = score + if overall_best is None: + overall_best = current_parameter_score + overall_best_parameters = parameter + else: + if current_parameter_score: + if self.higher_is_better(criterion): + if current_parameter_score > overall_best: + overall_best = current_parameter_score + overall_best_parameters = parameter + else: + if current_parameter_score < overall_best: + overall_best = current_parameter_score + overall_best_parameters = parameter + return overall_best_parameters + def set_best(self, validation_scores: "ValidationScores") -> None: """ Find the best iteration for each dataset/post_processing_parameter/criterion diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py index 7de54d99c..16eac194c 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py @@ -6,10 +6,11 @@ @attr.s class InstanceEvaluationScores(EvaluationScores): - criteria = ["voi_split", "voi_merge", "voi"] + criteria = ["voi_split", "voi_merge", "voi", "avg_iou"] voi_split: float = attr.ib(default=float("nan")) voi_merge: float = attr.ib(default=float("nan")) + avg_iou: float = attr.ib(default=float("nan")) @property def voi(self): @@ -21,6 +22,7 @@ def higher_is_better(criterion: str) -> bool: "voi_split": False, "voi_merge": False, "voi": False, + "avg_iou": True, } return mapping[criterion] @@ -30,6 +32,7 @@ def bounds(criterion: str) -> Tuple[float, float]: "voi_split": (0, 1), "voi_merge": (0, 1), "voi": (0, 1), + "avg_iou": (0, None), } return mapping[criterion] diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluator.py b/dacapo/experiments/tasks/evaluators/instance_evaluator.py index 0f3427a40..d2fc91678 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluator.py @@ -3,22 +3,48 @@ from .evaluator import Evaluator from .instance_evaluation_scores import InstanceEvaluationScores -from funlib.evaluate import rand_voi +from funlib.evaluate import rand_voi, detection_scores + +try: + from funlib.segment.arrays import relabel + + iou = True +except ImportError: + iou = False import numpy as np class InstanceEvaluator(Evaluator): - criteria = ["voi_merge", "voi_split", "voi"] + criteria = ["voi_merge", "voi_split", "voi", "avg_iou"] def evaluate(self, output_array_identifier, evaluation_array): output_array = ZarrArray.open_from_array_identifier(output_array_identifier) evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) output_data = output_array[output_array.roi].astype(np.uint64) results = rand_voi(evaluation_data, output_data) + if iou: + try: + output_data, _ = relabel(output_data) + results.update( + detection_scores( + evaluation_data, + output_data, + matching_score="iou", + ) + ) + except Exception: + results["avg_iou"] = 0 + logger.warning( + "Could not compute IoU because of an unknown error. Sorry about that." + ) + else: + results["avg_iou"] = 0 return InstanceEvaluationScores( - voi_merge=results["voi_merge"], voi_split=results["voi_split"] + voi_merge=results["voi_merge"], + voi_split=results["voi_split"], + avg_iou=results["avg_iou"], ) @property diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py new file mode 100644 index 000000000..ef0d03229 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -0,0 +1,25 @@ +from .evaluators import BinarySegmentationEvaluator +from .losses import HotDistanceLoss +from .post_processors import ThresholdPostProcessor +from .predictors import HotDistancePredictor +from .task import Task + + +class HotDistanceTask(Task): + """This is just a Hot Distance Task that combine Binary and distance prediction.""" + + def __init__(self, task_config): + """Create a `HotDistanceTask` from a `HotDistanceTaskConfig`.""" + + self.predictor = HotDistancePredictor( + channels=task_config.channels, + scale_factor=task_config.scale_factor, + mask_distances=task_config.mask_distances, + ) + self.loss = HotDistanceLoss() + self.post_processor = ThresholdPostProcessor() + self.evaluator = BinarySegmentationEvaluator( + clip_distance=task_config.clip_distance, + tol_distance=task_config.tol_distance, + channels=task_config.channels, + ) diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py new file mode 100644 index 000000000..951226476 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -0,0 +1,47 @@ +import attr + +from .hot_distance_task import HotDistanceTask +from .task_config import TaskConfig + +from typing import List + + +class HotDistanceTaskConfig(TaskConfig): + """This is a Hot Distance task config used for generating and + evaluating signed distance transforms as a way of generating + segmentations. + + The advantage of generating distance transforms over regular + affinities is you can get a denser signal, i.e. 1 misclassified + pixel in an affinity prediction could merge 2 otherwise very + distinct objects, this cannot happen with distances. + """ + + task_type = HotDistanceTask + + channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) + clip_distance: float = attr.ib( + metadata={ + "help_text": "Maximum distance to consider for false positive/negatives." + }, + ) + tol_distance: float = attr.ib( + metadata={ + "help_text": "Tolerance distance for counting false positives/negatives" + }, + ) + scale_factor: float = attr.ib( + default=1, + metadata={ + "help_text": "The amount by which to scale distances before applying " + "a tanh normalization." + }, + ) + mask_distances: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether or not to mask out regions where the true distance to " + "object boundary cannot be known. This is anywhere that the distance to crop boundary " + "is less than the distance to object boundary." + }, + ) diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index b675faa96..f1db3586b 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -2,3 +2,4 @@ from .mse_loss import MSELoss # noqa from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa +from .hot_distance_loss import HotDistanceLoss # noqa diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py new file mode 100644 index 000000000..2e99ab5e1 --- /dev/null +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -0,0 +1,29 @@ +from .loss import Loss +import torch + + +# HotDistance is used for predicting hot and distance maps at the same time. +# The first half of the channels are the hot maps, the second half are the distance maps. +# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. +# Model should predict twice the number of channels as the target. +class HotDistanceLoss(Loss): + def compute(self, prediction, target, weight): + target_hot, target_distance = self.split(target) + prediction_hot, prediction_distance = self.split(prediction) + weight_hot, weight_distance = self.split(weight) + return self.hot_loss( + prediction_hot, target_hot, weight_hot + ) + self.distance_loss(prediction_distance, target_distance, weight_distance) + + def hot_loss(self, prediction, target, weight): + return torch.nn.BCELoss().forward(prediction * weight, target * weight) + + def distance_loss(self, prediction, target, weight): + return torch.nn.MSELoss().forward(prediction * weight, target * weight) + + def split(self, x): + assert ( + x.shape[0] % 2 == 0 + ), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + mid = x.shape[0] // 2 + return x[:mid], x[-mid:] diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 709d1de34..799f2651e 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -19,7 +19,7 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, overwrite: bool = False): output_array = ZarrArray.create_from_array_identifier( output_array_identifier, [dim for dim in self.prediction_array.axes if dim != "c"], @@ -27,6 +27,7 @@ def process(self, parameters, output_array_identifier): None, self.prediction_array.voxel_size, np.uint8, + overwrite=overwrite, ) output_array[self.prediction_array.roi] = np.argmax( diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 5a2c7810a..ddb249539 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: def set_prediction(self, prediction_array): pass - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, overwrite: bool = False): # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index 020361cb9..4e4102d6b 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -33,6 +33,8 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", + overwrite: "bool", + blockwise: "bool", ) -> "Array": """Convert predictions into the final output.""" pass diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 67ffdd066..32bf4cfc0 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -28,6 +28,7 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", + overwrite: bool = False, ) -> ZarrArray: # TODO: Investigate Liskov substitution princple and whether it is a problem here # OOP theory states the super class should always be replaceable with its subclasses @@ -47,6 +48,7 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, + overwrite=overwrite, ) output_array[self.prediction_array.roi] = ( diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 8fa6104bc..307806772 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -24,7 +24,9 @@ def enumerate_parameters(self): """Enumerate all possible parameters of this post-processor. Should return instances of ``PostProcessorParameters``.""" - for i, bias in enumerate([0.1, 0.5, 0.9]): + for i, bias in enumerate( + [0.1, 0.3, 0.5, 0.7, 0.9] + ): # TODO: add this to the config yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): @@ -32,40 +34,65 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier): - output_array = ZarrArray.create_from_array_identifier( - output_array_identifier, - [axis for axis in self.prediction_array.axes if axis != "c"], - self.prediction_array.roi, - None, - self.prediction_array.voxel_size, - np.uint64, - ) - # if a previous segmentation is provided, it must have a "grid graph" - # in its metadata. - pred_data = self.prediction_array[self.prediction_array.roi] - affs = pred_data[: len(self.offsets)] - segmentation = mws.agglom( - affs - 0.5, - self.offsets, - ) - # filter fragments - average_affs = np.mean(affs, axis=0) - - filtered_fragments = [] - - fragment_ids = np.unique(segmentation) - - for fragment, mean in zip( - fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) - ): - if mean < 0.5: - filtered_fragments.append(fragment) - - filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) - replace = np.zeros_like(filtered_fragments) - segmentation = npi.remap(segmentation, filtered_fragments, replace) - - output_array[self.prediction_array.roi] = segmentation - - return output_array + def process( + self, + parameters, + output_array_identifier, + overwrite: bool = False, + blockwise: bool = False, + ): # TODO: will probably break with large arrays... + if not blockwise: + output_array = ZarrArray.create_from_array_identifier( + output_array_identifier, + [axis for axis in self.prediction_array.axes if axis != "c"], + self.prediction_array.roi, + None, + self.prediction_array.voxel_size, + np.uint64, + overwrite=overwrite, + ) + # if a previous segmentation is provided, it must have a "grid graph" + # in its metadata. + # pred_data = self.prediction_array[self.prediction_array.roi] + # affs = pred_data[: len(self.offsets)].astype( + # np.float64 + # ) # TODO: shouldn't need to be float64 + affs = self.prediction_array[self.prediction_array.roi][: len(self.offsets)] + if affs.dtype == np.uint8: + affs = affs.astype(np.float64) / 255.0 + else: + affs = affs.astype(np.float64) + segmentation = mws.agglom( + affs - parameters.bias, + self.offsets, + ) + # filter fragments + average_affs = np.mean(affs, axis=0) + + filtered_fragments = [] + + fragment_ids = np.unique(segmentation) + + for fragment, mean in zip( + fragment_ids, + measurements.mean(average_affs, segmentation, fragment_ids), + ): + if mean < parameters.bias: + filtered_fragments.append(fragment) + + filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) + replace = np.zeros_like(filtered_fragments) + + # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input + if filtered_fragments.size > 0: + segmentation = npi.remap( + segmentation.flatten(), filtered_fragments, replace + ).reshape(segmentation.shape) + + output_array[self.prediction_array.roi] = segmentation + + return output_array + else: + raise NotImplementedError( + "Blockwise processing not yet implemented." + ) # TODO: add rusty mws diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py new file mode 100644 index 000000000..fc73cb0ea --- /dev/null +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -0,0 +1,280 @@ +from dacapo.experiments.arraytypes.probabilities import ProbabilityArray +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import DistanceArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.balance_weights import balance_weights + +from funlib.geometry import Coordinate + +from scipy.ndimage.morphology import distance_transform_edt +import numpy as np +import torch + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +class HotDistancePredictor(Predictor): + """ + Predict signed distances and one hot embedding (as a proxy task) for a binary segmentation task. + Distances deep within background are pushed to -inf, distances deep within + the foreground object are pushed to inf. After distances have been + calculated they are passed through a tanh so that distances saturate at +-1. + Multiple classes can be predicted via multiple distance channels. The names + of each class that is being segmented can be passed in as a list of strings + in the channels argument. + """ + + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + self.channels = ( + channels * 2 + ) # one hot + distance (TODO: add hot/distance to channel names) + self.norm = "tanh" + self.dt_scale_factor = scale_factor + self.mask_distances = mask_distances + + self.max_distance = 1 * scale_factor + self.epsilon = 5e-2 # TODO: should be a config parameter + self.threshold = 0.8 # TODO: should be a config parameter + + @property + def embedding_dims(self): + return len(self.channels) + + @property + def classes(self): + return len(self.channels) // 2 + + def create_model(self, architecture): + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + + return Model(architecture, head) + + def create_target(self, gt): + target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor) + return NumpyArray.from_np_array( + target, + gt.roi, + gt.voxel_size, + gt.axes, + ) + + def create_weight(self, gt, target, mask, moving_class_counts=None): + # balance weights independently for each channel + one_hot_weights, one_hot_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi]], + moving_counts=moving_class_counts[: self.classes], + ) + + if self.mask_distances: + distance_mask = self.create_distance_mask( + target[target.roi][-self.classes :], + mask[target.roi], + target.voxel_size, + self.norm, + self.dt_scale_factor, + ) + else: + distance_mask = np.ones_like(target.data) + + distance_weights, distance_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi], distance_mask], + moving_counts=moving_class_counts[-self.classes :], + ) + + weights = np.concatenate((one_hot_weights, distance_weights)) + moving_class_counts = np.concatenate( + (one_hot_moving_class_counts, distance_moving_class_counts) + ) + return ( + NumpyArray.from_np_array( + weights, + gt.roi, + gt.voxel_size, + gt.axes, + ), + moving_class_counts, + ) + + @property + def output_array_type(self): + # technically this is a probability array + distance array, but it is only ever referenced for interpolatability (which is true for both) (TODO) + return ProbabilityArray(self.embedding_dims) + + def create_distance_mask( + self, + distances: np.ndarray, + mask: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + mask_output = mask.copy() + for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): + tmp = np.zeros( + np.array(channel_mask.shape) + np.array((2,) * channel_mask.ndim), + dtype=channel_mask.dtype, + ) + slices = tmp.ndim * (slice(1, -1),) + tmp[slices] = channel_mask + boundary_distance = distance_transform_edt( + tmp, + sampling=voxel_size, + ) + if self.epsilon is None: + add = 0 + else: + add = self.epsilon + boundary_distance = self.__normalize( + boundary_distance[slices], normalize, normalize_args + ) + + channel_mask_output = mask_output[i] + logging.debug( + "Total number of masked in voxels before distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance >= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after postive distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance <= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after negative distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + return mask_output + + def process( + self, + labels: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): + boundaries = self.__find_boundaries(channel) + + # mark boundaries with 0 (not 1) + boundaries = 1.0 - boundaries + + if np.sum(boundaries == 0) == 0: + max_distance = min( + dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) + ) + if np.sum(channel) == 0: + distances = -np.ones(channel.shape, dtype=np.float32) * max_distance + else: + distances = np.ones(channel.shape, dtype=np.float32) * max_distance + else: + # get distances (voxel_size/2 because image is doubled) + distances = distance_transform_edt( + boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + ) + distances = distances.astype(np.float32) + + # restore original shape + downsample = (slice(None, None, 2),) * len(voxel_size) + distances = distances[downsample] + + # todo: inverted distance + distances[channel == 0] = -distances[channel == 0] + + if normalize is not None: + distances = self.__normalize(distances, normalize, normalize_args) + + all_distances[ii] = distances + + return np.concatenate((labels, all_distances)) + + def __find_boundaries(self, labels): + # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n + # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 + # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 + # bound.: 00000001000100000001000 2n - 1 + + logger.debug("computing boundaries for %s", labels.shape) + + dims = len(labels.shape) + in_shape = labels.shape + out_shape = tuple(2 * s - 1 for s in in_shape) + + boundaries = np.zeros(out_shape, dtype=bool) + + logger.debug("boundaries shape is %s", boundaries.shape) + + for d in range(dims): + logger.debug("processing dimension %d", d) + + shift_p = [slice(None)] * dims + shift_p[d] = slice(1, in_shape[d]) + + shift_n = [slice(None)] * dims + shift_n[d] = slice(0, in_shape[d] - 1) + + diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0 + + logger.debug("diff shape is %s", diff.shape) + + target = [slice(None, None, 2)] * dims + target[d] = slice(1, out_shape[d], 2) + + logger.debug("target slices are %s", target) + + boundaries[tuple(target)] = diff + + return boundaries + + def __normalize(self, distances, norm, normalize_args): + if norm == "tanh": + scale = normalize_args + return np.tanh(distances / scale) + else: + raise ValueError("Only tanh is supported for normalization") + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims) diff --git a/dacapo/experiments/trainers/gp_augments/__init__.py b/dacapo/experiments/trainers/gp_augments/__init__.py index 0c93d4603..7e901c76f 100644 --- a/dacapo/experiments/trainers/gp_augments/__init__.py +++ b/dacapo/experiments/trainers/gp_augments/__init__.py @@ -4,3 +4,4 @@ from .gamma_config import GammaAugmentConfig from .intensity_config import IntensityAugmentConfig from .intensity_scale_shift_config import IntensityScaleShiftAugmentConfig +from .gaussian_noise_config import GaussianNoiseAugmentConfig diff --git a/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py b/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py new file mode 100644 index 000000000..ab3694d09 --- /dev/null +++ b/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py @@ -0,0 +1,22 @@ +from .augment_config import AugmentConfig + +import gunpowder as gp + +import attr + + +@attr.s +class GaussianNoiseAugmentConfig(AugmentConfig): + mean: float = attr.ib( + metadata={"help_text": "The mean of the gaussian noise to apply to your data."}, + default=0.0, + ) + var: float = attr.ib( + metadata={"help_text": "The variance of the gaussian noise."}, + default=0.05, + ) + + def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): + return gp.NoiseAugment( + array=raw_key, mode="gaussian", mean=self.mean, var=self.var + ) diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index 86de2161c..c5c1d9456 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -1,3 +1,4 @@ +from typing import List, Optional from .augment_config import AugmentConfig import gunpowder as gp @@ -7,5 +8,48 @@ @attr.s class SimpleAugmentConfig(AugmentConfig): - def node(self, _raw_key=None, _gt_key=None, _mask_key=None): - return gp.SimpleAugment() + mirror_only: Optional[List[int]] = attr.ib( + default=None, + metadata={ + "help_text": ( + "If set, only mirror between the given axes. This is useful to exclude channels that have a set direction, like time." + ) + }, + ) + transpose_only: Optional[List[int]] = attr.ib( + default=None, + metadata={ + "help_text": ( + "If set, only transpose between the given axes. This is useful to exclude channels that have a set direction, like time." + ) + }, + ) + mirror_probs: Optional[List[float]] = attr.ib( + default=None, + metadata={ + "help_text": ( + "Probability of mirroring along each axis. Defaults to 0.5 for each axis." + ) + }, + ) + transpose_probs: Optional[List[float]] = attr.ib( + default=None, + metadata={ + "help_text": ( + "Probability of transposing along each axis. Defaults to 0.5 for each axis." + ) + }, + ) + + def node( + self, + _raw_key=None, + _gt_key=None, + _mask_key=None, + ): + return gp.SimpleAugment( + self.mirror_only, + self.transpose_only, + self.mirror_probs, + self.transpose_probs, + ) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index efec630f0..9a445932d 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -37,6 +37,8 @@ def __init__(self, trainer_config): self.print_profiling = 100 self.snapshot_iteration = trainer_config.snapshot_interval self.min_masked = trainer_config.min_masked + self.reject_probability = trainer_config.reject_probability + self.weighted_reject = trainer_config.weighted_reject self.augments = trainer_config.augments self.mask_integral_downsample_factor = 4 @@ -46,12 +48,14 @@ def __init__(self, trainer_config): def create_optimizer(self, model): optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) - self.scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=0.01, - end_factor=1.0, - total_iters=1000, - last_epoch=-1, + self.scheduler = ( + torch.optim.lr_scheduler.LinearLR( # TODO: add scheduler to config + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=1000, + last_epoch=-1, + ) ) return optimizer @@ -60,6 +64,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): output_shape = Coordinate(model.output_shape) # get voxel sizes + # TODO: make dataset specific / resample raw_voxel_size = datasets[0].raw.voxel_size prediction_voxel_size = model.scale(raw_voxel_size) @@ -80,7 +85,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") - dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") + dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") # TODO: put these back in datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -88,7 +93,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): # Get source nodes dataset_sources = [] weights = [] - for dataset in datasets: + for dataset in datasets: # TODO: add automatic resampling? weights.append(dataset.weight) assert isinstance(dataset.weight, int), dataset @@ -141,24 +146,46 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) ) - dataset_source += gp.Reject(mask_placeholder, 1e-6) + if self.weighted_reject: + # Add predictor nodes to dataset_source + for augment in self.augments: + dataset_source += augment.node(raw_key, gt_key, mask_key) - for augment in self.augments: - dataset_source += augment.node(raw_key, gt_key, mask_key) + dataset_source += DaCapoTargetFilter( + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) - # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=dataset_weight_key, - mask_key=mask_key, - ) + dataset_source += gp.Reject( + mask=dataset_weight_key, + min_masked=self.min_masked, + reject_probability=self.reject_probability, + ) + else: + dataset_source += gp.Reject( + mask=mask_placeholder, + min_masked=self.min_masked, + reject_probability=self.reject_probability, + ) + + for augment in self.augments: + dataset_source += augment.node(raw_key, gt_key, mask_key) + + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) # Add predictor nodes to pipeline - pipeline += DaCapoTargetFilter( + pipeline += DaCapoTargetFilter( # TODO: why are there two of these? task.predictor, gt_key=gt_key, target_key=target_key, @@ -255,7 +282,7 @@ def iterate(self, num_iterations, model, optimizer, device): } if mask is not None: snapshot_arrays["volumes/mask"] = mask - logger.warning( + logger.info( f"Saving Snapshot. Iteration: {iteration}, " f"Loss: {loss.detach().cpu().numpy().item()}!" ) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index ae4243059..032dba23d 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -27,5 +27,7 @@ class GunpowderTrainerConfig(TrainerConfig): default=None, metadata={"help_text": "Number of iterations before saving a new snapshot."}, ) - min_masked: Optional[float] = attr.ib(default=0.15) - clip_raw: bool = attr.ib(default=True) + min_masked: Optional[float] = attr.ib(default=1e-6) + reject_probability: Optional[float or None] = attr.ib(default=1) + weighted_reject: bool = attr.ib(default=False) + clip_raw: bool = attr.ib(default=False) diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index 8fba05687..17727cc22 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -113,7 +113,7 @@ def get_best( best value in two seperate arrays. """ if "criteria" in data.coords.keys(): - if len(data.coords["criteria"].shape) == 1: + if len(data.coords["criteria"].shape) > 1: criteria_bests: List[Tuple[xr.DataArray, xr.DataArray]] = [] for criterion in data.coords["criteria"].values: if self.evaluation_scores.higher_is_better(criterion.item()): @@ -142,7 +142,10 @@ def get_best( return (da_best_indexes, da_best_scores) else: if self.evaluation_scores.higher_is_better( - data.coords["criteria"].item() + list(data.coords["criteria"].values)[ + 0 + ] # TODO: what is the intended behavior here? (hot fix in place) + # data.coords["criteria"].item() ): return ( data.idxmax(dim, skipna=True, fill_value=None), diff --git a/dacapo/gp/dacapo_create_target.py b/dacapo/gp/dacapo_create_target.py index f136c5c7b..42358b7b0 100644 --- a/dacapo/gp/dacapo_create_target.py +++ b/dacapo/gp/dacapo_create_target.py @@ -84,7 +84,9 @@ def process(self, batch, request): gt_array = NumpyArray.from_gp_array(batch[self.gt_key]) target_array = self.predictor.create_target(gt_array) - mask_array = NumpyArray.from_gp_array(batch[self.mask_key]) + mask_array = NumpyArray.from_gp_array( + batch[self.mask_key] + ) # TODO: doesn't this require mask_key to be set? if self.target_key is not None: request_spec = request[self.target_key] diff --git a/dacapo/gp/elastic_augment_fuse.py b/dacapo/gp/elastic_augment_fuse.py index c7163f68d..70d4bc059 100644 --- a/dacapo/gp/elastic_augment_fuse.py +++ b/dacapo/gp/elastic_augment_fuse.py @@ -138,7 +138,7 @@ def _min_max_mean_std(ndarray, prefix=""): return "" -class ElasticAugment(BatchFilter): +class ElasticAugment(BatchFilter): # TODO: replace DeformAugment node from gunpowder """ Elasticly deform a batch. Requests larger batches upstream to avoid data loss due to rotation and jitter. diff --git a/dacapo/predict.py b/dacapo/predict.py index 340517528..4fed2d484 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -71,6 +71,7 @@ def predict( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) + pipeline += gp.Normalize(raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) diff --git a/dacapo/store/local_array_store.py b/dacapo/store/local_array_store.py index c1581fc7b..9265c8f67 100644 --- a/dacapo/store/local_array_store.py +++ b/dacapo/store/local_array_store.py @@ -55,7 +55,7 @@ def validation_input_arrays( If we don't store these we would have to look up the datasplit config and figure out where to find the inputs for each run. If we write the data then we don't need to search for it. - This convenience comes at the cost of some extra memory usage. + This convenience comes at the cost of extra memory usage. #TODO: FIX THIS - refactor """ container = self.validation_container(run_name).container diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index c5f0ba5ff..c34afc6c3 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -62,7 +62,9 @@ def retrieve_weights(self, run: str, iteration: int) -> Weights: return weights - def _retrieve_weights(self, run: str, key: str) -> Weights: + def _retrieve_weights( + self, run: str, key: str + ) -> Weights: # TODO: redundant with above? weights_name = self.__get_weights_dir(run) / key if not weights_name.exists(): weights_name = self.__get_weights_dir(run) / "iterations" / key @@ -104,14 +106,14 @@ def retrieve_best(self, run: str, dataset: str, criterion: str) -> int: logger.info("Retrieving weights for run %s, criterion %s", run, criterion) weights_info = json.loads( - (self.__get_weights_dir(run) / criterion / f"{dataset}.json") + (self.__get_weights_dir(run) / dataset / f"{criterion}.json") .open("r") .read() ) return weights_info["iteration"] - def _load_best(self, run: Run, criterion: str): + def _load_best(self, run: Run, criterion: str): # TODO: probably won't work logger.info("Retrieving weights for run %s, criterion %s", run, criterion) weights_name = self.__get_weights_dir(run) / f"{criterion}" diff --git a/dacapo/train.py b/dacapo/train.py index 7beb096b4..1c104a55f 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dacapo.store.create_store import create_array_store from .experiments import Run from .compute_context import LocalTorch, ComputeContext @@ -10,6 +11,7 @@ import logging logger = logging.getLogger(__name__) +logger.setLevel("INFO") def train(run_name: str, compute_context: ComputeContext = LocalTorch()): @@ -101,7 +103,16 @@ def train_run( logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " + "Filling stats with last observed values." ) + last_iteration_stats = run.training_stats.iteration_stats[-1] + for i in range( + last_iteration_stats.iteration, latest_weights_iteration - 1 + ): + new_iteration_stats = deepcopy(last_iteration_stats) + new_iteration_stats.iteration = i + 1 + run.training_stats.add_iteration_stats(new_iteration_stats) + trained_until = run.training_stats.trained_until() # start/resume training diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index f5adcffca..5cd5ee597 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -12,6 +12,7 @@ def balance_weights( clipmin: float = 0.05, clipmax: float = 0.95, moving_counts: Optional[List[Dict[int, Tuple[int, int]]]] = None, + cross_class: bool = True, ): if moving_counts is None: moving_counts = [] @@ -29,10 +30,6 @@ def balance_weights( # initialize error scale with 1s error_scale = np.ones(label_data.shape, dtype=np.float32) - # set error_scale to 0 in masked-out areas - for mask in masks: - error_scale = error_scale * mask - if slab is None: slab = error_scale.shape else: @@ -77,4 +74,14 @@ def balance_weights( # scale_slab the masked-in scale_slab with the class weights scale_slab *= np.take(w, labels_slab) + if cross_class: + # get maximum error scale using first dimension + shape = error_scale.shape + error_scale = np.max(error_scale, axis=0) + error_scale = np.broadcast_to(error_scale, shape) + + # set error_scale to 0 in masked-out areas + for mask in masks: + error_scale = error_scale * mask + return error_scale, moving_counts diff --git a/dacapo/utils/voi.py b/dacapo/utils/voi.py index e5399a443..e64135bbc 100644 --- a/dacapo/utils/voi.py +++ b/dacapo/utils/voi.py @@ -10,6 +10,7 @@ import scipy.sparse as sparse +# TODO: Why are we not using this? def voi(reconstruction, groundtruth, ignore_reconstruction=[], ignore_groundtruth=[0]): """Return the conditional entropies of the variation of information metric. [1] diff --git a/dacapo/validate.py b/dacapo/validate.py index a1cf9da7d..39b231b8e 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -12,6 +12,7 @@ import torch from pathlib import Path +from reloading import reloading import logging logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def validate( return validate_run(run, iteration, compute_context=compute_context) +# @reloading # allows us to fix validation bugs without interrupting training def validate_run( run: Run, iteration: int, compute_context: ComputeContext = LocalTorch() ): @@ -54,85 +56,97 @@ def validate_run( load the weights of that iteration, it is assumed that the model is already loaded correctly. Returns the best parameters and scores for this iteration.""" - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - run.model.eval() - - if ( - run.datasplit.validate is None - or len(run.datasplit.validate) == 0 - or run.datasplit.validate[0].gt is None - ): - logger.info("Cannot validate run %s. Continuing training!", run.name) - return None, None - - # get array and weight store - weights_store = create_weights_store() - array_store = create_array_store() - iteration_scores = [] - - # get post processor and evaluator - post_processor = run.task.post_processor - evaluator = run.task.evaluator - - # Initialize the evaluator with the best scores seen so far - evaluator.set_best(run.validation_scores) - - for validation_dataset in run.datasplit.validate: - assert ( - validation_dataset.gt is not None - ), "We do not yet support validating on datasets without ground truth" - logger.info( - "Validating run %s on dataset %s", run.name, validation_dataset.name - ) + try: # we don't want this to hold up training + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.to(compute_context.device) + run.model.eval() - ( - input_raw_array_identifier, - input_gt_array_identifier, - ) = array_store.validation_input_arrays(run.name, validation_dataset.name) if ( - not Path( - f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" - ).exists() - or not Path( - f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" - ).exists() + run.datasplit.validate is None + or len(run.datasplit.validate) == 0 + or run.datasplit.validate[0].gt is None ): - logger.info("Copying validation inputs!") - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = validation_dataset.gt.roi - - input_roi = ( - output_roi.grow(context, context) - .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") - .intersect(validation_dataset.raw.roi) + logger.info("Cannot validate run %s. Continuing training!", run.name) + return None, None + + # get array and weight store + weights_store = create_weights_store() + array_store = create_array_store() + iteration_scores = [] + + # get post processor and evaluator + post_processor = run.task.post_processor + evaluator = run.task.evaluator + + # Initialize the evaluator with the best scores seen so far + evaluator.set_best(run.validation_scores) + + for validation_dataset in run.datasplit.validate: + if validation_dataset.gt is None: + logger.error( + "We do not yet support validating on datasets without ground truth" + ) + raise NotImplementedError + + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name ) - input_raw = ZarrArray.create_from_array_identifier( + + ( input_raw_array_identifier, - validation_dataset.raw.axes, - input_roi, - validation_dataset.raw.num_channels, - validation_dataset.raw.voxel_size, - validation_dataset.raw.dtype, - name=f"{run.name}_validation_raw", - write_size=input_size, - ) - input_raw[input_roi] = validation_dataset.raw[input_roi] - input_gt = ZarrArray.create_from_array_identifier( input_gt_array_identifier, - validation_dataset.gt.axes, - output_roi, - validation_dataset.gt.num_channels, - validation_dataset.gt.voxel_size, - validation_dataset.gt.dtype, - name=f"{run.name}_validation_gt", - write_size=output_size, + ) = array_store.validation_input_arrays(run.name, validation_dataset.name) + if ( + not Path( + f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" + ).exists() + or not Path( + f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" + ).exists() + ): + logger.info("Copying validation inputs!") + input_voxel_size = validation_dataset.raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = validation_dataset.gt.roi + + input_roi = ( + output_roi.grow(context, context) + .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") + .intersect(validation_dataset.raw.roi) + ) + input_raw = ZarrArray.create_from_array_identifier( + input_raw_array_identifier, + validation_dataset.raw.axes, + input_roi, + validation_dataset.raw.num_channels, + validation_dataset.raw.voxel_size, + validation_dataset.raw.dtype, + name=f"{run.name}_validation_raw", + write_size=input_size, + ) + input_raw[input_roi] = validation_dataset.raw[input_roi] + input_gt = ZarrArray.create_from_array_identifier( + input_gt_array_identifier, + validation_dataset.gt.axes, + output_roi, + validation_dataset.gt.num_channels, + validation_dataset.gt.voxel_size, + validation_dataset.gt.dtype, + name=f"{run.name}_validation_gt", + write_size=output_size, + ) + input_gt[output_roi] = validation_dataset.gt[output_roi] + else: + logger.info("validation inputs already copied!") + + prediction_array_identifier = array_store.validation_prediction_array( + run.name, iteration, validation_dataset ) input_gt[output_roi] = validation_dataset.gt[output_roi] else: @@ -160,58 +174,126 @@ def validate_run( run.name, iteration, parameters, validation_dataset ) - post_processed_array = post_processor.process( - parameters, output_array_identifier - ) + post_processor.set_prediction(prediction_array_identifier) - scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) + dataset_iteration_scores = [] + # set up dict for overall best scores + overall_best_scores = {} for criterion in run.validation_scores.criteria: - # replace predictions in array with the new better predictions - if evaluator.is_best( - validation_dataset, - parameters, - criterion, - scores, - ): - best_array_identifier = array_store.best_validation_array( - run.name, criterion, index=validation_dataset.name - ) - best_array = ZarrArray.create_from_array_identifier( - best_array_identifier, - post_processed_array.axes, - post_processed_array.roi, - post_processed_array.num_channels, - post_processed_array.voxel_size, - post_processed_array.dtype, + overall_best_scores[criterion] = evaluator.get_overall_best( + validation_dataset, criterion + ) + + any_overall_best = False + output_array_identifiers = [] + for parameters in post_processor.enumerate_parameters(): + output_array_identifier = array_store.validation_output_array( + run.name, iteration, parameters, validation_dataset + ) + output_array_identifiers.append(output_array_identifier) + post_processed_array = post_processor.process( + parameters, output_array_identifier + ) + + try: + scores = evaluator.evaluate( + output_array_identifier, validation_dataset.gt ) - best_array[best_array.roi] = post_processed_array[ - post_processed_array.roi - ] - best_array.add_metadata( - { - "iteration": iteration, - criterion: getattr(scores, criterion), - "parameters_id": parameters.id, - } - ) - weights_store.store_best( - run, iteration, validation_dataset.name, criterion + for criterion in run.validation_scores.criteria: + # replace predictions in array with the new better predictions + if evaluator.is_best( + validation_dataset, + parameters, + criterion, + scores, + ): + # then this is the current best score for this parameter, but not necessarily the overall best + higher_is_better = scores.higher_is_better(criterion) + # initial_best_score = overall_best_scores[criterion] + current_score = getattr(scores, criterion) + if not overall_best_scores[ + criterion + ] or ( # TODO: should be in evaluator + ( + higher_is_better + and current_score > overall_best_scores[criterion] + ) + or ( + not higher_is_better + and current_score < overall_best_scores[criterion] + ) + ): + any_overall_best = True + overall_best_scores[criterion] = current_score + + # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 + # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file + # Either we do that, or we only write out the overall best, regardless of parameters + best_array_identifier = ( + array_store.best_validation_array( + run.name, + criterion, + index=validation_dataset.name, + ) + ) + best_array = ZarrArray.create_from_array_identifier( + best_array_identifier, + post_processed_array.axes, + post_processed_array.roi, + post_processed_array.num_channels, + post_processed_array.voxel_size, + post_processed_array.dtype, + ) + best_array[best_array.roi] = post_processed_array[ + post_processed_array.roi + ] + best_array.add_metadata( + { + "iteration": iteration, + criterion: getattr(scores, criterion), + "parameters_id": parameters.id, + } + ) + weights_store.store_best( + run, iteration, validation_dataset.name, criterion + ) + except: + logger.error( + f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.", + exc_info=True, ) - # delete current output. We only keep the best outputs as determined by - # the evaluator - array_store.remove(output_array_identifier) + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) + if not any_overall_best: + # We only keep the best outputs as determined by the evaluator + for output_array_identifier in output_array_identifiers: + array_store.remove(prediction_array_identifier) + array_store.remove(output_array_identifier) - iteration_scores.append(dataset_iteration_scores) - array_store.remove(prediction_array_identifier) + iteration_scores.append(dataset_iteration_scores) - run.validation_scores.add_iteration_scores( - ValidationIterationScores(iteration, iteration_scores) - ) - stats_store = create_stats_store() - stats_store.store_validation_iteration_scores(run.name, run.validation_scores) + run.validation_scores.add_iteration_scores( + ValidationIterationScores(iteration, iteration_scores) + ) + stats_store = create_stats_store() + stats_store.store_validation_iteration_scores(run.name, run.validation_scores) + except Exception as e: + logger.error( + f"Validation failed for run {run.name} at iteration " f"{iteration}.", + exc_info=e, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("run_name", type=str) + parser.add_argument("iteration", type=int) + args = parser.parse_args() + + validate(args.run_name, args.iteration)