From 68258d742e21684b42dead502e49d5315c4b84e6 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Sat, 26 Aug 2023 16:06:46 -0400 Subject: [PATCH 01/33] feat: better tqdm, training reloading --- dacapo/train.py | 32 ++++++++++++++++++++++++-------- dacapo/validate.py | 13 +++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/dacapo/train.py b/dacapo/train.py index 9203c1be3..0874f5002 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -123,6 +123,13 @@ def train_run( ) with run.trainer as trainer: + bar = tqdm( + total=run.train_until, + initial=trained_until, + desc="training", + # unit="iteration", + position=0, + ) while trained_until < run.train_until: # train for at most 100 iterations at a time, then store training stats iterations = min(100, run.train_until - trained_until) @@ -139,6 +146,8 @@ def train_run( iterations, ): run.training_stats.add_iteration_stats(iteration_stats) + bar.update(1) + bar.set_postfix_str(s=f"loss = {iteration_stats['loss']}") if (iteration_stats.iteration + 1) % run.validation_interval == 0: break @@ -161,14 +170,21 @@ def train_run( run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) weights_store.store_weights(run, iteration_stats.iteration + 1) - validate_run( - run, - iteration_stats.iteration + 1, - compute_context=compute_context, - ) - stats_store.store_validation_iteration_scores( - run.name, run.validation_scores - ) + try: + validate_run( + run, + iteration_stats.iteration + 1, + compute_context=compute_context, + ) + stats_store.store_validation_iteration_scores( + run.name, run.validation_scores + ) + except Exception as e: + logger.error( + f"Validation failed for run {run.name} at iteration " + f"{iteration_stats.iteration + 1}.", + exc_info=e, + ) stats_store.store_training_stats(run.name, run.training_stats) # make sure to move optimizer back to the correct device diff --git a/dacapo/validate.py b/dacapo/validate.py index 25b7463e1..ae07844bd 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -12,6 +12,7 @@ import torch from pathlib import Path +from reloading import reloading import logging logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def validate( return validate_run(run, iteration, compute_context=compute_context) +@reloading def validate_run( run: Run, iteration: int, compute_context: ComputeContext = LocalTorch() ): @@ -213,3 +215,14 @@ def validate_run( ) stats_store = create_stats_store() stats_store.store_validation_iteration_scores(run.name, run.validation_scores) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("run_name", type=str) + parser.add_argument("iteration", type=int) + args = parser.parse_args() + + validate(args.run_name, args.iteration) From 9eee5a7a8eee96884a1005deade3750c32203e08 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 29 Aug 2023 11:59:49 -0400 Subject: [PATCH 02/33] feat: numpy version requirement (1.22.3); watershed Watershed bug fixes for float64, bias hyperparameter, & npi.remap usage --- .../post_processors/watershed_post_processor.py | 13 +++++++++---- setup.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 8fa6104bc..36ff106ae 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -44,9 +44,9 @@ def process(self, parameters, output_array_identifier): # if a previous segmentation is provided, it must have a "grid graph" # in its metadata. pred_data = self.prediction_array[self.prediction_array.roi] - affs = pred_data[: len(self.offsets)] + affs = pred_data[: len(self.offsets)].astype(np.float64) segmentation = mws.agglom( - affs - 0.5, + affs - parameters.bias, self.offsets, ) # filter fragments @@ -59,12 +59,17 @@ def process(self, parameters, output_array_identifier): for fragment, mean in zip( fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) ): - if mean < 0.5: + if mean < parameters.bias: filtered_fragments.append(fragment) filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) replace = np.zeros_like(filtered_fragments) - segmentation = npi.remap(segmentation, filtered_fragments, replace) + + # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input + if filtered_fragments.size > 0: + segmentation = npi.remap( + segmentation.flatten(), filtered_fragments, replace + ).reshape(segmentation.shape) output_array[self.prediction_array.roi] = segmentation diff --git a/setup.py b/setup.py index 3ba1f0d0b..1721fa42f 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ entry_points={"console_scripts": ["dacapo=dacapo.cli:cli"]}, include_package_data=True, install_requires=[ - "numpy", + "numpy==1.22.3", "pyyaml", "zarr", "cattrs", From 703a31d58184d77fc2ce381c5deaab6a1ed9a594 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 30 Aug 2023 15:10:44 -0400 Subject: [PATCH 03/33] feat: improve training logging --- dacapo/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dacapo/train.py b/dacapo/train.py index 0874f5002..82e2c772f 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -10,6 +10,7 @@ import logging logger = logging.getLogger(__name__) +logger.setLevel("INFO") def train(run_name: str, compute_context: ComputeContext = LocalTorch()): @@ -129,6 +130,7 @@ def train_run( desc="training", # unit="iteration", position=0, + leave=True, ) while trained_until < run.train_until: # train for at most 100 iterations at a time, then store training stats @@ -142,8 +144,9 @@ def train_run( run.optimizer, compute_context.device, ), - "training", + "training inner loop", iterations, + position=1, ): run.training_stats.add_iteration_stats(iteration_stats) bar.update(1) From 8a03a57f33fafa93576c79f9c8190126263c0281 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 1 Sep 2023 10:44:57 -0400 Subject: [PATCH 04/33] feat: stdout logging, version --- .../experiments/trainers/gunpowder_trainer.py | 2 +- dacapo/train.py | 23 ++++++------------- setup.py | 6 ++--- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index efec630f0..ab6fa7603 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -255,7 +255,7 @@ def iterate(self, num_iterations, model, optimizer, device): } if mask is not None: snapshot_arrays["volumes/mask"] = mask - logger.warning( + logger.info( f"Saving Snapshot. Iteration: {iteration}, " f"Loss: {loss.detach().cpu().numpy().item()}!" ) diff --git a/dacapo/train.py b/dacapo/train.py index 82e2c772f..a2068dd61 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -124,33 +124,24 @@ def train_run( ) with run.trainer as trainer: - bar = tqdm( - total=run.train_until, - initial=trained_until, - desc="training", - # unit="iteration", - position=0, - leave=True, - ) while trained_until < run.train_until: # train for at most 100 iterations at a time, then store training stats iterations = min(100, run.train_until - trained_until) iteration_stats = None - - for iteration_stats in tqdm( + bar = tqdm( trainer.iterate( iterations, run.model, run.optimizer, compute_context.device, ), - "training inner loop", - iterations, - position=1, - ): + desc=f"training until {iterations + trained_until}", + total=run.train_until, + initial=trained_until, + ) + for iteration_stats in bar: run.training_stats.add_iteration_stats(iteration_stats) - bar.update(1) - bar.set_postfix_str(s=f"loss = {iteration_stats['loss']}") + bar.set_postfix({"loss": iteration_stats.loss}) if (iteration_stats.iteration + 1) % run.validation_interval == 0: break diff --git a/setup.py b/setup.py index 1721fa42f..9d3b75231 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,10 @@ description="Framework for easy composition of volumetric machine learning jobs.", long_description=open("README.md", "r").read(), long_description_content_type="text/markdown", - version="0.1", + version="0.1.1", url="https://github.com/funkelab/dacapo", - author="Jan Funke, Will Patton", - author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org", + author="Jan Funke, Will Patton, Jeff Rhoades", + author_email="funkej@janelia.hhmi.org, pattonw@janelia.hhmi.org, rhoadesj@hhmi.org", license="MIT", packages=find_packages(), entry_points={"console_scripts": ["dacapo=dacapo.cli:cli"]}, From 96269226f4332e394fe2ad6fb40c30c9256a0e1d Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 5 Sep 2023 10:08:56 -0400 Subject: [PATCH 05/33] feat: validation logging, keep best validation --- .../trainers/gunpowder_trainer_config.py | 2 +- dacapo/validate.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index ae4243059..a6b90c659 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -28,4 +28,4 @@ class GunpowderTrainerConfig(TrainerConfig): metadata={"help_text": "Number of iterations before saving a new snapshot."}, ) min_masked: Optional[float] = attr.ib(default=0.15) - clip_raw: bool = attr.ib(default=True) + clip_raw: bool = attr.ib(default=False) diff --git a/dacapo/validate.py b/dacapo/validate.py index ae07844bd..855278908 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -81,9 +81,12 @@ def validate_run( evaluator.set_best(run.validation_scores) for validation_dataset in run.datasplit.validate: - assert ( - validation_dataset.gt is not None - ), "We do not yet support validating on datasets without ground truth" + if validation_dataset.gt is None: + logger.error( + "We do not yet support validating on datasets without ground truth" + ) + raise NotImplementedError + logger.info( "Validating run %s on dataset %s", run.name, validation_dataset.name ) @@ -166,6 +169,7 @@ def validate_run( scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) + any_best = False for criterion in run.validation_scores.criteria: # replace predictions in array with the new better predictions if evaluator.is_best( @@ -174,6 +178,7 @@ def validate_run( criterion, scores, ): + any_best = True best_array_identifier = array_store.best_validation_array( run.name, criterion, index=validation_dataset.name ) @@ -201,7 +206,8 @@ def validate_run( # delete current output. We only keep the best outputs as determined by # the evaluator - array_store.remove(output_array_identifier) + if not any_best: + array_store.remove(output_array_identifier) dataset_iteration_scores.append( [getattr(scores, criterion) for criterion in scores.criteria] From 8212b8b553fb12a338dd8413096b84f909d484c0 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 7 Sep 2023 23:23:01 -0400 Subject: [PATCH 06/33] WIP: apply.py --- dacapo/apply.py | 95 ++++++++++++++++++- .../watershed_post_processor.py | 2 +- dacapo/predict.py | 8 +- dacapo/store/local_weights_store.py | 6 +- dacapo/validate.py | 3 +- 5 files changed, 102 insertions(+), 12 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 64f23df3c..04df9e145 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -1,12 +1,101 @@ import logging +from funlib.geometry import Roi + +from dacapo.experiments.tasks.post_processors.post_processor_parameters import PostProcessorParameters +from dacapo.store.array_store import LocalArrayIdentifier +from .predict import predict +from .compute_context import LocalTorch, ComputeContext +from .experiments import Run, ValidationIterationScores +from .experiments.datasplits.datasets.arrays import ZarrArray +from .store import ( + create_array_store, + create_config_store, + create_stats_store, + create_weights_store, +) + +import torch + +from pathlib import Path logger = logging.getLogger(__name__) -def apply(run_name: str, iteration: int, dataset_name: str): +def apply( + run_name: str, + dataset_name: str, + output_path: str, + validation_name: str, + roi: Roi or None = None, + criterion: str or None = "voi", + iteration: int or None = None, + compute_context: ComputeContext = LocalTorch(), +): + """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" + + # create run + + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # read in previous training/validation stats TODO: is this necessary? + + stats_store = create_stats_store() + run.training_stats = stats_store.retrieve_training_stats(run_name) + run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( + run_name + ) + + # create weights store + weights_store = create_weights_store() + + # load weights + if iteration is None: + # weights_store._load_best(run, criterion) + iteration = weights_store.retrieve_best(run_name, validation_name, criterion) + weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights? + + # find the best parameters + scores = [s for s in run.validation_scores.scores if s.iteration == iteration + 1][ + 0 + ].scores + paremeters = ... scores[criterion]??? + + # make array identifiers for input, predictions and outputs + array_store = create_array_store() + input_array_identifier = ... + prediction_array_identifier = LocalArrayIdentifier( + output_path, dataset_name, "prediction"... + ) + output_array_identifier = LocalArrayIdentifier( + output_path, dataset_name, "output", parameters... + ) + logger.info( - "Applying results from run %s at iteration %d to dataset %s", - run_name, + "Applying best results from run %s at iteration %i to dataset %s", + run.name, iteration, dataset_name, ) + return apply_run(run, dataset_name, prediction_array_identifier, output_array_identifier, parameters, roi, compute_context) + + +def apply_run( + run: Run, + parameters: PostProcessorParameters, + input_array_identifier: LocalArrayIdentifier, + prediction_array_identifier: LocalArrayIdentifier, + output_array_identifier: LocalArrayIdentifier, + roi: Roi or None = None, + compute_context: ComputeContext = LocalTorch(), +): + """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" + + # Find the best parameters + + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.eval() + + ... diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 36ff106ae..3cdc74456 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -24,7 +24,7 @@ def enumerate_parameters(self): """Enumerate all possible parameters of this post-processor. Should return instances of ``PostProcessorParameters``.""" - for i, bias in enumerate([0.1, 0.5, 0.9]): + for i, bias in enumerate([0.1, 0.3, 0.5, 0.7, 0.9]): yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): diff --git a/dacapo/predict.py b/dacapo/predict.py index 5a40e303c..221dd219b 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -76,7 +76,7 @@ def predict( gt_padding = (output_size - output_roi.shape) % output_size prediction_roi = output_roi.grow(gt_padding) - + # TODO: Add cache node? # predict pipeline += gp_torch.Predict( model=model, @@ -97,8 +97,6 @@ def predict( pipeline += gp.Squeeze([raw, prediction]) # raw: (c, d, h, w) # prediction: (c, d, h, w) - # raw: (c, d, h, w) - # prediction: (c, d, h, w) # write to zarr pipeline += gp.ZarrWrite( @@ -112,7 +110,9 @@ def predict( ref_request = gp.BatchRequest() ref_request.add(raw, input_size) ref_request.add(prediction, output_size) - pipeline += gp.Scan(ref_request) + pipeline += gp.Scan( + ref_request + ) # TODO: This is a slow implementation for rendering # build pipeline and predict in complete output ROI diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index c5f0ba5ff..dbf45aa2c 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -62,7 +62,9 @@ def retrieve_weights(self, run: str, iteration: int) -> Weights: return weights - def _retrieve_weights(self, run: str, key: str) -> Weights: + def _retrieve_weights( + self, run: str, key: str + ) -> Weights: # TODO: redundant with above? weights_name = self.__get_weights_dir(run) / key if not weights_name.exists(): weights_name = self.__get_weights_dir(run) / "iterations" / key @@ -111,7 +113,7 @@ def retrieve_best(self, run: str, dataset: str, criterion: str) -> int: return weights_info["iteration"] - def _load_best(self, run: Run, criterion: str): + def _load_best(self, run: Run, criterion: str): # TODO: probably won't work logger.info("Retrieving weights for run %s, criterion %s", run, criterion) weights_name = self.__get_weights_dir(run) / f"{criterion}" diff --git a/dacapo/validate.py b/dacapo/validate.py index 855278908..a50a87df3 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -206,8 +206,7 @@ def validate_run( # delete current output. We only keep the best outputs as determined by # the evaluator - if not any_best: - array_store.remove(output_array_identifier) + array_store.remove(output_array_identifier) dataset_iteration_scores.append( [getattr(scores, criterion) for criterion in scores.criteria] From 8326400bb7c082f68d703441ecbadc110e107342 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 11 Sep 2023 10:55:47 -0400 Subject: [PATCH 07/33] feat: keep actual best scoring dataset --- dacapo/apply.py | 48 ++++++---- .../experiments/tasks/evaluators/evaluator.py | 23 +++++ dacapo/store/local_weights_store.py | 2 +- dacapo/validate.py | 92 ++++++++++++------- 4 files changed, 114 insertions(+), 51 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 04df9e145..f94a3c6d8 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -1,13 +1,15 @@ import logging from funlib.geometry import Roi -from dacapo.experiments.tasks.post_processors.post_processor_parameters import PostProcessorParameters +from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( + PostProcessorParameters, +) from dacapo.store.array_store import LocalArrayIdentifier -from .predict import predict -from .compute_context import LocalTorch, ComputeContext -from .experiments import Run, ValidationIterationScores -from .experiments.datasplits.datasets.arrays import ZarrArray -from .store import ( +from dacapo.predict import predict +from dacapo.compute_context import LocalTorch, ComputeContext +from dacapo.experiments import Run, ValidationIterationScores +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.store import ( create_array_store, create_config_store, create_stats_store, @@ -60,17 +62,17 @@ def apply( scores = [s for s in run.validation_scores.scores if s.iteration == iteration + 1][ 0 ].scores - paremeters = ... scores[criterion]??? - - # make array identifiers for input, predictions and outputs - array_store = create_array_store() - input_array_identifier = ... - prediction_array_identifier = LocalArrayIdentifier( - output_path, dataset_name, "prediction"... - ) - output_array_identifier = LocalArrayIdentifier( - output_path, dataset_name, "output", parameters... - ) + # paremeters = ... scores[criterion]??? + + # # make array identifiers for input, predictions and outputs + # array_store = create_array_store() + # input_array_identifier = ... + # prediction_array_identifier = LocalArrayIdentifier( + # output_path, dataset_name, "prediction"... + # ) + # output_array_identifier = LocalArrayIdentifier( + # output_path, dataset_name, "output", parameters... + # ) logger.info( "Applying best results from run %s at iteration %i to dataset %s", @@ -78,7 +80,15 @@ def apply( iteration, dataset_name, ) - return apply_run(run, dataset_name, prediction_array_identifier, output_array_identifier, parameters, roi, compute_context) + return apply_run( + run, + dataset_name, + prediction_array_identifier, + output_array_identifier, + parameters, + roi, + compute_context, + ) def apply_run( @@ -97,5 +107,5 @@ def apply_run( # set benchmark flag to True for performance torch.backends.cudnn.benchmark = True run.model.eval() - + ... diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 9d5cbbda0..f28fafca4 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -63,6 +63,29 @@ def is_best( else: return getattr(score, criterion) < previous_best_score + def get_overall_best( + self, dataset: "Dataset", criterion: str, higher_is_better: bool + ): + overall_best = None + if self.best_scores: + for _, parameter, _ in self.best_scores.keys(): + score = self.best_scores[(dataset, parameter, criterion)] + if score is None: + overall_best = None + else: + _, current_parameter_score = score + if overall_best is None: + overall_best = current_parameter_score + else: + if current_parameter_score: + if higher_is_better: + if current_parameter_score > overall_best: + overall_best = current_parameter_score + else: + if current_parameter_score < overall_best: + overall_best = current_parameter_score + return overall_best + def set_best(self, validation_scores: "ValidationScores") -> None: """ Find the best iteration for each dataset/post_processing_parameter/criterion diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index dbf45aa2c..c34afc6c3 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -106,7 +106,7 @@ def retrieve_best(self, run: str, dataset: str, criterion: str) -> int: logger.info("Retrieving weights for run %s, criterion %s", run, criterion) weights_info = json.loads( - (self.__get_weights_dir(run) / criterion / f"{dataset}.json") + (self.__get_weights_dir(run) / dataset / f"{criterion}.json") .open("r") .read() ) diff --git a/dacapo/validate.py b/dacapo/validate.py index a50a87df3..26a193f0b 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -158,6 +158,17 @@ def validate_run( dataset_iteration_scores = [] + # set up dict for overall best scores + overall_best_scores = {} + for criterion in run.validation_scores.criteria: + overall_best_scores[criterion] = evaluator.get_overall_best( + validation_dataset, + criterion, + run.validation_scores.evaluation_scores.higher_is_better( + criterion + ), # TODO: should be in evaluator + ) + for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( run.name, iteration, parameters, validation_dataset @@ -168,8 +179,6 @@ def validate_run( ) scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) - - any_best = False for criterion in run.validation_scores.criteria: # replace predictions in array with the new better predictions if evaluator.is_best( @@ -178,35 +187,56 @@ def validate_run( criterion, scores, ): - any_best = True - best_array_identifier = array_store.best_validation_array( - run.name, criterion, index=validation_dataset.name - ) - best_array = ZarrArray.create_from_array_identifier( - best_array_identifier, - post_processed_array.axes, - post_processed_array.roi, - post_processed_array.num_channels, - post_processed_array.voxel_size, - post_processed_array.dtype, - ) - best_array[best_array.roi] = post_processed_array[ - post_processed_array.roi - ] - best_array.add_metadata( - { - "iteration": iteration, - criterion: getattr(scores, criterion), - "parameters_id": parameters.id, - } - ) - weights_store.store_best( - run, iteration, validation_dataset.name, criterion - ) - - # delete current output. We only keep the best outputs as determined by - # the evaluator - array_store.remove(output_array_identifier) + # then this is the current best score for this parameter, but not necessarily the overall best + higher_is_better = scores.higher_is_better(criterion) + # initial_best_score = overall_best_scores[criterion] + current_score = getattr(scores, criterion) + if not overall_best_scores[ + criterion + ] or ( # TODO: should be in evaluator + ( + higher_is_better + and current_score > overall_best_scores[criterion] + ) + or ( + not higher_is_better + and current_score < overall_best_scores[criterion] + ) + ): + overall_best_scores[criterion] = current_score + + # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 + # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file + # Either we do that, or we only write out the overall best, regardless of parameters + best_array_identifier = array_store.best_validation_array( + run.name, criterion, index=validation_dataset.name + ) + best_array = ZarrArray.create_from_array_identifier( + best_array_identifier, + post_processed_array.axes, + post_processed_array.roi, + post_processed_array.num_channels, + post_processed_array.voxel_size, + post_processed_array.dtype, + ) + best_array[best_array.roi] = post_processed_array[ + post_processed_array.roi + ] + best_array.add_metadata( + { + "iteration": iteration, + criterion: getattr(scores, criterion), + "parameters_id": parameters.id, + } + ) + weights_store.store_best( + run, iteration, validation_dataset.name, criterion + ) + else: + # delete current output. We only keep the best outputs as determined by + # the evaluator + # remove deletion for now since we have to rerun later to double check things anyway + array_store.remove(output_array_identifier) dataset_iteration_scores.append( [getattr(scores, criterion) for criterion in scores.criteria] From 663fe0fa095e552760bd33e064b1e62b9abdd08b Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 11 Sep 2023 17:21:00 -0400 Subject: [PATCH 08/33] feat: ready to debug apply.py --- dacapo/apply.py | 92 ++++++++++++------- .../experiments/tasks/evaluators/evaluator.py | 31 ++++++- .../watershed_post_processor.py | 4 +- dacapo/experiments/validation_scores.py | 2 +- dacapo/predict.py | 3 +- dacapo/validate.py | 6 +- 6 files changed, 94 insertions(+), 44 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index f94a3c6d8..0e831bba6 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -1,5 +1,8 @@ import logging +from typing import Optional from funlib.geometry import Roi +import numpy as np +from dacapo.experiments.datasplits.datasets.dataset import Dataset from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( PostProcessorParameters, @@ -25,18 +28,25 @@ def apply( run_name: str, - dataset_name: str, + input_container: str, + input_dataset: str, output_path: str, - validation_name: str, - roi: Roi or None = None, - criterion: str or None = "voi", - iteration: int or None = None, + validation_dataset: Optional[str or Dataset] = None, + criterion: Optional[str] = "voi", + iteration: Optional[int] = None, + roi: Optional[Roi] = None, + num_cpu_workers: int = 4, + output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, compute_context: ComputeContext = LocalTorch(), ): """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" - # create run + assert (validation_dataset is not None and isinstance(criterion, str)) or ( + iteration is not None + ), "Either validation_dataset and criterion, or iteration must be provided." + # retrieving run + logger.info("Loading run %s", run_name) config_store = create_config_store() run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) @@ -55,38 +65,47 @@ def apply( # load weights if iteration is None: # weights_store._load_best(run, criterion) - iteration = weights_store.retrieve_best(run_name, validation_name, criterion) + iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion) + logger.info("Loading weights for iteration %i", iteration) weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights? # find the best parameters - scores = [s for s in run.validation_scores.scores if s.iteration == iteration + 1][ - 0 - ].scores - # paremeters = ... scores[criterion]??? - - # # make array identifiers for input, predictions and outputs - # array_store = create_array_store() - # input_array_identifier = ... - # prediction_array_identifier = LocalArrayIdentifier( - # output_path, dataset_name, "prediction"... - # ) - # output_array_identifier = LocalArrayIdentifier( - # output_path, dataset_name, "output", parameters... - # ) + if isinstance(validation_dataset, str): + val_ds_name = validation_dataset + validation_dataset = [ + dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name + ][0] + logger.info("Finding best parameters for validation dataset %s", validation_dataset) + parameters = run.task.evaluator.get_overall_best_parameters( + validation_dataset, criterion + ) + + # make array identifiers for input, predictions and outputs + array_store = create_array_store() + input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) + output_container = Path(output_path, Path(input_container).name) + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}_{parameters}" + ) + output_array_identifier = LocalArrayIdentifier( + output_container, f"output_{run_name}_{iteration}_{parameters}" + ) logger.info( "Applying best results from run %s at iteration %i to dataset %s", run.name, iteration, - dataset_name, + Path(input_container, input_dataset), ) return apply_run( run, - dataset_name, + parameters, + input_array_identifier prediction_array_identifier, output_array_identifier, - parameters, roi, + num_cpu_workers, + output_dtype, compute_context, ) @@ -97,15 +116,24 @@ def apply_run( input_array_identifier: LocalArrayIdentifier, prediction_array_identifier: LocalArrayIdentifier, output_array_identifier: LocalArrayIdentifier, - roi: Roi or None = None, + roi: Optional[Roi] = None, + num_cpu_workers: int = 4, + output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, compute_context: ComputeContext = LocalTorch(), ): """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" - # Find the best parameters - - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - run.model.eval() - - ... + # render prediction dataset + logger.info("Predicting on dataset %s", prediction_array_identifier) + predict(run.model, input_array_identifier, prediction_array_identifier, output_roi=roi, num_cpu_workers=num_cpu_workers, output_dtype=output_dtype compute_context=compute_context) + + # post-process the output + logger.info("Post-processing output to dataset %s", output_array_identifier) + post_processor = run.task.post_processor + post_processor.set_prediction(prediction_array_identifier) + post_processed_array = post_processor.process( + parameters, output_array_identifier + ) + + logger.info("Done") + return \ No newline at end of file diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index f28fafca4..24096261a 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -63,9 +63,7 @@ def is_best( else: return getattr(score, criterion) < previous_best_score - def get_overall_best( - self, dataset: "Dataset", criterion: str, higher_is_better: bool - ): + def get_overall_best(self, dataset: "Dataset", criterion: str): overall_best = None if self.best_scores: for _, parameter, _ in self.best_scores.keys(): @@ -78,7 +76,7 @@ def get_overall_best( overall_best = current_parameter_score else: if current_parameter_score: - if higher_is_better: + if self.higher_is_better(criterion): if current_parameter_score > overall_best: overall_best = current_parameter_score else: @@ -86,6 +84,31 @@ def get_overall_best( overall_best = current_parameter_score return overall_best + def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): + overall_best = None + overall_best_parameters = None + if self.best_scores: + for _, parameter, _ in self.best_scores.keys(): + score = self.best_scores[(dataset, parameter, criterion)] + if score is None: + overall_best = None + else: + _, current_parameter_score = score + if overall_best is None: + overall_best = current_parameter_score + overall_best_parameters = parameter + else: + if current_parameter_score: + if self.higher_is_better(criterion): + if current_parameter_score > overall_best: + overall_best = current_parameter_score + overall_best_parameters = parameter + else: + if current_parameter_score < overall_best: + overall_best = current_parameter_score + overall_best_parameters = parameter + return overall_best_parameters + def set_best(self, validation_scores: "ValidationScores") -> None: """ Find the best iteration for each dataset/post_processing_parameter/criterion diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 3cdc74456..2ec78db9b 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -32,7 +32,9 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier): + def process( + self, parameters, output_array_identifier + ): # TODO: will probably break with large arrays... output_array = ZarrArray.create_from_array_identifier( output_array_identifier, [axis for axis in self.prediction_array.axes if axis != "c"], diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index 8fba05687..23a23d62e 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -113,7 +113,7 @@ def get_best( best value in two seperate arrays. """ if "criteria" in data.coords.keys(): - if len(data.coords["criteria"].shape) == 1: + if len(data.coords["criteria"].shape) > 1: criteria_bests: List[Tuple[xr.DataArray, xr.DataArray]] = [] for criterion in data.coords["criteria"].values: if self.evaluation_scores.higher_is_better(criterion.item()): diff --git a/dacapo/predict.py b/dacapo/predict.py index 221dd219b..0f912bc8c 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -24,7 +24,8 @@ def predict( num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, -): + output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, +): # TODO: Add dtype argument # get the model's input and output size input_voxel_size = Coordinate(raw_array.voxel_size) diff --git a/dacapo/validate.py b/dacapo/validate.py index 26a193f0b..bb630e335 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -162,11 +162,7 @@ def validate_run( overall_best_scores = {} for criterion in run.validation_scores.criteria: overall_best_scores[criterion] = evaluator.get_overall_best( - validation_dataset, - criterion, - run.validation_scores.evaluation_scores.higher_is_better( - criterion - ), # TODO: should be in evaluator + validation_dataset, criterion ) for parameters in post_processor.enumerate_parameters(): From beeed60761f524ed8aae7442108b4714daca514e Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 12 Sep 2023 16:08:20 -0400 Subject: [PATCH 09/33] feat!: add cli for applying models --- dacapo/apply.py | 36 +++++++++++++++++++------------- dacapo/cli.py | 55 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 67 insertions(+), 24 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 0e831bba6..56f409cf0 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -28,12 +28,13 @@ def apply( run_name: str, - input_container: str, + input_container: str or Path, input_dataset: str, - output_path: str, + output_container: str or Path, validation_dataset: Optional[str or Dataset] = None, criterion: Optional[str] = "voi", iteration: Optional[int] = None, + parameters: Optional[PostProcessorParameters] = None, roi: Optional[Roi] = None, num_cpu_workers: int = 4, output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, @@ -42,7 +43,7 @@ def apply( """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" assert (validation_dataset is not None and isinstance(criterion, str)) or ( - iteration is not None + isinstance(iteration, int) ), "Either validation_dataset and criterion, or iteration must be provided." # retrieving run @@ -76,14 +77,15 @@ def apply( dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name ][0] logger.info("Finding best parameters for validation dataset %s", validation_dataset) - parameters = run.task.evaluator.get_overall_best_parameters( - validation_dataset, criterion - ) + if parameters is None: + parameters = run.task.evaluator.get_overall_best_parameters( + validation_dataset, criterion + ) # make array identifiers for input, predictions and outputs array_store = create_array_store() input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) - output_container = Path(output_path, Path(input_container).name) + output_container = Path(output_container, Path(input_container).name) prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}_{parameters}" ) @@ -100,7 +102,7 @@ def apply( return apply_run( run, parameters, - input_array_identifier + input_array_identifier, prediction_array_identifier, output_array_identifier, roi, @@ -125,15 +127,21 @@ def apply_run( # render prediction dataset logger.info("Predicting on dataset %s", prediction_array_identifier) - predict(run.model, input_array_identifier, prediction_array_identifier, output_roi=roi, num_cpu_workers=num_cpu_workers, output_dtype=output_dtype compute_context=compute_context) + predict( + run.model, + input_array_identifier, + prediction_array_identifier, + output_roi=roi, + num_cpu_workers=num_cpu_workers, + output_dtype=output_dtype, + compute_context=compute_context, + ) # post-process the output logger.info("Post-processing output to dataset %s", output_array_identifier) post_processor = run.task.post_processor post_processor.set_prediction(prediction_array_identifier) - post_processed_array = post_processor.process( - parameters, output_array_identifier - ) - + post_processed_array = post_processor.process(parameters, output_array_identifier) + logger.info("Done") - return \ No newline at end of file + return diff --git a/dacapo/cli.py b/dacapo/cli.py index 76a5e18e0..715532065 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -43,18 +43,53 @@ def validate(run_name, iteration): "-r", "--run", required=True, type=str, help="The name of the run to use." ) @click.option( - "-i", - "--iteration", + "-ic", + "--input_container", required=True, - type=int, - help="The iteration weights and parameters to use.", + type=click.Path(exists=True, file_okay=False), ) +@click.option("-id", "--input_dataset", required=True, type=str) @click.option( - "-r", - "--dataset", - required=True, + "-oc", "--output_container", required=True, type=click.Path(file_okay=False) +) +@click.option("-vd", "--validation_dataset", type=str, default=None) +@click.option("-c", "--criterion", default="voi") +@click.option("-i", "--iteration", type=int, default=None) +@click.option("-p", "--parameters", type=str, default=None) +@click.option( + "-roi", + "--roi", type=str, - help="The name of the dataset to apply the run to.", + required=False, + help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -def apply(run_name, iteration, dataset_name): - dacapo.apply(run_name, iteration, dataset_name) +@click.option("-w", "--num_cpu_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +def apply( + run_name: str, + input_container: str, + input_dataset: str, + output_path: str, + validation_dataset: Optional[str or Dataset] = None, + criterion: Optional[str] = "voi", + iteration: Optional[int] = None, + parameters: Optional[PostProcessorParameters] = None, + roi: Optional[Roi] = None, + num_cpu_workers: int = 4, + output_dtype: Optional[np.dtype or torch.dtype or str] = np.uint8, +): + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + dacapo.apply( + run_name, + input_container, + input_dataset, + output_path, + validation_dataset, + criterion, + iteration, + parameters, + roi, + num_cpu_workers, + output_dtype, + ) From 8553b5ae60e4887a0f1cc1098d2f9aa35b768f1b Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 12 Sep 2023 16:09:45 -0400 Subject: [PATCH 10/33] feat!: add cli for applying models --- dacapo/cli.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dacapo/cli.py b/dacapo/cli.py index 715532065..d0a57a724 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -1,6 +1,15 @@ +from typing import Optional + +from funlib.geometry import Roi +import numpy as np import dacapo import click import logging +from dacapo.experiments.datasplits.datasets.dataset import Dataset + +from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( + PostProcessorParameters, +) @click.group() @@ -76,7 +85,7 @@ def apply( parameters: Optional[PostProcessorParameters] = None, roi: Optional[Roi] = None, num_cpu_workers: int = 4, - output_dtype: Optional[np.dtype or torch.dtype or str] = np.uint8, + output_dtype: Optional[np.dtype or str] = np.uint8, ): if isinstance(output_dtype, str): output_dtype = np.dtype(output_dtype) From 273ecb7bb430b0f8bbda55632d913c87bd9d44f9 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 12 Sep 2023 16:21:59 -0400 Subject: [PATCH 11/33] feat: apply.py roi --- dacapo/apply.py | 28 +++++++++++++++++++++------- dacapo/cli.py | 14 ++++++-------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 56f409cf0..1bc0ed30a 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -1,6 +1,6 @@ import logging from typing import Optional -from funlib.geometry import Roi +from funlib.geometry import Roi, Coordinate import numpy as np from dacapo.experiments.datasplits.datasets.dataset import Dataset @@ -28,19 +28,33 @@ def apply( run_name: str, - input_container: str or Path, + input_container: Path or str, input_dataset: str, - output_container: str or Path, - validation_dataset: Optional[str or Dataset] = None, + output_container: Path or str, + validation_dataset: Optional[Dataset or str] = None, criterion: Optional[str] = "voi", iteration: Optional[int] = None, - parameters: Optional[PostProcessorParameters] = None, - roi: Optional[Roi] = None, + parameters: Optional[PostProcessorParameters or str] = None, + roi: Optional[Roi or str] = None, num_cpu_workers: int = 4, - output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, + output_dtype: Optional[np.dtype or str] = np.uint8, compute_context: ComputeContext = LocalTorch(), ): """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" + if isinstance(output_dtype, str): + output_dtype = np.dtype(output_dtype) + + if isinstance(roi, str): + start, end = zip( + *[ + tuple(int(coord) for coord in axis.split(":")) + for axis in roi.strip("[]").split(",") + ] + ) + roi = Roi( + Coordinate(start), + Coordinate(end) - Coordinate(start), + ) assert (validation_dataset is not None and isinstance(criterion, str)) or ( isinstance(iteration, int) diff --git a/dacapo/cli.py b/dacapo/cli.py index d0a57a724..dac78f848 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -76,24 +76,22 @@ def validate(run_name, iteration): @click.option("-dt", "--output_dtype", type=str, default="uint8") def apply( run_name: str, - input_container: str, + input_container: Path or str, input_dataset: str, - output_path: str, - validation_dataset: Optional[str or Dataset] = None, + output_container: Path or str, + validation_dataset: Optional[Dataset or str] = None, criterion: Optional[str] = "voi", iteration: Optional[int] = None, - parameters: Optional[PostProcessorParameters] = None, - roi: Optional[Roi] = None, + parameters: Optional[PostProcessorParameters or str] = None, + roi: Optional[Roi or str] = None, num_cpu_workers: int = 4, output_dtype: Optional[np.dtype or str] = np.uint8, ): - if isinstance(output_dtype, str): - output_dtype = np.dtype(output_dtype) dacapo.apply( run_name, input_container, input_dataset, - output_path, + output_container, validation_dataset, criterion, iteration, From 2622b940a25d7aa63d56d0cfe7d766d06548a3ac Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 12 Sep 2023 22:46:43 -0400 Subject: [PATCH 12/33] feat: postprocess parameter parsing --- dacapo/apply.py | 70 ++++++++++++++++++++++--------- dacapo/cli.py | 27 ++++-------- dacapo/store/local_array_store.py | 2 +- 3 files changed, 61 insertions(+), 38 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 1bc0ed30a..119feeb0b 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -3,24 +3,22 @@ from funlib.geometry import Roi, Coordinate import numpy as np from dacapo.experiments.datasplits.datasets.dataset import Dataset +from dacapo.experiments.run import Run from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( PostProcessorParameters, ) +import dacapo.experiments.tasks.post_processors as post_processors from dacapo.store.array_store import LocalArrayIdentifier from dacapo.predict import predict from dacapo.compute_context import LocalTorch, ComputeContext -from dacapo.experiments import Run, ValidationIterationScores from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store import ( - create_array_store, create_config_store, create_stats_store, create_weights_store, ) -import torch - from pathlib import Path logger = logging.getLogger(__name__) @@ -30,7 +28,7 @@ def apply( run_name: str, input_container: Path or str, input_dataset: str, - output_container: Path or str, + output_path: Path or str, validation_dataset: Optional[Dataset or str] = None, criterion: Optional[str] = "voi", iteration: Optional[int] = None, @@ -66,13 +64,12 @@ def apply( run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - # read in previous training/validation stats TODO: is this necessary? - - stats_store = create_stats_store() - run.training_stats = stats_store.retrieve_training_stats(run_name) - run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( - run_name - ) + # # read in previous training/validation stats TODO: is this necessary? + # stats_store = create_stats_store() + # run.training_stats = stats_store.retrieve_training_stats(run_name) + # run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( + # run_name + # ) # create weights store weights_store = create_weights_store() @@ -95,13 +92,48 @@ def apply( parameters = run.task.evaluator.get_overall_best_parameters( validation_dataset, criterion ) + assert ( + parameters is not None + ), "Unable to retieve parameters. Parameters must be provided explicitly." + + elif isinstance(parameters, str): + try: + post_processor_name = parameters.split("(")[0] + post_processor_kwargs = parameters.split("(")[1].strip(")").split(",") + post_processor_kwargs = { + key.strip(): value.strip() + for key, value in [arg.split("=") for arg in post_processor_kwargs] + } + for key, value in post_processor_kwargs.items(): + if value.isdigit(): + post_processor_kwargs[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + post_processor_kwargs[key] = float(value) + except: + raise ValueError( + f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'" + ) + try: + parameters = getattr(post_processors, post_processor_name)( + **post_processor_kwargs + ) + except Exception as e: + logger.error( + f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.", + exc_info=True, + ) + raise e + + assert isinstance( + parameters, PostProcessorParameters + ), "Parameters must be parsable to a PostProcessorParameters object." # make array identifiers for input, predictions and outputs - array_store = create_array_store() input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) - output_container = Path(output_container, Path(input_container).name) + input_array = ZarrArray.open_from_array_identifier(input_array_identifier) + output_container = Path(output_path, Path(input_container).name) prediction_array_identifier = LocalArrayIdentifier( - output_container, f"prediction_{run_name}_{iteration}_{parameters}" + output_container, f"prediction_{run_name}_{iteration}" ) output_array_identifier = LocalArrayIdentifier( output_container, f"output_{run_name}_{iteration}_{parameters}" @@ -116,7 +148,7 @@ def apply( return apply_run( run, parameters, - input_array_identifier, + input_array, prediction_array_identifier, output_array_identifier, roi, @@ -129,12 +161,12 @@ def apply( def apply_run( run: Run, parameters: PostProcessorParameters, - input_array_identifier: LocalArrayIdentifier, + input_array: ZarrArray, prediction_array_identifier: LocalArrayIdentifier, output_array_identifier: LocalArrayIdentifier, roi: Optional[Roi] = None, num_cpu_workers: int = 4, - output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, + output_dtype: Optional[np.dtype] = np.uint8, compute_context: ComputeContext = LocalTorch(), ): """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" @@ -143,7 +175,7 @@ def apply_run( logger.info("Predicting on dataset %s", prediction_array_identifier) predict( run.model, - input_array_identifier, + input_array, prediction_array_identifier, output_roi=roi, num_cpu_workers=num_cpu_workers, diff --git a/dacapo/cli.py b/dacapo/cli.py index dac78f848..92867d46c 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -1,15 +1,8 @@ from typing import Optional -from funlib.geometry import Roi -import numpy as np import dacapo import click import logging -from dacapo.experiments.datasplits.datasets.dataset import Dataset - -from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( - PostProcessorParameters, -) @click.group() @@ -49,7 +42,7 @@ def validate(run_name, iteration): @cli.command() @click.option( - "-r", "--run", required=True, type=str, help="The name of the run to use." + "-r", "--run_name", required=True, type=str, help="The name of the run to use." ) @click.option( "-ic", @@ -58,9 +51,7 @@ def validate(run_name, iteration): type=click.Path(exists=True, file_okay=False), ) @click.option("-id", "--input_dataset", required=True, type=str) -@click.option( - "-oc", "--output_container", required=True, type=click.Path(file_okay=False) -) +@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) @click.option("-vd", "--validation_dataset", type=str, default=None) @click.option("-c", "--criterion", default="voi") @click.option("-i", "--iteration", type=int, default=None) @@ -76,22 +67,22 @@ def validate(run_name, iteration): @click.option("-dt", "--output_dtype", type=str, default="uint8") def apply( run_name: str, - input_container: Path or str, + input_container: str, input_dataset: str, - output_container: Path or str, - validation_dataset: Optional[Dataset or str] = None, + output_path: str, + validation_dataset: Optional[str] = None, criterion: Optional[str] = "voi", iteration: Optional[int] = None, - parameters: Optional[PostProcessorParameters or str] = None, - roi: Optional[Roi or str] = None, + parameters: Optional[str] = None, + roi: Optional[str] = None, num_cpu_workers: int = 4, - output_dtype: Optional[np.dtype or str] = np.uint8, + output_dtype: Optional[str] = "uint8", ): dacapo.apply( run_name, input_container, input_dataset, - output_container, + output_path, validation_dataset, criterion, iteration, diff --git a/dacapo/store/local_array_store.py b/dacapo/store/local_array_store.py index c1581fc7b..9265c8f67 100644 --- a/dacapo/store/local_array_store.py +++ b/dacapo/store/local_array_store.py @@ -55,7 +55,7 @@ def validation_input_arrays( If we don't store these we would have to look up the datasplit config and figure out where to find the inputs for each run. If we write the data then we don't need to search for it. - This convenience comes at the cost of some extra memory usage. + This convenience comes at the cost of extra memory usage. #TODO: FIX THIS - refactor """ container = self.validation_container(run_name).container From bd192e5fe6eb16f6ce934d0ec6264a4ffe4a3868 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 13 Sep 2023 17:36:26 -0400 Subject: [PATCH 13/33] feat: overwrite option for predict & postprocess --- dacapo/apply.py | 27 ++++++++++--------- dacapo/cli.py | 2 +- .../datasplits/datasets/arrays/zarr_array.py | 2 ++ .../post_processors/argmax_post_processor.py | 3 ++- .../post_processors/dummy_post_processor.py | 2 +- .../tasks/post_processors/post_processor.py | 1 + .../threshold_post_processor.py | 2 ++ .../watershed_post_processor.py | 3 ++- dacapo/gp/dacapo_array_source.py | 2 +- dacapo/predict.py | 20 ++++++++++---- 10 files changed, 42 insertions(+), 22 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 119feeb0b..42379b498 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -2,6 +2,7 @@ from typing import Optional from funlib.geometry import Roi, Coordinate import numpy as np +from dacapo.experiments.datasplits.datasets.arrays.array import Array from dacapo.experiments.datasplits.datasets.dataset import Dataset from dacapo.experiments.run import Run @@ -15,7 +16,6 @@ from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store import ( create_config_store, - create_stats_store, create_weights_store, ) @@ -34,9 +34,10 @@ def apply( iteration: Optional[int] = None, parameters: Optional[PostProcessorParameters or str] = None, roi: Optional[Roi or str] = None, - num_cpu_workers: int = 4, + num_cpu_workers: int = 30, output_dtype: Optional[np.dtype or str] = np.uint8, compute_context: ComputeContext = LocalTorch(), + overwrite: bool = True, ): """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" if isinstance(output_dtype, str): @@ -64,13 +65,6 @@ def apply( run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - # # read in previous training/validation stats TODO: is this necessary? - # stats_store = create_stats_store() - # run.training_stats = stats_store.retrieve_training_stats(run_name) - # run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( - # run_name - # ) - # create weights store weights_store = create_weights_store() @@ -131,6 +125,9 @@ def apply( # make array identifiers for input, predictions and outputs input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) input_array = ZarrArray.open_from_array_identifier(input_array_identifier) + roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( + input_array.roi + ) output_container = Path(output_path, Path(input_container).name) prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" @@ -155,21 +152,24 @@ def apply( num_cpu_workers, output_dtype, compute_context, + overwrite, ) def apply_run( run: Run, parameters: PostProcessorParameters, - input_array: ZarrArray, + input_array: Array, prediction_array_identifier: LocalArrayIdentifier, output_array_identifier: LocalArrayIdentifier, roi: Optional[Roi] = None, - num_cpu_workers: int = 4, + num_cpu_workers: int = 30, output_dtype: Optional[np.dtype] = np.uint8, compute_context: ComputeContext = LocalTorch(), + overwrite: bool = True, ): """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" + run.model.eval() # render prediction dataset logger.info("Predicting on dataset %s", prediction_array_identifier) @@ -181,13 +181,16 @@ def apply_run( num_cpu_workers=num_cpu_workers, output_dtype=output_dtype, compute_context=compute_context, + overwrite=overwrite, ) # post-process the output logger.info("Post-processing output to dataset %s", output_array_identifier) post_processor = run.task.post_processor post_processor.set_prediction(prediction_array_identifier) - post_processed_array = post_processor.process(parameters, output_array_identifier) + post_processed_array = post_processor.process( + parameters, output_array_identifier, overwrite=overwrite + ) logger.info("Done") return diff --git a/dacapo/cli.py b/dacapo/cli.py index 92867d46c..f8f06db54 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -75,7 +75,7 @@ def apply( iteration: Optional[int] = None, parameters: Optional[str] = None, roi: Optional[str] = None, - num_cpu_workers: int = 4, + num_cpu_workers: int = 30, output_dtype: Optional[str] = "uint8", ): dacapo.apply( diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index cadfcb6cd..58d6a1825 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -116,6 +116,7 @@ def create_from_array_identifier( dtype, write_size=None, name=None, + overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that @@ -145,6 +146,7 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, + delete=overwrite, ) zarr_dataset = zarr_container[array_identifier.dataset] zarr_dataset.attrs["offset"] = ( diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 709d1de34..799f2651e 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -19,7 +19,7 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, overwrite: bool = False): output_array = ZarrArray.create_from_array_identifier( output_array_identifier, [dim for dim in self.prediction_array.axes if dim != "c"], @@ -27,6 +27,7 @@ def process(self, parameters, output_array_identifier): None, self.prediction_array.voxel_size, np.uint8, + overwrite=overwrite, ) output_array[self.prediction_array.roi] = np.argmax( diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 5a2c7810a..ddb249539 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: def set_prediction(self, prediction_array): pass - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, overwrite: bool = False): # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index 020361cb9..eada06c0a 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -33,6 +33,7 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", + overwrite: "bool", ) -> "Array": """Convert predictions into the final output.""" pass diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 67ffdd066..32bf4cfc0 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -28,6 +28,7 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", + overwrite: bool = False, ) -> ZarrArray: # TODO: Investigate Liskov substitution princple and whether it is a problem here # OOP theory states the super class should always be replaceable with its subclasses @@ -47,6 +48,7 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, + overwrite=overwrite, ) output_array[self.prediction_array.roi] = ( diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 2ec78db9b..dd138e4be 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -33,7 +33,7 @@ def set_prediction(self, prediction_array_identifier): ) def process( - self, parameters, output_array_identifier + self, parameters, output_array_identifier, overwrite: bool = False ): # TODO: will probably break with large arrays... output_array = ZarrArray.create_from_array_identifier( output_array_identifier, @@ -42,6 +42,7 @@ def process( None, self.prediction_array.voxel_size, np.uint64, + overwrite=overwrite, ) # if a previous segmentation is provided, it must have a "grid graph" # in its metadata. diff --git a/dacapo/gp/dacapo_array_source.py b/dacapo/gp/dacapo_array_source.py index c00b2d504..f9cf1446c 100644 --- a/dacapo/gp/dacapo_array_source.py +++ b/dacapo/gp/dacapo_array_source.py @@ -26,7 +26,7 @@ class DaCapoArraySource(gp.BatchProvider): def __init__(self, array: Array, key: gp.ArrayKey): self.array = array self.array_spec = ArraySpec( - roi=self.array.roi, voxel_size=self.array.voxel_size + roi=self.array.roi, voxel_size=self.array.voxel_size, dtype=self.array.dtype ) self.key = key diff --git a/dacapo/predict.py b/dacapo/predict.py index 0f912bc8c..15045eb1d 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -24,8 +24,9 @@ def predict( num_cpu_workers: int = 4, compute_context: ComputeContext = LocalTorch(), output_roi: Optional[Roi] = None, - output_dtype: Optional[np.dtype or torch.dtype] = np.uint8, -): # TODO: Add dtype argument + output_dtype: Optional[np.dtype] = np.uint8, + overwrite: bool = False, +): # get the model's input and output size input_voxel_size = Coordinate(raw_array.voxel_size) @@ -57,7 +58,8 @@ def predict( output_roi, model.num_out_channels, output_voxel_size, - np.float32, + output_dtype, + overwrite=overwrite, ) # create gunpowder keys @@ -69,6 +71,7 @@ def predict( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) + pipeline += gp.Normalize(raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) @@ -85,7 +88,9 @@ def predict( outputs={0: prediction}, array_specs={ prediction: gp.ArraySpec( - roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32 + roi=prediction_roi, + voxel_size=output_voxel_size, + dtype=np.float32, # assumes network output is float32 ) }, spawn_subprocess=False, @@ -99,12 +104,17 @@ def predict( # raw: (c, d, h, w) # prediction: (c, d, h, w) + # convert to uint8 if necessary: + if output_dtype == np.uint8: + pipeline += gp.IntensityScaleShift(prediction, scale=255.0, shift=0.0) + pipeline += gp.AsType(prediction, output_dtype) + # write to zarr pipeline += gp.ZarrWrite( {prediction: prediction_array_identifier.dataset}, prediction_array_identifier.container.parent, prediction_array_identifier.container.name, - dataset_dtypes={prediction: np.float32}, + dataset_dtypes={prediction: output_dtype}, ) # create reference batch request From 6421bd118b309a1dde2e29a8ef14621ffede7c3f Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 14 Sep 2023 14:38:08 -0400 Subject: [PATCH 14/33] bugfix!?: removed odd zarr creation already handled by funlib.persistence.prepare_ds --- .../datasplits/datasets/arrays/zarr_array.py | 56 +++++-------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 58d6a1825..b902df706 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -137,48 +137,20 @@ def create_from_array_identifier( write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape))) zarr_container = zarr.open(array_identifier.container, "a") - try: - funlib.persistence.prepare_ds( - f"{array_identifier.container}", - array_identifier.dataset, - roi, - voxel_size, - dtype, - num_channels=num_channels, - write_size=write_size, - delete=overwrite, - ) - zarr_dataset = zarr_container[array_identifier.dataset] - zarr_dataset.attrs["offset"] = ( - roi.offset[::-1] - if array_identifier.container.name.endswith("n5") - else roi.offset - ) - zarr_dataset.attrs["resolution"] = ( - voxel_size[::-1] - if array_identifier.container.name.endswith("n5") - else voxel_size - ) - zarr_dataset.attrs["axes"] = ( - axes[::-1] if array_identifier.container.name.endswith("n5") else axes - ) - except zarr.errors.ContainsArrayError: - zarr_dataset = zarr_container[array_identifier.dataset] - assert ( - tuple(zarr_dataset.attrs["offset"]) == roi.offset - ), f"{zarr_dataset.attrs['offset']}, {roi.offset}" - assert ( - tuple(zarr_dataset.attrs["resolution"]) == voxel_size - ), f"{zarr_dataset.attrs['resolution']}, {voxel_size}" - assert tuple(zarr_dataset.attrs["axes"]) == tuple( - axes - ), f"{zarr_dataset.attrs['axes']}, {axes}" - assert ( - zarr_dataset.shape - == ((num_channels,) if num_channels is not None else ()) - + roi.shape / voxel_size - ), f"{zarr_dataset.shape}, {((num_channels,) if num_channels is not None else ()) + roi.shape / voxel_size}" - zarr_dataset[:] = np.zeros(zarr_dataset.shape, dtype) + funlib.persistence.prepare_ds( + f"{array_identifier.container}", + array_identifier.dataset, + roi, + voxel_size, + dtype, + num_channels=num_channels, + write_size=write_size, + delete=overwrite, + ) + zarr_dataset = zarr_container[array_identifier.dataset] + zarr_dataset.attrs["axes"] = ( + axes[::-1] if array_identifier.container.name.endswith("n5") else axes + ) zarr_array = cls.__new__(cls) zarr_array.file_name = array_identifier.container From 81986484d974bebfe9bbd7de7690da0a0b650602 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 15 Sep 2023 09:09:27 -0400 Subject: [PATCH 15/33] !: attempts to fix roi mismatch failed. committing for reference --- dacapo/apply.py | 4 +- .../datasplits/datasets/arrays/zarr_array.py | 3 +- .../tasks/post_processors/post_processor.py | 1 + .../watershed_post_processor.py | 103 ++++++++++-------- dacapo/predict.py | 13 ++- 5 files changed, 75 insertions(+), 49 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 42379b498..802f96705 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -188,8 +188,8 @@ def apply_run( logger.info("Post-processing output to dataset %s", output_array_identifier) post_processor = run.task.post_processor post_processor.set_prediction(prediction_array_identifier) - post_processed_array = post_processor.process( - parameters, output_array_identifier, overwrite=overwrite + post_processor.process( + parameters, output_array_identifier, overwrite=overwrite, blockwise=True ) logger.info("Done") diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index b902df706..3153fef87 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -149,7 +149,8 @@ def create_from_array_identifier( ) zarr_dataset = zarr_container[array_identifier.dataset] zarr_dataset.attrs["axes"] = ( - axes[::-1] if array_identifier.container.name.endswith("n5") else axes + # axes[::-1] if array_identifier.container.name.endswith("n5") else axes + axes ) zarr_array = cls.__new__(cls) diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index eada06c0a..4e4102d6b 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -34,6 +34,7 @@ def process( parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", overwrite: "bool", + blockwise: "bool", ) -> "Array": """Convert predictions into the final output.""" pass diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index dd138e4be..439b36a2f 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -33,47 +33,64 @@ def set_prediction(self, prediction_array_identifier): ) def process( - self, parameters, output_array_identifier, overwrite: bool = False + self, + parameters, + output_array_identifier, + overwrite: bool = False, + blockwise: bool = False, ): # TODO: will probably break with large arrays... - output_array = ZarrArray.create_from_array_identifier( - output_array_identifier, - [axis for axis in self.prediction_array.axes if axis != "c"], - self.prediction_array.roi, - None, - self.prediction_array.voxel_size, - np.uint64, - overwrite=overwrite, - ) - # if a previous segmentation is provided, it must have a "grid graph" - # in its metadata. - pred_data = self.prediction_array[self.prediction_array.roi] - affs = pred_data[: len(self.offsets)].astype(np.float64) - segmentation = mws.agglom( - affs - parameters.bias, - self.offsets, - ) - # filter fragments - average_affs = np.mean(affs, axis=0) - - filtered_fragments = [] - - fragment_ids = np.unique(segmentation) - - for fragment, mean in zip( - fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids) - ): - if mean < parameters.bias: - filtered_fragments.append(fragment) - - filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) - replace = np.zeros_like(filtered_fragments) - - # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input - if filtered_fragments.size > 0: - segmentation = npi.remap( - segmentation.flatten(), filtered_fragments, replace - ).reshape(segmentation.shape) - - output_array[self.prediction_array.roi] = segmentation - - return output_array + if not blockwise: + output_array = ZarrArray.create_from_array_identifier( + output_array_identifier, + [axis for axis in self.prediction_array.axes if axis != "c"], + self.prediction_array.roi, + None, + self.prediction_array.voxel_size, + np.uint64, + overwrite=overwrite, + ) + # if a previous segmentation is provided, it must have a "grid graph" + # in its metadata. + # pred_data = self.prediction_array[self.prediction_array.roi] + # affs = pred_data[: len(self.offsets)].astype( + # np.float64 + # ) # TODO: shouldn't need to be float64 + affs = self.prediction_array[self.prediction_array.roi][: len(self.offsets)] + if affs.dtype == np.uint8: + affs = affs.astype(np.float64) / 255.0 + else: + affs = affs.astype(np.float64) + segmentation = mws.agglom( + affs - parameters.bias, + self.offsets, + ) + # filter fragments + average_affs = np.mean(affs, axis=0) + + filtered_fragments = [] + + fragment_ids = np.unique(segmentation) + + for fragment, mean in zip( + fragment_ids, + measurements.mean(average_affs, segmentation, fragment_ids), + ): + if mean < parameters.bias: + filtered_fragments.append(fragment) + + filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype) + replace = np.zeros_like(filtered_fragments) + + # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input + if filtered_fragments.size > 0: + segmentation = npi.remap( + segmentation.flatten(), filtered_fragments, replace + ).reshape(segmentation.shape) + + output_array[self.prediction_array.roi] = segmentation + + return output_array + else: + raise NotImplementedError( + "Blockwise processing not yet implemented." + ) # TODO: add rusty mws diff --git a/dacapo/predict.py b/dacapo/predict.py index 15045eb1d..aeb3df173 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -51,7 +51,12 @@ def predict( logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi) # prepare prediction dataset - axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] + if raw_array.file_name.endswith( + "zarr" + ) == prediction_array_identifier.container.name.endswith("zarr"): + axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] + else: + axes = ["c"] + [axis for axis in raw_array.axes[::-1] if axis != "c"] ZarrArray.create_from_array_identifier( prediction_array_identifier, axes, @@ -79,7 +84,7 @@ def predict( # raw: (1, c, d, h, w) gt_padding = (output_size - output_roi.shape) % output_size - prediction_roi = output_roi.grow(gt_padding) + prediction_roi = output_roi.grow(gt_padding) # TODO: are we sure this makes sense? # TODO: Add cache node? # predict pipeline += gp_torch.Predict( @@ -106,7 +111,9 @@ def predict( # convert to uint8 if necessary: if output_dtype == np.uint8: - pipeline += gp.IntensityScaleShift(prediction, scale=255.0, shift=0.0) + pipeline += gp.IntensityScaleShift( + prediction, scale=255.0, shift=0.0 + ) # assumes float32 is [0,1] pipeline += gp.AsType(prediction, output_dtype) # write to zarr From ed0d4add5ae3c9d716f5fdd09464efb2420867b8 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 15 Sep 2023 12:47:02 -0400 Subject: [PATCH 16/33] bugfix: prediction works with zarrs bug seems specific to n5 --- dacapo/apply.py | 6 +- .../datasplits/datasets/arrays/zarr_array.py | 57 ++++++++++++++----- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 802f96705..b33cffe46 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -38,6 +38,7 @@ def apply( output_dtype: Optional[np.dtype or str] = np.uint8, compute_context: ComputeContext = LocalTorch(), overwrite: bool = True, + file_format: str = "zarr", ): """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" if isinstance(output_dtype, str): @@ -128,7 +129,10 @@ def apply( roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( input_array.roi ) - output_container = Path(output_path, Path(input_container).name) + output_container = Path( + output_path, + "".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}", + ) prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" ) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 3153fef87..58d6a1825 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -137,21 +137,48 @@ def create_from_array_identifier( write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape))) zarr_container = zarr.open(array_identifier.container, "a") - funlib.persistence.prepare_ds( - f"{array_identifier.container}", - array_identifier.dataset, - roi, - voxel_size, - dtype, - num_channels=num_channels, - write_size=write_size, - delete=overwrite, - ) - zarr_dataset = zarr_container[array_identifier.dataset] - zarr_dataset.attrs["axes"] = ( - # axes[::-1] if array_identifier.container.name.endswith("n5") else axes - axes - ) + try: + funlib.persistence.prepare_ds( + f"{array_identifier.container}", + array_identifier.dataset, + roi, + voxel_size, + dtype, + num_channels=num_channels, + write_size=write_size, + delete=overwrite, + ) + zarr_dataset = zarr_container[array_identifier.dataset] + zarr_dataset.attrs["offset"] = ( + roi.offset[::-1] + if array_identifier.container.name.endswith("n5") + else roi.offset + ) + zarr_dataset.attrs["resolution"] = ( + voxel_size[::-1] + if array_identifier.container.name.endswith("n5") + else voxel_size + ) + zarr_dataset.attrs["axes"] = ( + axes[::-1] if array_identifier.container.name.endswith("n5") else axes + ) + except zarr.errors.ContainsArrayError: + zarr_dataset = zarr_container[array_identifier.dataset] + assert ( + tuple(zarr_dataset.attrs["offset"]) == roi.offset + ), f"{zarr_dataset.attrs['offset']}, {roi.offset}" + assert ( + tuple(zarr_dataset.attrs["resolution"]) == voxel_size + ), f"{zarr_dataset.attrs['resolution']}, {voxel_size}" + assert tuple(zarr_dataset.attrs["axes"]) == tuple( + axes + ), f"{zarr_dataset.attrs['axes']}, {axes}" + assert ( + zarr_dataset.shape + == ((num_channels,) if num_channels is not None else ()) + + roi.shape / voxel_size + ), f"{zarr_dataset.shape}, {((num_channels,) if num_channels is not None else ()) + roi.shape / voxel_size}" + zarr_dataset[:] = np.zeros(zarr_dataset.shape, dtype) zarr_array = cls.__new__(cls) zarr_array.file_name = array_identifier.container From ae231794335e36f61810126e27a83dc50a02aa97 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 19 Sep 2023 23:20:45 -0400 Subject: [PATCH 17/33] feat: backlogged updates and wip --- .../datasplits/datasets/arrays/zarr_array.py | 4 +- dacapo/experiments/tasks/affinities_task.py | 5 ++- .../tasks/affinities_task_config.py | 13 ++++++ .../watershed_post_processor.py | 4 +- .../tasks/predictors/affinities_predictor.py | 17 +++++-- .../trainers/gp_augments/__init__.py | 1 + .../gp_augments/gaussian_noise_config.py | 22 ++++++++++ .../experiments/trainers/gunpowder_trainer.py | 44 ++++++++++++++----- .../trainers/gunpowder_trainer_config.py | 4 +- dacapo/experiments/training_stats.py | 4 +- dacapo/gp/dacapo_create_target.py | 4 +- dacapo/predict.py | 7 +-- dacapo/train.py | 19 +++++--- dacapo/validate.py | 19 ++++---- 14 files changed, 125 insertions(+), 42 deletions(-) create mode 100644 dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 58d6a1825..3a75c6546 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -81,7 +81,7 @@ def writable(self) -> bool: @property def dtype(self) -> Any: - return self.data.dtype + return self.data.dtype # TODO: why not use self._daisy_array.dtype? @property def num_channels(self) -> Optional[int]: @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]: return [ax for ax in self.axes if ax not in set(["c", "b"])] @property - def data(self) -> Any: + def data(self) -> Any: # TODO: why not use self._daisy_array.data? zarr_container = zarr.open(str(self.file_name)) return zarr_container[self.dataset] diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index c1014fd02..818422b2b 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -12,7 +12,10 @@ def __init__(self, task_config): """Create a `DummyTask` from a `DummyTaskConfig`.""" self.predictor = AffinitiesPredictor( - neighborhood=task_config.neighborhood, lsds=task_config.lsds + neighborhood=task_config.neighborhood, + lsds=task_config.lsds, + num_voxels=task_config.num_voxels, + downsample_lsds=task_config.downsample_lsds, ) self.loss = AffinitiesLoss(len(task_config.neighborhood)) self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood) diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index d4b2c6199..264d002c0 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -30,3 +30,16 @@ class AffinitiesTaskConfig(TaskConfig): "It has been shown that lsds as an auxiliary task can help affinity predictions." }, ) + num_voxels: int = attr.ib( + default=20, + metadata={ + "help_text": "The number of voxels to use for the gaussian sigma when computing lsds." + }, + ) + downsample_lsds: int = attr.ib( + default=1, + metadata={ + "help_text": "The amount to downsample the lsds. " + "This is useful for speeding up training and inference." + }, + ) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 439b36a2f..307806772 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -24,7 +24,9 @@ def enumerate_parameters(self): """Enumerate all possible parameters of this post-processor. Should return instances of ``PostProcessorParameters``.""" - for i, bias in enumerate([0.1, 0.3, 0.5, 0.7, 0.9]): + for i, bias in enumerate( + [0.1, 0.3, 0.5, 0.7, 0.9] + ): # TODO: add this to the config yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 81efb2375..0eaf61c67 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -17,9 +17,16 @@ class AffinitiesPredictor(Predictor): - def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): + def __init__( + self, + neighborhood: List[Coordinate], + lsds: bool = True, + num_voxels: int = 20, + downsample_lsds: int = 1, + ): self.neighborhood = neighborhood self.lsds = lsds + self.num_voxels = num_voxels if lsds: self._extractor = None if self.dims == 2: @@ -30,12 +37,15 @@ def __init__(self, neighborhood: List[Coordinate], lsds: bool = True): raise ValueError( f"Cannot compute lsds on volumes with {self.dims} dimensions" ) + self.downsample_lsds = downsample_lsds else: self.num_lsds = 0 def extractor(self, voxel_size): if self._extractor is None: - self._extractor = LsdExtractor(self.sigma(voxel_size)) + self._extractor = LsdExtractor( + self.sigma(voxel_size), downsample=self.downsample_lsds + ) return self._extractor @@ -45,8 +55,7 @@ def dims(self): def sigma(self, voxel_size): voxel_dist = max(voxel_size) # arbitrarily chosen - num_voxels = 10 # arbitrarily chosen - sigma = voxel_dist * num_voxels + sigma = voxel_dist * self.num_voxels # arbitrarily chosen return Coordinate((sigma,) * self.dims) def lsd_pad(self, voxel_size): diff --git a/dacapo/experiments/trainers/gp_augments/__init__.py b/dacapo/experiments/trainers/gp_augments/__init__.py index 0c93d4603..7e901c76f 100644 --- a/dacapo/experiments/trainers/gp_augments/__init__.py +++ b/dacapo/experiments/trainers/gp_augments/__init__.py @@ -4,3 +4,4 @@ from .gamma_config import GammaAugmentConfig from .intensity_config import IntensityAugmentConfig from .intensity_scale_shift_config import IntensityScaleShiftAugmentConfig +from .gaussian_noise_config import GaussianNoiseAugmentConfig diff --git a/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py b/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py new file mode 100644 index 000000000..ab3694d09 --- /dev/null +++ b/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py @@ -0,0 +1,22 @@ +from .augment_config import AugmentConfig + +import gunpowder as gp + +import attr + + +@attr.s +class GaussianNoiseAugmentConfig(AugmentConfig): + mean: float = attr.ib( + metadata={"help_text": "The mean of the gaussian noise to apply to your data."}, + default=0.0, + ) + var: float = attr.ib( + metadata={"help_text": "The variance of the gaussian noise."}, + default=0.05, + ) + + def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): + return gp.NoiseAugment( + array=raw_key, mode="gaussian", mean=self.mean, var=self.var + ) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index ab6fa7603..24cfee7dd 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -37,6 +37,8 @@ def __init__(self, trainer_config): self.print_profiling = 100 self.snapshot_iteration = trainer_config.snapshot_interval self.min_masked = trainer_config.min_masked + self.reject_probability = trainer_config.reject_probability + self.weighted_reject = trainer_config.weighted_reject self.augments = trainer_config.augments self.mask_integral_downsample_factor = 4 @@ -88,7 +90,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): # Get source nodes dataset_sources = [] weights = [] - for dataset in datasets: + for dataset in datasets: # TODO: add automatic resampling? weights.append(dataset.weight) assert isinstance(dataset.weight, int), dataset @@ -141,24 +143,44 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) ) - dataset_source += gp.Reject(mask_placeholder, 1e-6) - for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) + if self.weighted_reject: + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( # TODO: could we add this above reject, and use weights to determine if we should reject? + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) - # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=dataset_weight_key, - mask_key=mask_key, - ) + dataset_source += gp.Reject( + mask=dataset_weight_key, + min_masked=self.min_masked, + reject_probability=self.reject_probability, + ) + else: + dataset_source += gp.Reject( + mask=mask_placeholder, + min_masked=self.min_masked, + reject_probability=self.reject_probability, + ) + # for augment in self.augments: + # dataset_source += augment.node(raw_key, gt_key, mask_key) + + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( # TODO: could we add this above reject, and use weights to determine if we should reject? + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) # Add predictor nodes to pipeline - pipeline += DaCapoTargetFilter( + pipeline += DaCapoTargetFilter( # TODO: why are there two of these? task.predictor, gt_key=gt_key, target_key=target_key, diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index a6b90c659..032dba23d 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -27,5 +27,7 @@ class GunpowderTrainerConfig(TrainerConfig): default=None, metadata={"help_text": "Number of iterations before saving a new snapshot."}, ) - min_masked: Optional[float] = attr.ib(default=0.15) + min_masked: Optional[float] = attr.ib(default=1e-6) + reject_probability: Optional[float or None] = attr.ib(default=1) + weighted_reject: bool = attr.ib(default=False) clip_raw: bool = attr.ib(default=False) diff --git a/dacapo/experiments/training_stats.py b/dacapo/experiments/training_stats.py index cd3fcd012..72c631ed4 100644 --- a/dacapo/experiments/training_stats.py +++ b/dacapo/experiments/training_stats.py @@ -16,7 +16,9 @@ class TrainingStats: def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None: if len(self.iteration_stats) > 0: - assert iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 + assert ( + iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 + ), f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}" self.iteration_stats.append(iteration_stats) diff --git a/dacapo/gp/dacapo_create_target.py b/dacapo/gp/dacapo_create_target.py index f136c5c7b..42358b7b0 100644 --- a/dacapo/gp/dacapo_create_target.py +++ b/dacapo/gp/dacapo_create_target.py @@ -84,7 +84,9 @@ def process(self, batch, request): gt_array = NumpyArray.from_gp_array(batch[self.gt_key]) target_array = self.predictor.create_target(gt_array) - mask_array = NumpyArray.from_gp_array(batch[self.mask_key]) + mask_array = NumpyArray.from_gp_array( + batch[self.mask_key] + ) # TODO: doesn't this require mask_key to be set? if self.target_key is not None: request_spec = request[self.target_key] diff --git a/dacapo/predict.py b/dacapo/predict.py index aeb3df173..07483bea1 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -51,12 +51,7 @@ def predict( logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi) # prepare prediction dataset - if raw_array.file_name.endswith( - "zarr" - ) == prediction_array_identifier.container.name.endswith("zarr"): - axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] - else: - axes = ["c"] + [axis for axis in raw_array.axes[::-1] if axis != "c"] + axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] ZarrArray.create_from_array_identifier( prediction_array_identifier, axes, diff --git a/dacapo/train.py b/dacapo/train.py index a2068dd61..cef5da5ca 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dacapo.store.create_store import create_array_store from .experiments import Run from .compute_context import LocalTorch, ComputeContext @@ -97,10 +98,19 @@ def train_run( weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - raise RuntimeError( + logger.warn( f"Found weights for iteration {latest_weights_iteration}, but " - f"run {run.name} was only trained until {trained_until}." + f"run {run.name} was only trained until {trained_until}. " + "Filling stats with last observed values." ) + last_iteration_stats = run.training_stats.iteration_stats[-1] + for i in range( + last_iteration_stats.iteration, latest_weights_iteration - 1 + ): + new_iteration_stats = deepcopy(last_iteration_stats) + new_iteration_stats.iteration = i + 1 + run.training_stats.add_iteration_stats(new_iteration_stats) + trained_until = run.training_stats.trained_until() # start/resume training @@ -163,6 +173,7 @@ def train_run( run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) + stats_store.store_training_stats(run.name, run.training_stats) weights_store.store_weights(run, iteration_stats.iteration + 1) try: validate_run( @@ -179,13 +190,9 @@ def train_run( f"{iteration_stats.iteration + 1}.", exc_info=e, ) - stats_store.store_training_stats(run.name, run.training_stats) # make sure to move optimizer back to the correct device run.move_optimizer(compute_context.device) run.model.train() - weights_store.store_weights(run, run.training_stats.trained_until()) - stats_store.store_training_stats(run.name, run.training_stats) - logger.info("Trained until %d, finished.", trained_until) diff --git a/dacapo/validate.py b/dacapo/validate.py index bb630e335..90aad03e6 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -80,7 +80,7 @@ def validate_run( # Initialize the evaluator with the best scores seen so far evaluator.set_best(run.validation_scores) - for validation_dataset in run.datasplit.validate: + for validation_dataset in reloading(run.datasplit.validate): if validation_dataset.gt is None: logger.error( "We do not yet support validating on datasets without ground truth" @@ -165,11 +165,13 @@ def validate_run( validation_dataset, criterion ) - for parameters in post_processor.enumerate_parameters(): + any_overall_best = False + output_array_identifiers = [] + for parameters in reloading(post_processor.enumerate_parameters()): output_array_identifier = array_store.validation_output_array( run.name, iteration, parameters, validation_dataset ) - + output_array_identifiers.append(output_array_identifier) post_processed_array = post_processor.process( parameters, output_array_identifier ) @@ -199,6 +201,7 @@ def validate_run( and current_score < overall_best_scores[criterion] ) ): + any_overall_best = True overall_best_scores[criterion] = current_score # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 @@ -228,16 +231,16 @@ def validate_run( weights_store.store_best( run, iteration, validation_dataset.name, criterion ) - else: - # delete current output. We only keep the best outputs as determined by - # the evaluator - # remove deletion for now since we have to rerun later to double check things anyway - array_store.remove(output_array_identifier) dataset_iteration_scores.append( [getattr(scores, criterion) for criterion in scores.criteria] ) + if not any_overall_best: + # We only keep the best outputs as determined by the evaluator + for output_array_identifier in output_array_identifiers: + array_store.remove(output_array_identifier) + iteration_scores.append(dataset_iteration_scores) array_store.remove(prediction_array_identifier) From 53b3556e14a8bd13e56182f79bc6efe84db28941 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 21 Sep 2023 16:21:02 -0400 Subject: [PATCH 18/33] feat: zarr fix dimension detection, simple augment config kwargs, validation bug hanging --- .../datasplits/datasets/arrays/zarr_array.py | 8 +- .../trainers/gp_augments/simple_config.py | 15 +- .../experiments/trainers/gunpowder_trainer.py | 51 ++- dacapo/experiments/validation_scores.py | 5 +- dacapo/validate.py | 373 +++++++++--------- 5 files changed, 236 insertions(+), 216 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 3a75c6546..081f08e9b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -49,16 +49,16 @@ def axes(self): try: return self._attributes["axes"] except KeyError: - logger.debug( + logger.info( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" - f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}", + f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", ) - return ["t", "z", "y", "x"][-self.dims : :] + return ["c", "z", "y", "x"][-self.dims : :] @property def dims(self) -> int: - return self.voxel_size.dims + return len(self.data.shape) @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index 86de2161c..9f56aa19e 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -7,5 +7,16 @@ @attr.s class SimpleAugmentConfig(AugmentConfig): - def node(self, _raw_key=None, _gt_key=None, _mask_key=None): - return gp.SimpleAugment() + def node( + self, + _raw_key=None, + _gt_key=None, + _mask_key=None, + mirror_only=None, + transpose_only=None, + mirror_probs=None, + transpose_probs=None, + ): + return gp.SimpleAugment( + mirror_only, transpose_only, mirror_probs, transpose_probs + ) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 24cfee7dd..e00b05797 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -48,12 +48,14 @@ def __init__(self, trainer_config): def create_optimizer(self, model): optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) - self.scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=0.01, - end_factor=1.0, - total_iters=1000, - last_epoch=-1, + self.scheduler = ( + torch.optim.lr_scheduler.LinearLR( # TODO: add scheduler to config + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=1000, + last_epoch=-1, + ) ) return optimizer @@ -62,7 +64,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): output_shape = Coordinate(model.output_shape) # get voxel sizes - raw_voxel_size = datasets[0].raw.voxel_size + raw_voxel_size = datasets[ + 0 + ].raw.voxel_size # TODO: make dataset specific / resample prediction_voxel_size = model.scale(raw_voxel_size) # define input and output size: @@ -143,19 +147,21 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) ) - for augment in self.augments: - dataset_source += augment.node(raw_key, gt_key, mask_key) if self.weighted_reject: # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( # TODO: could we add this above reject, and use weights to determine if we should reject? + for augment in self.augments: + dataset_source += augment.node(raw_key, gt_key, mask_key) + + dataset_source += DaCapoTargetFilter( task.predictor, gt_key=gt_key, - weights_key=dataset_weight_key, + weights_key=weight_key, + target_key=target_key, mask_key=mask_key, ) dataset_source += gp.Reject( - mask=dataset_weight_key, + mask=weight_key, min_masked=self.min_masked, reject_probability=self.reject_probability, ) @@ -165,31 +171,22 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): min_masked=self.min_masked, reject_probability=self.reject_probability, ) - # for augment in self.augments: - # dataset_source += augment.node(raw_key, gt_key, mask_key) + + for augment in self.augments: + dataset_source += augment.node(raw_key, gt_key, mask_key) # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( # TODO: could we add this above reject, and use weights to determine if we should reject? + dataset_source += DaCapoTargetFilter( task.predictor, gt_key=gt_key, - weights_key=dataset_weight_key, + weights_key=weight_key, + target_key=target_key, mask_key=mask_key, ) dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) - # Add predictor nodes to pipeline - pipeline += DaCapoTargetFilter( # TODO: why are there two of these? - task.predictor, - gt_key=gt_key, - target_key=target_key, - weights_key=datasets_weight_key, - mask_key=mask_key, - ) - - pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) - # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index 23a23d62e..ec568205c 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -142,7 +142,10 @@ def get_best( return (da_best_indexes, da_best_scores) else: if self.evaluation_scores.higher_is_better( - data.coords["criteria"].item() + list(data.coords["criteria"].values())[ + 0 + ] # TODO: what is the intended behavior here? (hot fix in place) + # data.coords["criteria"].item() ): return ( data.idxmax(dim, skipna=True, fill_value=None), diff --git a/dacapo/validate.py b/dacapo/validate.py index 90aad03e6..ca1052fb6 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -48,7 +48,7 @@ def validate( return validate_run(run, iteration, compute_context=compute_context) -@reloading +@reloading # allows us to fix validation bugs without interrupting training def validate_run( run: Run, iteration: int, compute_context: ComputeContext = LocalTorch() ): @@ -56,199 +56,208 @@ def validate_run( load the weights of that iteration, it is assumed that the model is already loaded correctly. Returns the best parameters and scores for this iteration.""" - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - run.model.eval() - - if ( - run.datasplit.validate is None - or len(run.datasplit.validate) == 0 - or run.datasplit.validate[0].gt is None - ): - logger.info("Cannot validate run %s. Continuing training!", run.name) - return None, None - - # get array and weight store - weights_store = create_weights_store() - array_store = create_array_store() - iteration_scores = [] - - # get post processor and evaluator - post_processor = run.task.post_processor - evaluator = run.task.evaluator - - # Initialize the evaluator with the best scores seen so far - evaluator.set_best(run.validation_scores) - - for validation_dataset in reloading(run.datasplit.validate): - if validation_dataset.gt is None: - logger.error( - "We do not yet support validating on datasets without ground truth" - ) - raise NotImplementedError - - logger.info( - "Validating run %s on dataset %s", run.name, validation_dataset.name - ) + try: # we don't want this to hold up training + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.to(compute_context.device) + run.model.eval() - ( - input_raw_array_identifier, - input_gt_array_identifier, - ) = array_store.validation_input_arrays(run.name, validation_dataset.name) if ( - not Path( - f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" - ).exists() - or not Path( - f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" - ).exists() + run.datasplit.validate is None + or len(run.datasplit.validate) == 0 + or run.datasplit.validate[0].gt is None ): - logger.info("Copying validation inputs!") - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = validation_dataset.gt.roi - - input_roi = ( - output_roi.grow(context, context) - .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") - .intersect(validation_dataset.raw.roi) + logger.info("Cannot validate run %s. Continuing training!", run.name) + return None, None + + # get array and weight store + weights_store = create_weights_store() + array_store = create_array_store() + iteration_scores = [] + + # get post processor and evaluator + post_processor = run.task.post_processor + evaluator = run.task.evaluator + + # Initialize the evaluator with the best scores seen so far + evaluator.set_best(run.validation_scores) + + for validation_dataset in run.datasplit.validate: + if validation_dataset.gt is None: + logger.error( + "We do not yet support validating on datasets without ground truth" + ) + raise NotImplementedError + + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name ) - input_raw = ZarrArray.create_from_array_identifier( + + ( input_raw_array_identifier, - validation_dataset.raw.axes, - input_roi, - validation_dataset.raw.num_channels, - validation_dataset.raw.voxel_size, - validation_dataset.raw.dtype, - name=f"{run.name}_validation_raw", - write_size=input_size, - ) - input_raw[input_roi] = validation_dataset.raw[input_roi] - input_gt = ZarrArray.create_from_array_identifier( input_gt_array_identifier, - validation_dataset.gt.axes, - output_roi, - validation_dataset.gt.num_channels, - validation_dataset.gt.voxel_size, - validation_dataset.gt.dtype, - name=f"{run.name}_validation_gt", - write_size=output_size, + ) = array_store.validation_input_arrays(run.name, validation_dataset.name) + if ( + not Path( + f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" + ).exists() + or not Path( + f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" + ).exists() + ): + logger.info("Copying validation inputs!") + input_voxel_size = validation_dataset.raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = validation_dataset.gt.roi + + input_roi = ( + output_roi.grow(context, context) + .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") + .intersect(validation_dataset.raw.roi) + ) + input_raw = ZarrArray.create_from_array_identifier( + input_raw_array_identifier, + validation_dataset.raw.axes, + input_roi, + validation_dataset.raw.num_channels, + validation_dataset.raw.voxel_size, + validation_dataset.raw.dtype, + name=f"{run.name}_validation_raw", + write_size=input_size, + ) + input_raw[input_roi] = validation_dataset.raw[input_roi] + input_gt = ZarrArray.create_from_array_identifier( + input_gt_array_identifier, + validation_dataset.gt.axes, + output_roi, + validation_dataset.gt.num_channels, + validation_dataset.gt.voxel_size, + validation_dataset.gt.dtype, + name=f"{run.name}_validation_gt", + write_size=output_size, + ) + input_gt[output_roi] = validation_dataset.gt[output_roi] + else: + logger.info("validation inputs already copied!") + + prediction_array_identifier = array_store.validation_prediction_array( + run.name, iteration, validation_dataset ) - input_gt[output_roi] = validation_dataset.gt[output_roi] - else: - logger.info("validation inputs already copied!") - - prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset - ) - predict( - run.model, - validation_dataset.raw, - prediction_array_identifier, - compute_context=compute_context, - output_roi=validation_dataset.gt.roi, - ) - - post_processor.set_prediction(prediction_array_identifier) - - dataset_iteration_scores = [] - - # set up dict for overall best scores - overall_best_scores = {} - for criterion in run.validation_scores.criteria: - overall_best_scores[criterion] = evaluator.get_overall_best( - validation_dataset, criterion + predict( + run.model, + validation_dataset.raw, + prediction_array_identifier, + compute_context=compute_context, + output_roi=validation_dataset.gt.roi, ) - any_overall_best = False - output_array_identifiers = [] - for parameters in reloading(post_processor.enumerate_parameters()): - output_array_identifier = array_store.validation_output_array( - run.name, iteration, parameters, validation_dataset - ) - output_array_identifiers.append(output_array_identifier) - post_processed_array = post_processor.process( - parameters, output_array_identifier - ) + post_processor.set_prediction(prediction_array_identifier) - scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) + dataset_iteration_scores = [] + + # set up dict for overall best scores + overall_best_scores = {} for criterion in run.validation_scores.criteria: - # replace predictions in array with the new better predictions - if evaluator.is_best( - validation_dataset, - parameters, - criterion, - scores, - ): - # then this is the current best score for this parameter, but not necessarily the overall best - higher_is_better = scores.higher_is_better(criterion) - # initial_best_score = overall_best_scores[criterion] - current_score = getattr(scores, criterion) - if not overall_best_scores[ - criterion - ] or ( # TODO: should be in evaluator - ( - higher_is_better - and current_score > overall_best_scores[criterion] - ) - or ( - not higher_is_better - and current_score < overall_best_scores[criterion] - ) + overall_best_scores[criterion] = evaluator.get_overall_best( + validation_dataset, criterion + ) + + any_overall_best = False + output_array_identifiers = [] + for parameters in post_processor.enumerate_parameters(): + output_array_identifier = array_store.validation_output_array( + run.name, iteration, parameters, validation_dataset + ) + output_array_identifiers.append(output_array_identifier) + post_processed_array = post_processor.process( + parameters, output_array_identifier + ) + + scores = evaluator.evaluate( + output_array_identifier, validation_dataset.gt + ) + for criterion in run.validation_scores.criteria: + # replace predictions in array with the new better predictions + if evaluator.is_best( + validation_dataset, + parameters, + criterion, + scores, ): - any_overall_best = True - overall_best_scores[criterion] = current_score - - # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 - # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file - # Either we do that, or we only write out the overall best, regardless of parameters - best_array_identifier = array_store.best_validation_array( - run.name, criterion, index=validation_dataset.name - ) - best_array = ZarrArray.create_from_array_identifier( - best_array_identifier, - post_processed_array.axes, - post_processed_array.roi, - post_processed_array.num_channels, - post_processed_array.voxel_size, - post_processed_array.dtype, - ) - best_array[best_array.roi] = post_processed_array[ - post_processed_array.roi - ] - best_array.add_metadata( - { - "iteration": iteration, - criterion: getattr(scores, criterion), - "parameters_id": parameters.id, - } - ) - weights_store.store_best( - run, iteration, validation_dataset.name, criterion - ) - - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) - - if not any_overall_best: - # We only keep the best outputs as determined by the evaluator - for output_array_identifier in output_array_identifiers: - array_store.remove(output_array_identifier) - - iteration_scores.append(dataset_iteration_scores) - array_store.remove(prediction_array_identifier) - - run.validation_scores.add_iteration_scores( - ValidationIterationScores(iteration, iteration_scores) - ) - stats_store = create_stats_store() - stats_store.store_validation_iteration_scores(run.name, run.validation_scores) + # then this is the current best score for this parameter, but not necessarily the overall best + higher_is_better = scores.higher_is_better(criterion) + # initial_best_score = overall_best_scores[criterion] + current_score = getattr(scores, criterion) + if not overall_best_scores[ + criterion + ] or ( # TODO: should be in evaluator + ( + higher_is_better + and current_score > overall_best_scores[criterion] + ) + or ( + not higher_is_better + and current_score < overall_best_scores[criterion] + ) + ): + any_overall_best = True + overall_best_scores[criterion] = current_score + + # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 + # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file + # Either we do that, or we only write out the overall best, regardless of parameters + best_array_identifier = array_store.best_validation_array( + run.name, criterion, index=validation_dataset.name + ) + best_array = ZarrArray.create_from_array_identifier( + best_array_identifier, + post_processed_array.axes, + post_processed_array.roi, + post_processed_array.num_channels, + post_processed_array.voxel_size, + post_processed_array.dtype, + ) + best_array[best_array.roi] = post_processed_array[ + post_processed_array.roi + ] + best_array.add_metadata( + { + "iteration": iteration, + criterion: getattr(scores, criterion), + "parameters_id": parameters.id, + } + ) + weights_store.store_best( + run, iteration, validation_dataset.name, criterion + ) + + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) + + if not any_overall_best: + # We only keep the best outputs as determined by the evaluator + for output_array_identifier in output_array_identifiers: + array_store.remove(output_array_identifier) + + iteration_scores.append(dataset_iteration_scores) + array_store.remove(prediction_array_identifier) + + run.validation_scores.add_iteration_scores( + ValidationIterationScores(iteration, iteration_scores) + ) + stats_store = create_stats_store() + stats_store.store_validation_iteration_scores(run.name, run.validation_scores) + except Exception as e: + logger.error( + f"Validation failed for run {run.name} at iteration " f"{iteration}.", + exc_info=e, + ) if __name__ == "__main__": From 8e6dfa351ccdc4ffd196e872a7d687b20e9c0459 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 21 Sep 2023 16:42:22 -0400 Subject: [PATCH 19/33] bugfix: simple augment config --- .../trainers/gp_augments/simple_config.py | 43 ++++++++++++++++--- setup.py | 1 + 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index 9f56aa19e..8859bde38 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -1,3 +1,4 @@ +from typing import List from .augment_config import AugmentConfig import gunpowder as gp @@ -7,16 +8,48 @@ @attr.s class SimpleAugmentConfig(AugmentConfig): + mirror_only: List[int] = attr.ib( + default=None, + metadata={ + "help_text": ( + "If set, only mirror between the given axes. This is useful to exclude channels that have a set direction, like time." + ) + }, + ) + transpose_only: List[int] = attr.ib( + default=None, + metadata={ + "help_text": ( + "If set, only transpose between the given axes. This is useful to exclude channels that have a set direction, like time." + ) + }, + ) + mirror_probs: List[float] = attr.ib( + default=None, + metadata={ + "help_text": ( + "Probability of mirroring along each axis. Defaults to 0.5 for each axis." + ) + }, + ) + transpose_probs: List[float] = attr.ib( + default=None, + metadata={ + "help_text": ( + "Probability of transposing along each axis. Defaults to 0.5 for each axis." + ) + }, + ) + def node( self, _raw_key=None, _gt_key=None, _mask_key=None, - mirror_only=None, - transpose_only=None, - mirror_probs=None, - transpose_probs=None, ): return gp.SimpleAugment( - mirror_only, transpose_only, mirror_probs, transpose_probs + self.mirror_only, + self.transpose_only, + self.mirror_probs, + self.transpose_probs, ) diff --git a/setup.py b/setup.py index 9d3b75231..e8ccc4ed5 100644 --- a/setup.py +++ b/setup.py @@ -36,5 +36,6 @@ "funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate", "gunpowder>=1.3", "lsds>=0.1.3", + "reloading", ], ) From 08690159cdb7e37c73581ab1855fa6580dfe69a6 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 21 Sep 2023 16:47:55 -0400 Subject: [PATCH 20/33] bugfix: simple augment config --- .../experiments/trainers/gp_augments/simple_config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index 8859bde38..c5c1d9456 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from .augment_config import AugmentConfig import gunpowder as gp @@ -8,7 +8,7 @@ @attr.s class SimpleAugmentConfig(AugmentConfig): - mirror_only: List[int] = attr.ib( + mirror_only: Optional[List[int]] = attr.ib( default=None, metadata={ "help_text": ( @@ -16,7 +16,7 @@ class SimpleAugmentConfig(AugmentConfig): ) }, ) - transpose_only: List[int] = attr.ib( + transpose_only: Optional[List[int]] = attr.ib( default=None, metadata={ "help_text": ( @@ -24,7 +24,7 @@ class SimpleAugmentConfig(AugmentConfig): ) }, ) - mirror_probs: List[float] = attr.ib( + mirror_probs: Optional[List[float]] = attr.ib( default=None, metadata={ "help_text": ( @@ -32,7 +32,7 @@ class SimpleAugmentConfig(AugmentConfig): ) }, ) - transpose_probs: List[float] = attr.ib( + transpose_probs: Optional[List[float]] = attr.ib( default=None, metadata={ "help_text": ( From 922ba62a4db80632d511580cc4bdd5e08f6f9e58 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 25 Sep 2023 10:45:59 -0400 Subject: [PATCH 21/33] feat: black format, iou score --- dacapo/experiments/tasks/affinities_task.py | 1 + .../tasks/affinities_task_config.py | 7 ++++++ .../evaluators/instance_evaluation_scores.py | 5 +++- .../tasks/evaluators/instance_evaluator.py | 14 +++++++++-- .../tasks/predictors/affinities_predictor.py | 16 +++++++++---- dacapo/experiments/validation_scores.py | 2 +- dacapo/gp/elastic_augment_fuse.py | 2 +- docs/source/conf.py | 23 ++++++++++--------- 8 files changed, 49 insertions(+), 21 deletions(-) diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index 818422b2b..4a1b8cc4a 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -16,6 +16,7 @@ def __init__(self, task_config): lsds=task_config.lsds, num_voxels=task_config.num_voxels, downsample_lsds=task_config.downsample_lsds, + grow_boundary_iterations=task_config.grow_boundary_iterations, ) self.loss = AffinitiesLoss(len(task_config.neighborhood)) self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood) diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index 264d002c0..0a94db79d 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -43,3 +43,10 @@ class AffinitiesTaskConfig(TaskConfig): "This is useful for speeding up training and inference." }, ) + grow_boundary_iterations: int = attr.ib( + default=0, + metadata={ + "help_text": "The number of iterations to run the grow boundaries algorithm. " + "This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects." + }, + ) diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py index 7de54d99c..16eac194c 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py @@ -6,10 +6,11 @@ @attr.s class InstanceEvaluationScores(EvaluationScores): - criteria = ["voi_split", "voi_merge", "voi"] + criteria = ["voi_split", "voi_merge", "voi", "avg_iou"] voi_split: float = attr.ib(default=float("nan")) voi_merge: float = attr.ib(default=float("nan")) + avg_iou: float = attr.ib(default=float("nan")) @property def voi(self): @@ -21,6 +22,7 @@ def higher_is_better(criterion: str) -> bool: "voi_split": False, "voi_merge": False, "voi": False, + "avg_iou": True, } return mapping[criterion] @@ -30,6 +32,7 @@ def bounds(criterion: str) -> Tuple[float, float]: "voi_split": (0, 1), "voi_merge": (0, 1), "voi": (0, 1), + "avg_iou": (0, None), } return mapping[criterion] diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluator.py b/dacapo/experiments/tasks/evaluators/instance_evaluator.py index 0f3427a40..5650d44b4 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluator.py @@ -3,7 +3,7 @@ from .evaluator import Evaluator from .instance_evaluation_scores import InstanceEvaluationScores -from funlib.evaluate import rand_voi +from funlib.evaluate import rand_voi, detection_scores import numpy as np @@ -16,9 +16,19 @@ def evaluate(self, output_array_identifier, evaluation_array): evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) output_data = output_array[output_array.roi].astype(np.uint64) results = rand_voi(evaluation_data, output_data) + results.update( + detection_scores( + evaluation_data, + output_data, + matching_score="iou", + voxel_size=output_array.voxel_size, + ) + ) return InstanceEvaluationScores( - voi_merge=results["voi_merge"], voi_split=results["voi_split"] + voi_merge=results["voi_merge"], + voi_split=results["voi_split"], + avg_iou=results["avg_iou"], ) @property diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 0eaf61c67..40d81f5da 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -23,6 +23,7 @@ def __init__( lsds: bool = True, num_voxels: int = 20, downsample_lsds: int = 1, + grow_boundary_iterations: int = 0, ): self.neighborhood = neighborhood self.lsds = lsds @@ -40,6 +41,7 @@ def __init__( self.downsample_lsds = downsample_lsds else: self.num_lsds = 0 + self.grow_boundary_iterations = grow_boundary_iterations def extractor(self, voxel_size): if self._extractor is None: @@ -127,7 +129,9 @@ def _grow_boundaries(self, mask, slab): slice(start[d], start[d] + slab[d]) for d in range(len(slab)) ) mask_slab = mask[slices] - dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1) + dilated_mask_slab = ndimage.binary_dilation( + mask_slab, iterations=self.grow_boundary_iterations + ) foreground[slices] = dilated_mask_slab # label new background @@ -139,10 +143,12 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): (moving_class_counts, moving_lsd_class_counts) = ( moving_class_counts if moving_class_counts is not None else (None, None) ) - # mask_data = self._grow_boundaries( - # mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) - # ) - mask_data = mask[target.roi] + if self.grow_boundary_iterations > 0: + mask_data = self._grow_boundaries( + mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes) + ) + else: + mask_data = mask[target.roi] aff_weights, moving_class_counts = balance_weights( target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8), 2, diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index ec568205c..17727cc22 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -142,7 +142,7 @@ def get_best( return (da_best_indexes, da_best_scores) else: if self.evaluation_scores.higher_is_better( - list(data.coords["criteria"].values())[ + list(data.coords["criteria"].values)[ 0 ] # TODO: what is the intended behavior here? (hot fix in place) # data.coords["criteria"].item() diff --git a/dacapo/gp/elastic_augment_fuse.py b/dacapo/gp/elastic_augment_fuse.py index c7163f68d..70d4bc059 100644 --- a/dacapo/gp/elastic_augment_fuse.py +++ b/dacapo/gp/elastic_augment_fuse.py @@ -138,7 +138,7 @@ def _min_max_mean_std(ndarray, prefix=""): return "" -class ElasticAugment(BatchFilter): +class ElasticAugment(BatchFilter): # TODO: replace DeformAugment node from gunpowder """ Elasticly deform a batch. Requests larger batches upstream to avoid data loss due to rotation and jitter. diff --git a/docs/source/conf.py b/docs/source/conf.py index cd5823612..7df2f563b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,14 +12,15 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'DaCapo' -copyright = '2022, William Patton, David Ackerman, Jan Funke' -author = 'William Patton, David Ackerman, Jan Funke' +project = "DaCapo" +copyright = "2022, William Patton, David Ackerman, Jan Funke" +author = "William Patton, David Ackerman, Jan Funke" # -- General configuration --------------------------------------------------- @@ -27,15 +28,15 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx_autodoc_typehints'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -43,12 +44,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_material' +html_theme = "sphinx_material" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'css/custom.css', -] \ No newline at end of file + "css/custom.css", +] From 7d954f9ff005e7f833fa56630d32756f3514b4af Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 25 Sep 2023 11:09:04 -0400 Subject: [PATCH 22/33] bugfix: remove array overspecification --- dacapo/gp/dacapo_array_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dacapo/gp/dacapo_array_source.py b/dacapo/gp/dacapo_array_source.py index f9cf1446c..c00b2d504 100644 --- a/dacapo/gp/dacapo_array_source.py +++ b/dacapo/gp/dacapo_array_source.py @@ -26,7 +26,7 @@ class DaCapoArraySource(gp.BatchProvider): def __init__(self, array: Array, key: gp.ArrayKey): self.array = array self.array_spec = ArraySpec( - roi=self.array.roi, voxel_size=self.array.voxel_size, dtype=self.array.dtype + roi=self.array.roi, voxel_size=self.array.voxel_size ) self.key = key From ec1b0d870c8c68d29b9d087f7380301203697a73 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 25 Sep 2023 11:57:49 -0400 Subject: [PATCH 23/33] hot distance loss function --- .../tasks/losses/hot_distance_loss.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 dacapo/experiments/tasks/losses/hot_distance_loss.py diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py new file mode 100644 index 000000000..77f34fd08 --- /dev/null +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -0,0 +1,24 @@ +from .loss import Loss +import torch + +# HotDistance is used for predicting hot and distance maps at the same time. +# The first half of the channels are the hot maps, the second half are the distance maps. +# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. +# Model should predict twice the number of channels as the target. +class HotDistanceLoss(Loss): + def compute(self, prediction, target, weight): + target_hot, target_distance = self.split(target) + prediction_hot, prediction_distance = self.split(prediction) + weight_hot, weight_distance = self.split(weight) + return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance) + + def hot_loss(self, prediction, target, weight): + return torch.nn.BCELoss().forward(prediction * weight, target * weight) + + def distance_loss(self, prediction, target, weight): + return torch.nn.MSELoss().forward(prediction * weight, target * weight) + + def split(self, x): + assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + mid = x.shape[0] // 2 + return x[:mid], x[-mid:] \ No newline at end of file From be0f6db3d5af9278d69f285523e389686afba3f4 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 25 Sep 2023 12:05:55 -0400 Subject: [PATCH 24/33] feat: hotdistance predictor, model/target --- .../predictors/hot_distance_predictor.py | 268 ++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 dacapo/experiments/tasks/predictors/hot_distance_predictor.py diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py new file mode 100644 index 000000000..746cc998d --- /dev/null +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -0,0 +1,268 @@ +from dacapo.experiments.arraytypes.probabilities import ProbabilityArray +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import DistanceArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.balance_weights import balance_weights + +from funlib.geometry import Coordinate + +from scipy.ndimage.morphology import distance_transform_edt +import numpy as np +import torch + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +class HotDistancePredictor(Predictor): + """ + Predict signed distances and one hot embedding (as a proxy task) for a binary segmentation task. + Distances deep within background are pushed to -inf, distances deep within + the foreground object are pushed to inf. After distances have been + calculated they are passed through a tanh so that distances saturate at +-1. + Multiple classes can be predicted via multiple distance channels. The names + of each class that is being segmented can be passed in as a list of strings + in the channels argument. + """ + + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + self.channels = ( + channels * 2 + ) # one hot + distance (TODO: add hot/distance to channel names) + self.norm = "tanh" + self.dt_scale_factor = scale_factor + self.mask_distances = mask_distances + + self.max_distance = 1 * scale_factor + self.epsilon = 5e-2 # TODO: should be a config parameter + self.threshold = 0.8 # TODO: should be a config parameter + + @property + def embedding_dims(self): + return len(self.channels) + + def create_model(self, architecture): + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + + return Model(architecture, head) + + def create_target(self, gt): + target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor) + return NumpyArray.from_np_array( + target, + gt.roi, + gt.voxel_size, + gt.axes, + ) + + def create_weight(self, gt, target, mask, moving_class_counts=None): + # balance weights independently for each channel + if self.mask_distances: + distance_mask = self.create_distance_mask( + target[target.roi], + mask[target.roi], + target.voxel_size, + self.norm, + self.dt_scale_factor, + ) + else: + distance_mask = np.ones_like(target.data) + + weights, moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi], distance_mask], + moving_counts=moving_class_counts, + ) + return ( + NumpyArray.from_np_array( + weights, + gt.roi, + gt.voxel_size, + gt.axes, + ), + moving_class_counts, + ) + + @property + def output_array_type(self): + # technically this is a probability array + distance array, but it is only ever referenced for interpolatability (which is true for both) (TODO) + return ProbabilityArray(self.embedding_dims) + + def create_distance_mask( + self, + distances: np.ndarray, + mask: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + mask_output = mask.copy() + for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): + tmp = np.zeros( + np.array(channel_mask.shape) + np.array((2,) * channel_mask.ndim), + dtype=channel_mask.dtype, + ) + slices = tmp.ndim * (slice(1, -1),) + tmp[slices] = channel_mask + boundary_distance = distance_transform_edt( + tmp, + sampling=voxel_size, + ) + if self.epsilon is None: + add = 0 + else: + add = self.epsilon + boundary_distance = self.__normalize( + boundary_distance[slices], normalize, normalize_args + ) + + channel_mask_output = mask_output[i] + logging.debug( + "Total number of masked in voxels before distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance >= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after postive distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance <= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after negative distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + return mask_output + + def process( + self, + labels: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + one_hots = np.zeros( + (self.embedding_dims // 2,) + labels.shape[1:], dtype=np.uint8 + ) + # TODO: Assumes labels has a singleton channel dim and channel dim is first + for ii, channel in enumerate(labels): + boundaries = self.__find_boundaries(channel) + + # mark boundaries with 0 (not 1) + boundaries = 1.0 - boundaries + + if np.sum(boundaries == 0) == 0: + max_distance = min( + dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) + ) + if np.sum(channel) == 0: + distances = -np.ones(channel.shape, dtype=np.float32) * max_distance + else: + distances = np.ones(channel.shape, dtype=np.float32) * max_distance + else: + # get distances (voxel_size/2 because image is doubled) + distances = distance_transform_edt( + boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + ) + distances = distances.astype(np.float32) + + # restore original shape + downsample = (slice(None, None, 2),) * len(voxel_size) + distances = distances[downsample] + + # todo: inverted distance + distances[channel == 0] = -distances[channel == 0] + + if normalize is not None: + distances = self.__normalize(distances, normalize, normalize_args) + + all_distances[ii] = distances + one_hots[ii] += labels[0] == ii + + return np.concatenate(all_distances, one_hots) + + def __find_boundaries(self, labels): + # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n + # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 + # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 + # bound.: 00000001000100000001000 2n - 1 + + logger.debug("computing boundaries for %s", labels.shape) + + dims = len(labels.shape) + in_shape = labels.shape + out_shape = tuple(2 * s - 1 for s in in_shape) + + boundaries = np.zeros(out_shape, dtype=bool) + + logger.debug("boundaries shape is %s", boundaries.shape) + + for d in range(dims): + logger.debug("processing dimension %d", d) + + shift_p = [slice(None)] * dims + shift_p[d] = slice(1, in_shape[d]) + + shift_n = [slice(None)] * dims + shift_n[d] = slice(0, in_shape[d] - 1) + + diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0 + + logger.debug("diff shape is %s", diff.shape) + + target = [slice(None, None, 2)] * dims + target[d] = slice(1, out_shape[d], 2) + + logger.debug("target slices are %s", target) + + boundaries[tuple(target)] = diff + + return boundaries + + def __normalize(self, distances, norm, normalize_args): + if norm == "tanh": + scale = normalize_args + return np.tanh(distances / scale) + else: + raise ValueError("Only tanh is supported for normalization") + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims) From da48612801e890502dfc0c6a23163ad15481fe26 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 25 Sep 2023 12:06:40 -0400 Subject: [PATCH 25/33] hot distance task --- dacapo/experiments/tasks/hot_distance_task.py | 24 ++++++++++ .../tasks/hot_distance_task_config.py | 46 +++++++++++++++++++ dacapo/experiments/tasks/losses/__init__.py | 1 + 3 files changed, 71 insertions(+) create mode 100644 dacapo/experiments/tasks/hot_distance_task.py create mode 100644 dacapo/experiments/tasks/hot_distance_task_config.py diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py new file mode 100644 index 000000000..7f1e4dd96 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -0,0 +1,24 @@ +from .evaluators import BinarySegmentationEvaluator +from .losses import HotDistanceLoss +from .post_processors import ThresholdPostProcessor +from .predictors import HotDistancePredictor +from .task import Task + +class HotDistanceTask(Task): + """This is just a Hot Distance Task that combine Binary and distance prediction.""" + + def __init__(self, task_config): + """Create a `HotDistanceTask` from a `HotDistanceTaskConfig`.""" + + self.predictor = HotDistancePredictor( + channels=task_config.channels, + scale_factor=task_config.scale_factor, + mask_distances=task_config.mask_distances, + ) + self.loss = HotDistanceLoss() + self.post_processor = ThresholdPostProcessor() + self.evaluator = BinarySegmentationEvaluator( + clip_distance=task_config.clip_distance, + tol_distance=task_config.tol_distance, + channels=task_config.channels, + ) \ No newline at end of file diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py new file mode 100644 index 000000000..aab2b01d6 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -0,0 +1,46 @@ +import attr + +from .hot_distance_task import HotDistanceTask +from .task_config import TaskConfig + +from typing import List + +class HotDistanceTaskConfig(TaskConfig): + """This is a Hot Distance task config used for generating and + evaluating signed distance transforms as a way of generating + segmentations. + + The advantage of generating distance transforms over regular + affinities is you can get a denser signal, i.e. 1 misclassified + pixel in an affinity prediction could merge 2 otherwise very + distinct objects, this cannot happen with distances. + """ + + task_type = HotDistanceTask + + channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) + clip_distance: float = attr.ib( + metadata={ + "help_text": "Maximum distance to consider for false positive/negatives." + }, + ) + tol_distance: float = attr.ib( + metadata={ + "help_text": "Tolerance distance for counting false positives/negatives" + }, + ) + scale_factor: float = attr.ib( + default=1, + metadata={ + "help_text": "The amount by which to scale distances before applying " + "a tanh normalization." + }, + ) + mask_distances: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether or not to mask out regions where the true distance to " + "object boundary cannot be known. This is anywhere that the distance to crop boundary " + "is less than the distance to object boundary." + }, + ) \ No newline at end of file diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index b675faa96..f1db3586b 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -2,3 +2,4 @@ from .mse_loss import MSELoss # noqa from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa +from .hot_distance_loss import HotDistanceLoss # noqa From 9567ec496418236bfa39d5b3a9ac366d69440402 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 25 Sep 2023 12:40:13 -0400 Subject: [PATCH 26/33] feat: hot_distance_predictor target/weight --- .../predictors/hot_distance_predictor.py | 30 +++++++++++++------ .../experiments/trainers/gunpowder_trainer.py | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index 746cc998d..fc73cb0ea 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -44,6 +44,10 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo def embedding_dims(self): return len(self.channels) + @property + def classes(self): + return len(self.channels) // 2 + def create_model(self, architecture): if architecture.dims == 2: head = torch.nn.Conv2d( @@ -67,9 +71,17 @@ def create_target(self, gt): def create_weight(self, gt, target, mask, moving_class_counts=None): # balance weights independently for each channel + one_hot_weights, one_hot_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi]], + moving_counts=moving_class_counts[: self.classes], + ) + if self.mask_distances: distance_mask = self.create_distance_mask( - target[target.roi], + target[target.roi][-self.classes :], mask[target.roi], target.voxel_size, self.norm, @@ -78,12 +90,17 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): else: distance_mask = np.ones_like(target.data) - weights, moving_class_counts = balance_weights( + distance_weights, distance_moving_class_counts = balance_weights( gt[target.roi], 2, slab=tuple(1 if c == "c" else -1 for c in gt.axes), masks=[mask[target.roi], distance_mask], - moving_counts=moving_class_counts, + moving_counts=moving_class_counts[-self.classes :], + ) + + weights = np.concatenate((one_hot_weights, distance_weights)) + moving_class_counts = np.concatenate( + (one_hot_moving_class_counts, distance_moving_class_counts) ) return ( NumpyArray.from_np_array( @@ -168,10 +185,6 @@ def process( normalize_args=None, ): all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 - one_hots = np.zeros( - (self.embedding_dims // 2,) + labels.shape[1:], dtype=np.uint8 - ) - # TODO: Assumes labels has a singleton channel dim and channel dim is first for ii, channel in enumerate(labels): boundaries = self.__find_boundaries(channel) @@ -204,9 +217,8 @@ def process( distances = self.__normalize(distances, normalize, normalize_args) all_distances[ii] = distances - one_hots[ii] += labels[0] == ii - return np.concatenate(all_distances, one_hots) + return np.concatenate((labels, all_distances)) def __find_boundaries(self, labels): # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index e00b05797..889ab292c 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -86,7 +86,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") - dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") + dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") # TODO: put these back in datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") From 37de7e0f8f68b825df3fac5d9012b6f3585fd32f Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 26 Sep 2023 12:46:32 -0400 Subject: [PATCH 27/33] init show predictor --- dacapo/experiments/tasks/predictors/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index 76f82138d..044bb1881 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -3,3 +3,4 @@ from .one_hot_predictor import OneHotPredictor # noqa from .predictor import Predictor # noqa from .affinities_predictor import AffinitiesPredictor # noqa +from .hotspot_predictor import HotspotPredictor # noqa \ No newline at end of file From ee03505b8e8db74898a77aed54f6af2df91c68cf Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 26 Sep 2023 18:09:24 -0400 Subject: [PATCH 28/33] fix hotdistance bugs --- dacapo/experiments/tasks/__init__.py | 1 + .../experiments/tasks/hot_distance_task_config.py | 1 + .../experiments/tasks/losses/hot_distance_loss.py | 14 +++++++++----- dacapo/experiments/tasks/predictors/__init__.py | 2 +- .../tasks/predictors/hot_distance_predictor.py | 4 ++-- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/dacapo/experiments/tasks/__init__.py b/dacapo/experiments/tasks/__init__.py index 780f343d1..6ecea4863 100644 --- a/dacapo/experiments/tasks/__init__.py +++ b/dacapo/experiments/tasks/__init__.py @@ -5,3 +5,4 @@ from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa +from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa \ No newline at end of file diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index aab2b01d6..833c21b72 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -5,6 +5,7 @@ from typing import List +@attr.s class HotDistanceTaskConfig(TaskConfig): """This is a Hot Distance task config used for generating and evaluating signed distance transforms as a way of generating diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index 77f34fd08..cbcac982b 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -13,12 +13,16 @@ def compute(self, prediction, target, weight): return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance) def hot_loss(self, prediction, target, weight): - return torch.nn.BCELoss().forward(prediction * weight, target * weight) + loss = torch.nn.BCELoss() + return loss(prediction * weight, target * weight) def distance_loss(self, prediction, target, weight): - return torch.nn.MSELoss().forward(prediction * weight, target * weight) + loss = torch.nn.MSELoss() + return loss(prediction * weight, target * weight) def split(self, x): - assert x.shape[0] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." - mid = x.shape[0] // 2 - return x[:mid], x[-mid:] \ No newline at end of file + # Shape[0] is the batch size and Shape[1] is the number of channels. + assert x.shape[1] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + mid = x.shape[1] // 2 + return torch.split(x,mid,dim=1) + # return x[:,:mid], x[:,-mid:] \ No newline at end of file diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index 044bb1881..73a906315 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -3,4 +3,4 @@ from .one_hot_predictor import OneHotPredictor # noqa from .predictor import Predictor # noqa from .affinities_predictor import AffinitiesPredictor # noqa -from .hotspot_predictor import HotspotPredictor # noqa \ No newline at end of file +from .hot_distance_predictor import HotDistancePredictor # noqa \ No newline at end of file diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index fc73cb0ea..5fe2e8c28 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -76,7 +76,7 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): 2, slab=tuple(1 if c == "c" else -1 for c in gt.axes), masks=[mask[target.roi]], - moving_counts=moving_class_counts[: self.classes], + moving_counts=None if moving_class_counts is None else moving_class_counts[: self.classes], ) if self.mask_distances: @@ -95,7 +95,7 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): 2, slab=tuple(1 if c == "c" else -1 for c in gt.axes), masks=[mask[target.roi], distance_mask], - moving_counts=moving_class_counts[-self.classes :], + moving_counts=None if moving_class_counts is None else moving_class_counts[-self.classes :], ) weights = np.concatenate((one_hot_weights, distance_weights)) From cd4077def724efe7438bf679e721ba1b62587fc4 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 27 Sep 2023 13:07:49 -0400 Subject: [PATCH 29/33] fix bce loss --- dacapo/experiments/tasks/losses/hot_distance_loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index cbcac982b..eebd4165c 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -13,8 +13,9 @@ def compute(self, prediction, target, weight): return self.hot_loss(prediction_hot, target_hot, weight_hot) + self.distance_loss(prediction_distance, target_distance, weight_distance) def hot_loss(self, prediction, target, weight): - loss = torch.nn.BCELoss() - return loss(prediction * weight, target * weight) + loss = torch.nn.BCEWithLogitsLoss(reduction='none') + return torch.mean(loss(prediction , target) * weight) + # return abs(prediction * weight - target * weight).sum() def distance_loss(self, prediction, target, weight): loss = torch.nn.MSELoss() From 448f766e3bc8ccd05b1271098fd0dad3efa7aa24 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 9 Feb 2024 13:47:46 -0500 Subject: [PATCH 30/33] =?UTF-8?q?feat:=20=E2=9A=A1=EF=B8=8F=20Incorporate?= =?UTF-8?q?=20hot=5Fdistance=20related=20changes=20from=20rhoadesj/dev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/experiments/tasks/hot_distance_task.py | 25 ++ .../tasks/hot_distance_task_config.py | 47 +++ dacapo/experiments/tasks/losses/__init__.py | 1 + .../tasks/losses/hot_distance_loss.py | 29 ++ .../predictors/hot_distance_predictor.py | 280 ++++++++++++++++++ 5 files changed, 382 insertions(+) create mode 100644 dacapo/experiments/tasks/hot_distance_task.py create mode 100644 dacapo/experiments/tasks/hot_distance_task_config.py create mode 100644 dacapo/experiments/tasks/losses/hot_distance_loss.py create mode 100644 dacapo/experiments/tasks/predictors/hot_distance_predictor.py diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py new file mode 100644 index 000000000..ef0d03229 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -0,0 +1,25 @@ +from .evaluators import BinarySegmentationEvaluator +from .losses import HotDistanceLoss +from .post_processors import ThresholdPostProcessor +from .predictors import HotDistancePredictor +from .task import Task + + +class HotDistanceTask(Task): + """This is just a Hot Distance Task that combine Binary and distance prediction.""" + + def __init__(self, task_config): + """Create a `HotDistanceTask` from a `HotDistanceTaskConfig`.""" + + self.predictor = HotDistancePredictor( + channels=task_config.channels, + scale_factor=task_config.scale_factor, + mask_distances=task_config.mask_distances, + ) + self.loss = HotDistanceLoss() + self.post_processor = ThresholdPostProcessor() + self.evaluator = BinarySegmentationEvaluator( + clip_distance=task_config.clip_distance, + tol_distance=task_config.tol_distance, + channels=task_config.channels, + ) diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py new file mode 100644 index 000000000..951226476 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -0,0 +1,47 @@ +import attr + +from .hot_distance_task import HotDistanceTask +from .task_config import TaskConfig + +from typing import List + + +class HotDistanceTaskConfig(TaskConfig): + """This is a Hot Distance task config used for generating and + evaluating signed distance transforms as a way of generating + segmentations. + + The advantage of generating distance transforms over regular + affinities is you can get a denser signal, i.e. 1 misclassified + pixel in an affinity prediction could merge 2 otherwise very + distinct objects, this cannot happen with distances. + """ + + task_type = HotDistanceTask + + channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) + clip_distance: float = attr.ib( + metadata={ + "help_text": "Maximum distance to consider for false positive/negatives." + }, + ) + tol_distance: float = attr.ib( + metadata={ + "help_text": "Tolerance distance for counting false positives/negatives" + }, + ) + scale_factor: float = attr.ib( + default=1, + metadata={ + "help_text": "The amount by which to scale distances before applying " + "a tanh normalization." + }, + ) + mask_distances: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether or not to mask out regions where the true distance to " + "object boundary cannot be known. This is anywhere that the distance to crop boundary " + "is less than the distance to object boundary." + }, + ) diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index b675faa96..f1db3586b 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -2,3 +2,4 @@ from .mse_loss import MSELoss # noqa from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa +from .hot_distance_loss import HotDistanceLoss # noqa diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py new file mode 100644 index 000000000..2e99ab5e1 --- /dev/null +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -0,0 +1,29 @@ +from .loss import Loss +import torch + + +# HotDistance is used for predicting hot and distance maps at the same time. +# The first half of the channels are the hot maps, the second half are the distance maps. +# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. +# Model should predict twice the number of channels as the target. +class HotDistanceLoss(Loss): + def compute(self, prediction, target, weight): + target_hot, target_distance = self.split(target) + prediction_hot, prediction_distance = self.split(prediction) + weight_hot, weight_distance = self.split(weight) + return self.hot_loss( + prediction_hot, target_hot, weight_hot + ) + self.distance_loss(prediction_distance, target_distance, weight_distance) + + def hot_loss(self, prediction, target, weight): + return torch.nn.BCELoss().forward(prediction * weight, target * weight) + + def distance_loss(self, prediction, target, weight): + return torch.nn.MSELoss().forward(prediction * weight, target * weight) + + def split(self, x): + assert ( + x.shape[0] % 2 == 0 + ), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + mid = x.shape[0] // 2 + return x[:mid], x[-mid:] diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py new file mode 100644 index 000000000..fc73cb0ea --- /dev/null +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -0,0 +1,280 @@ +from dacapo.experiments.arraytypes.probabilities import ProbabilityArray +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import DistanceArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.balance_weights import balance_weights + +from funlib.geometry import Coordinate + +from scipy.ndimage.morphology import distance_transform_edt +import numpy as np +import torch + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +class HotDistancePredictor(Predictor): + """ + Predict signed distances and one hot embedding (as a proxy task) for a binary segmentation task. + Distances deep within background are pushed to -inf, distances deep within + the foreground object are pushed to inf. After distances have been + calculated they are passed through a tanh so that distances saturate at +-1. + Multiple classes can be predicted via multiple distance channels. The names + of each class that is being segmented can be passed in as a list of strings + in the channels argument. + """ + + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + self.channels = ( + channels * 2 + ) # one hot + distance (TODO: add hot/distance to channel names) + self.norm = "tanh" + self.dt_scale_factor = scale_factor + self.mask_distances = mask_distances + + self.max_distance = 1 * scale_factor + self.epsilon = 5e-2 # TODO: should be a config parameter + self.threshold = 0.8 # TODO: should be a config parameter + + @property + def embedding_dims(self): + return len(self.channels) + + @property + def classes(self): + return len(self.channels) // 2 + + def create_model(self, architecture): + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + + return Model(architecture, head) + + def create_target(self, gt): + target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor) + return NumpyArray.from_np_array( + target, + gt.roi, + gt.voxel_size, + gt.axes, + ) + + def create_weight(self, gt, target, mask, moving_class_counts=None): + # balance weights independently for each channel + one_hot_weights, one_hot_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi]], + moving_counts=moving_class_counts[: self.classes], + ) + + if self.mask_distances: + distance_mask = self.create_distance_mask( + target[target.roi][-self.classes :], + mask[target.roi], + target.voxel_size, + self.norm, + self.dt_scale_factor, + ) + else: + distance_mask = np.ones_like(target.data) + + distance_weights, distance_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi], distance_mask], + moving_counts=moving_class_counts[-self.classes :], + ) + + weights = np.concatenate((one_hot_weights, distance_weights)) + moving_class_counts = np.concatenate( + (one_hot_moving_class_counts, distance_moving_class_counts) + ) + return ( + NumpyArray.from_np_array( + weights, + gt.roi, + gt.voxel_size, + gt.axes, + ), + moving_class_counts, + ) + + @property + def output_array_type(self): + # technically this is a probability array + distance array, but it is only ever referenced for interpolatability (which is true for both) (TODO) + return ProbabilityArray(self.embedding_dims) + + def create_distance_mask( + self, + distances: np.ndarray, + mask: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + mask_output = mask.copy() + for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): + tmp = np.zeros( + np.array(channel_mask.shape) + np.array((2,) * channel_mask.ndim), + dtype=channel_mask.dtype, + ) + slices = tmp.ndim * (slice(1, -1),) + tmp[slices] = channel_mask + boundary_distance = distance_transform_edt( + tmp, + sampling=voxel_size, + ) + if self.epsilon is None: + add = 0 + else: + add = self.epsilon + boundary_distance = self.__normalize( + boundary_distance[slices], normalize, normalize_args + ) + + channel_mask_output = mask_output[i] + logging.debug( + "Total number of masked in voxels before distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance >= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after postive distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance <= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after negative distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + return mask_output + + def process( + self, + labels: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): + boundaries = self.__find_boundaries(channel) + + # mark boundaries with 0 (not 1) + boundaries = 1.0 - boundaries + + if np.sum(boundaries == 0) == 0: + max_distance = min( + dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) + ) + if np.sum(channel) == 0: + distances = -np.ones(channel.shape, dtype=np.float32) * max_distance + else: + distances = np.ones(channel.shape, dtype=np.float32) * max_distance + else: + # get distances (voxel_size/2 because image is doubled) + distances = distance_transform_edt( + boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + ) + distances = distances.astype(np.float32) + + # restore original shape + downsample = (slice(None, None, 2),) * len(voxel_size) + distances = distances[downsample] + + # todo: inverted distance + distances[channel == 0] = -distances[channel == 0] + + if normalize is not None: + distances = self.__normalize(distances, normalize, normalize_args) + + all_distances[ii] = distances + + return np.concatenate((labels, all_distances)) + + def __find_boundaries(self, labels): + # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n + # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 + # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 + # bound.: 00000001000100000001000 2n - 1 + + logger.debug("computing boundaries for %s", labels.shape) + + dims = len(labels.shape) + in_shape = labels.shape + out_shape = tuple(2 * s - 1 for s in in_shape) + + boundaries = np.zeros(out_shape, dtype=bool) + + logger.debug("boundaries shape is %s", boundaries.shape) + + for d in range(dims): + logger.debug("processing dimension %d", d) + + shift_p = [slice(None)] * dims + shift_p[d] = slice(1, in_shape[d]) + + shift_n = [slice(None)] * dims + shift_n[d] = slice(0, in_shape[d] - 1) + + diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0 + + logger.debug("diff shape is %s", diff.shape) + + target = [slice(None, None, 2)] * dims + target[d] = slice(1, out_shape[d], 2) + + logger.debug("target slices are %s", target) + + boundaries[tuple(target)] = diff + + return boundaries + + def __normalize(self, distances, norm, normalize_args): + if norm == "tanh": + scale = normalize_args + return np.tanh(distances / scale) + else: + raise ValueError("Only tanh is supported for normalization") + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims) From c810a0e757ad7174366c4aa463e2dffc806535b2 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Sun, 11 Feb 2024 14:13:31 -0500 Subject: [PATCH 31/33] Revert GunpowderTrainer class and configuration to main --- .../experiments/trainers/gunpowder_trainer.py | 52 +++++-------------- .../trainers/gunpowder_trainer_config.py | 2 +- 2 files changed, 14 insertions(+), 40 deletions(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 72427951e..f5d8fcd52 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -37,8 +37,6 @@ def __init__(self, trainer_config): self.print_profiling = 100 self.snapshot_iteration = trainer_config.snapshot_interval self.min_masked = trainer_config.min_masked - self.reject_probability = trainer_config.reject_probability - self.weighted_reject = trainer_config.weighted_reject self.augments = trainer_config.augments self.mask_integral_downsample_factor = 4 @@ -53,14 +51,12 @@ def __init__(self, trainer_config): def create_optimizer(self, model): optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) - self.scheduler = ( - torch.optim.lr_scheduler.LinearLR( # TODO: add scheduler to config - optimizer, - start_factor=0.01, - end_factor=1.0, - total_iters=1000, - last_epoch=-1, - ) + self.scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=1000, + last_epoch=-1, ) return optimizer @@ -69,9 +65,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): output_shape = Coordinate(model.output_shape) # get voxel sizes - raw_voxel_size = datasets[ - 0 - ].raw.voxel_size # TODO: make dataset specific / resample + raw_voxel_size = datasets[0].raw.voxel_size prediction_voxel_size = model.scale(raw_voxel_size) # define input and output size: @@ -91,7 +85,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") - dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") # TODO: put these back in + dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -99,7 +93,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): # Get source nodes dataset_sources = [] weights = [] - for dataset in datasets: # TODO: add automatic resampling? + for dataset in datasets: weights.append(dataset.weight) assert isinstance(dataset.weight, int), dataset @@ -152,30 +146,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) ) - if self.weighted_reject: - # Add predictor nodes to dataset_source - for augment in self.augments: - dataset_source += augment.node(raw_key, gt_key, mask_key) + dataset_source += gp.Reject(mask_placeholder, 1e-6) - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=weight_key, - target_key=target_key, - mask_key=mask_key, - ) - - dataset_source += gp.Reject( - mask=weight_key, - min_masked=self.min_masked, - reject_probability=self.reject_probability, - ) - else: - dataset_source += gp.Reject( - mask=mask_placeholder, - min_masked=self.min_masked, - reject_probability=self.reject_probability, - ) + for augment in self.augments: + dataset_source += augment.node(raw_key, gt_key, mask_key) if self.add_predictor_nodes_to_dataset: # Add predictor nodes to dataset_source @@ -290,7 +264,7 @@ def iterate(self, num_iterations, model, optimizer, device): } if mask is not None: snapshot_arrays["volumes/mask"] = mask - logger.info( + logger.warning( f"Saving Snapshot. Iteration: {iteration}, " f"Loss: {loss.detach().cpu().numpy().item()}!" ) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 255c73ad6..539e3c5e1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -28,7 +28,7 @@ class GunpowderTrainerConfig(TrainerConfig): metadata={"help_text": "Number of iterations before saving a new snapshot."}, ) min_masked: Optional[float] = attr.ib(default=0.15) - clip_raw: bool = attr.ib(default=False) + clip_raw: bool = attr.ib(default=True) add_predictor_nodes_to_dataset: Optional[bool] = attr.ib( default=True, From 5f99dd40cc2433f6f7ba1db4a3b16baf1fc86fd7 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Sun, 11 Feb 2024 19:13:56 +0000 Subject: [PATCH 32/33] :art: Format Python code with psf/black --- dacapo/experiments/tasks/__init__.py | 2 +- .../tasks/hot_distance_task_config.py | 1 + .../tasks/losses/hot_distance_loss.py | 16 +++++++++------- dacapo/experiments/tasks/predictors/__init__.py | 2 +- .../tasks/predictors/hot_distance_predictor.py | 8 ++++++-- dacapo/train.py | 2 +- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/dacapo/experiments/tasks/__init__.py b/dacapo/experiments/tasks/__init__.py index 6ecea4863..1327962e3 100644 --- a/dacapo/experiments/tasks/__init__.py +++ b/dacapo/experiments/tasks/__init__.py @@ -5,4 +5,4 @@ from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa -from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa \ No newline at end of file +from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index b7f276892..559d283de 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -5,6 +5,7 @@ from typing import List + @attr.s class HotDistanceTaskConfig(TaskConfig): """This is a Hot Distance task config used for generating and diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index 0394df8e7..784176bd0 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -14,17 +14,19 @@ def compute(self, prediction, target, weight): return self.hot_loss( prediction_hot, target_hot, weight_hot ) + self.distance_loss(prediction_distance, target_distance, weight_distance) - + def hot_loss(self, prediction, target, weight): - loss = torch.nn.BCEWithLogitsLoss(reduction='none') - return torch.mean(loss(prediction , target) * weight) - + loss = torch.nn.BCEWithLogitsLoss(reduction="none") + return torch.mean(loss(prediction, target) * weight) + def distance_loss(self, prediction, target, weight): loss = torch.nn.MSELoss() return loss(prediction * weight, target * weight) - + def split(self, x): # Shape[0] is the batch size and Shape[1] is the number of channels. - assert x.shape[1] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + assert ( + x.shape[1] % 2 == 0 + ), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." mid = x.shape[1] // 2 - return torch.split(x,mid,dim=1) + return torch.split(x, mid, dim=1) diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index 73a906315..3fe61ae03 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -3,4 +3,4 @@ from .one_hot_predictor import OneHotPredictor # noqa from .predictor import Predictor # noqa from .affinities_predictor import AffinitiesPredictor # noqa -from .hot_distance_predictor import HotDistancePredictor # noqa \ No newline at end of file +from .hot_distance_predictor import HotDistancePredictor # noqa diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index 5fe2e8c28..96a100c92 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -76,7 +76,9 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): 2, slab=tuple(1 if c == "c" else -1 for c in gt.axes), masks=[mask[target.roi]], - moving_counts=None if moving_class_counts is None else moving_class_counts[: self.classes], + moving_counts=None + if moving_class_counts is None + else moving_class_counts[: self.classes], ) if self.mask_distances: @@ -95,7 +97,9 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): 2, slab=tuple(1 if c == "c" else -1 for c in gt.axes), masks=[mask[target.roi], distance_mask], - moving_counts=None if moving_class_counts is None else moving_class_counts[-self.classes :], + moving_counts=None + if moving_class_counts is None + else moving_class_counts[-self.classes :], ) weights = np.concatenate((one_hot_weights, distance_weights)) diff --git a/dacapo/train.py b/dacapo/train.py index 6f0e1cd29..cbd025a98 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -98,7 +98,7 @@ def train_run( weights_store.retrieve_weights(run, iteration=trained_until) - elif latest_weights_iteration > trained_until: + elif latest_weights_iteration > trained_until: logger.warn( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " From 06f2dc3b52abb80548d398b8f2491ead2e71244b Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 14 Feb 2024 13:18:57 -0500 Subject: [PATCH 33/33] remove irrelevant stuff to hotdistance --- .../datasplits/datasets/arrays/zarr_array.py | 10 +- .../experiments/tasks/evaluators/evaluator.py | 46 --- .../evaluators/instance_evaluation_scores.py | 5 +- .../tasks/evaluators/instance_evaluator.py | 14 +- .../post_processors/argmax_post_processor.py | 3 +- .../post_processors/dummy_post_processor.py | 2 +- .../tasks/post_processors/post_processor.py | 2 - .../threshold_post_processor.py | 2 - .../trainers/gp_augments/__init__.py | 1 - .../gp_augments/gaussian_noise_config.py | 22 -- .../trainers/gp_augments/simple_config.py | 48 +-- dacapo/experiments/validation_scores.py | 7 +- dacapo/gp/dacapo_create_target.py | 4 +- dacapo/gp/elastic_augment_fuse.py | 2 +- dacapo/predict.py | 1 - dacapo/store/local_array_store.py | 2 +- dacapo/store/local_weights_store.py | 8 +- dacapo/train.py | 14 +- dacapo/validate.py | 348 ++++++++---------- 19 files changed, 168 insertions(+), 373 deletions(-) delete mode 100644 dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 081f08e9b..25f2c224e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -49,7 +49,7 @@ def axes(self): try: return self._attributes["axes"] except KeyError: - logger.info( + logger.debug( "DaCapo expects Zarr datasets to have an 'axes' attribute!\n" f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n" f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}", @@ -58,7 +58,7 @@ def axes(self): @property def dims(self) -> int: - return len(self.data.shape) + return self.voxel_size.dims @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: @@ -81,7 +81,7 @@ def writable(self) -> bool: @property def dtype(self) -> Any: - return self.data.dtype # TODO: why not use self._daisy_array.dtype? + return self.data.dtype @property def num_channels(self) -> Optional[int]: @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]: return [ax for ax in self.axes if ax not in set(["c", "b"])] @property - def data(self) -> Any: # TODO: why not use self._daisy_array.data? + def data(self) -> Any: zarr_container = zarr.open(str(self.file_name)) return zarr_container[self.dataset] @@ -116,7 +116,6 @@ def create_from_array_identifier( dtype, write_size=None, name=None, - overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that @@ -146,7 +145,6 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, - delete=overwrite, ) zarr_dataset = zarr_container[array_identifier.dataset] zarr_dataset.attrs["offset"] = ( diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 24096261a..9d5cbbda0 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -63,52 +63,6 @@ def is_best( else: return getattr(score, criterion) < previous_best_score - def get_overall_best(self, dataset: "Dataset", criterion: str): - overall_best = None - if self.best_scores: - for _, parameter, _ in self.best_scores.keys(): - score = self.best_scores[(dataset, parameter, criterion)] - if score is None: - overall_best = None - else: - _, current_parameter_score = score - if overall_best is None: - overall_best = current_parameter_score - else: - if current_parameter_score: - if self.higher_is_better(criterion): - if current_parameter_score > overall_best: - overall_best = current_parameter_score - else: - if current_parameter_score < overall_best: - overall_best = current_parameter_score - return overall_best - - def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): - overall_best = None - overall_best_parameters = None - if self.best_scores: - for _, parameter, _ in self.best_scores.keys(): - score = self.best_scores[(dataset, parameter, criterion)] - if score is None: - overall_best = None - else: - _, current_parameter_score = score - if overall_best is None: - overall_best = current_parameter_score - overall_best_parameters = parameter - else: - if current_parameter_score: - if self.higher_is_better(criterion): - if current_parameter_score > overall_best: - overall_best = current_parameter_score - overall_best_parameters = parameter - else: - if current_parameter_score < overall_best: - overall_best = current_parameter_score - overall_best_parameters = parameter - return overall_best_parameters - def set_best(self, validation_scores: "ValidationScores") -> None: """ Find the best iteration for each dataset/post_processing_parameter/criterion diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py index 16eac194c..7de54d99c 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py @@ -6,11 +6,10 @@ @attr.s class InstanceEvaluationScores(EvaluationScores): - criteria = ["voi_split", "voi_merge", "voi", "avg_iou"] + criteria = ["voi_split", "voi_merge", "voi"] voi_split: float = attr.ib(default=float("nan")) voi_merge: float = attr.ib(default=float("nan")) - avg_iou: float = attr.ib(default=float("nan")) @property def voi(self): @@ -22,7 +21,6 @@ def higher_is_better(criterion: str) -> bool: "voi_split": False, "voi_merge": False, "voi": False, - "avg_iou": True, } return mapping[criterion] @@ -32,7 +30,6 @@ def bounds(criterion: str) -> Tuple[float, float]: "voi_split": (0, 1), "voi_merge": (0, 1), "voi": (0, 1), - "avg_iou": (0, None), } return mapping[criterion] diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluator.py b/dacapo/experiments/tasks/evaluators/instance_evaluator.py index 5650d44b4..0f3427a40 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluator.py @@ -3,7 +3,7 @@ from .evaluator import Evaluator from .instance_evaluation_scores import InstanceEvaluationScores -from funlib.evaluate import rand_voi, detection_scores +from funlib.evaluate import rand_voi import numpy as np @@ -16,19 +16,9 @@ def evaluate(self, output_array_identifier, evaluation_array): evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) output_data = output_array[output_array.roi].astype(np.uint64) results = rand_voi(evaluation_data, output_data) - results.update( - detection_scores( - evaluation_data, - output_data, - matching_score="iou", - voxel_size=output_array.voxel_size, - ) - ) return InstanceEvaluationScores( - voi_merge=results["voi_merge"], - voi_split=results["voi_split"], - avg_iou=results["avg_iou"], + voi_merge=results["voi_merge"], voi_split=results["voi_split"] ) @property diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 799f2651e..709d1de34 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -19,7 +19,7 @@ def set_prediction(self, prediction_array_identifier): prediction_array_identifier ) - def process(self, parameters, output_array_identifier, overwrite: bool = False): + def process(self, parameters, output_array_identifier): output_array = ZarrArray.create_from_array_identifier( output_array_identifier, [dim for dim in self.prediction_array.axes if dim != "c"], @@ -27,7 +27,6 @@ def process(self, parameters, output_array_identifier, overwrite: bool = False): None, self.prediction_array.voxel_size, np.uint8, - overwrite=overwrite, ) output_array[self.prediction_array.roi] = np.argmax( diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index ddb249539..5a2c7810a 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: def set_prediction(self, prediction_array): pass - def process(self, parameters, output_array_identifier, overwrite: bool = False): + def process(self, parameters, output_array_identifier): # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index 4e4102d6b..020361cb9 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -33,8 +33,6 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", - overwrite: "bool", - blockwise: "bool", ) -> "Array": """Convert predictions into the final output.""" pass diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 32bf4cfc0..67ffdd066 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -28,7 +28,6 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", - overwrite: bool = False, ) -> ZarrArray: # TODO: Investigate Liskov substitution princple and whether it is a problem here # OOP theory states the super class should always be replaceable with its subclasses @@ -48,7 +47,6 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, - overwrite=overwrite, ) output_array[self.prediction_array.roi] = ( diff --git a/dacapo/experiments/trainers/gp_augments/__init__.py b/dacapo/experiments/trainers/gp_augments/__init__.py index 7e901c76f..0c93d4603 100644 --- a/dacapo/experiments/trainers/gp_augments/__init__.py +++ b/dacapo/experiments/trainers/gp_augments/__init__.py @@ -4,4 +4,3 @@ from .gamma_config import GammaAugmentConfig from .intensity_config import IntensityAugmentConfig from .intensity_scale_shift_config import IntensityScaleShiftAugmentConfig -from .gaussian_noise_config import GaussianNoiseAugmentConfig diff --git a/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py b/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py deleted file mode 100644 index ab3694d09..000000000 --- a/dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py +++ /dev/null @@ -1,22 +0,0 @@ -from .augment_config import AugmentConfig - -import gunpowder as gp - -import attr - - -@attr.s -class GaussianNoiseAugmentConfig(AugmentConfig): - mean: float = attr.ib( - metadata={"help_text": "The mean of the gaussian noise to apply to your data."}, - default=0.0, - ) - var: float = attr.ib( - metadata={"help_text": "The variance of the gaussian noise."}, - default=0.05, - ) - - def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): - return gp.NoiseAugment( - array=raw_key, mode="gaussian", mean=self.mean, var=self.var - ) diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index c5c1d9456..86de2161c 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -1,4 +1,3 @@ -from typing import List, Optional from .augment_config import AugmentConfig import gunpowder as gp @@ -8,48 +7,5 @@ @attr.s class SimpleAugmentConfig(AugmentConfig): - mirror_only: Optional[List[int]] = attr.ib( - default=None, - metadata={ - "help_text": ( - "If set, only mirror between the given axes. This is useful to exclude channels that have a set direction, like time." - ) - }, - ) - transpose_only: Optional[List[int]] = attr.ib( - default=None, - metadata={ - "help_text": ( - "If set, only transpose between the given axes. This is useful to exclude channels that have a set direction, like time." - ) - }, - ) - mirror_probs: Optional[List[float]] = attr.ib( - default=None, - metadata={ - "help_text": ( - "Probability of mirroring along each axis. Defaults to 0.5 for each axis." - ) - }, - ) - transpose_probs: Optional[List[float]] = attr.ib( - default=None, - metadata={ - "help_text": ( - "Probability of transposing along each axis. Defaults to 0.5 for each axis." - ) - }, - ) - - def node( - self, - _raw_key=None, - _gt_key=None, - _mask_key=None, - ): - return gp.SimpleAugment( - self.mirror_only, - self.transpose_only, - self.mirror_probs, - self.transpose_probs, - ) + def node(self, _raw_key=None, _gt_key=None, _mask_key=None): + return gp.SimpleAugment() diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index 17727cc22..8fba05687 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -113,7 +113,7 @@ def get_best( best value in two seperate arrays. """ if "criteria" in data.coords.keys(): - if len(data.coords["criteria"].shape) > 1: + if len(data.coords["criteria"].shape) == 1: criteria_bests: List[Tuple[xr.DataArray, xr.DataArray]] = [] for criterion in data.coords["criteria"].values: if self.evaluation_scores.higher_is_better(criterion.item()): @@ -142,10 +142,7 @@ def get_best( return (da_best_indexes, da_best_scores) else: if self.evaluation_scores.higher_is_better( - list(data.coords["criteria"].values)[ - 0 - ] # TODO: what is the intended behavior here? (hot fix in place) - # data.coords["criteria"].item() + data.coords["criteria"].item() ): return ( data.idxmax(dim, skipna=True, fill_value=None), diff --git a/dacapo/gp/dacapo_create_target.py b/dacapo/gp/dacapo_create_target.py index 42358b7b0..f136c5c7b 100644 --- a/dacapo/gp/dacapo_create_target.py +++ b/dacapo/gp/dacapo_create_target.py @@ -84,9 +84,7 @@ def process(self, batch, request): gt_array = NumpyArray.from_gp_array(batch[self.gt_key]) target_array = self.predictor.create_target(gt_array) - mask_array = NumpyArray.from_gp_array( - batch[self.mask_key] - ) # TODO: doesn't this require mask_key to be set? + mask_array = NumpyArray.from_gp_array(batch[self.mask_key]) if self.target_key is not None: request_spec = request[self.target_key] diff --git a/dacapo/gp/elastic_augment_fuse.py b/dacapo/gp/elastic_augment_fuse.py index 66dec97f0..b070d20ab 100644 --- a/dacapo/gp/elastic_augment_fuse.py +++ b/dacapo/gp/elastic_augment_fuse.py @@ -138,7 +138,7 @@ def _min_max_mean_std(ndarray, prefix=""): return "" -class ElasticAugment(BatchFilter): # TODO: replace DeformAugment node from gunpowder +class ElasticAugment(BatchFilter): """ Elasticly deform a batch. Requests larger batches upstream to avoid data loss due to rotation and jitter. diff --git a/dacapo/predict.py b/dacapo/predict.py index 491fd96fb..1df4d779e 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -70,7 +70,6 @@ def predict( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) - pipeline += gp.Normalize(raw) # raw: (c, d, h, w) pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) # raw: (c, d, h, w) diff --git a/dacapo/store/local_array_store.py b/dacapo/store/local_array_store.py index 9265c8f67..c1581fc7b 100644 --- a/dacapo/store/local_array_store.py +++ b/dacapo/store/local_array_store.py @@ -55,7 +55,7 @@ def validation_input_arrays( If we don't store these we would have to look up the datasplit config and figure out where to find the inputs for each run. If we write the data then we don't need to search for it. - This convenience comes at the cost of extra memory usage. #TODO: FIX THIS - refactor + This convenience comes at the cost of some extra memory usage. """ container = self.validation_container(run_name).container diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index c34afc6c3..c5f0ba5ff 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -62,9 +62,7 @@ def retrieve_weights(self, run: str, iteration: int) -> Weights: return weights - def _retrieve_weights( - self, run: str, key: str - ) -> Weights: # TODO: redundant with above? + def _retrieve_weights(self, run: str, key: str) -> Weights: weights_name = self.__get_weights_dir(run) / key if not weights_name.exists(): weights_name = self.__get_weights_dir(run) / "iterations" / key @@ -106,14 +104,14 @@ def retrieve_best(self, run: str, dataset: str, criterion: str) -> int: logger.info("Retrieving weights for run %s, criterion %s", run, criterion) weights_info = json.loads( - (self.__get_weights_dir(run) / dataset / f"{criterion}.json") + (self.__get_weights_dir(run) / criterion / f"{dataset}.json") .open("r") .read() ) return weights_info["iteration"] - def _load_best(self, run: Run, criterion: str): # TODO: probably won't work + def _load_best(self, run: Run, criterion: str): logger.info("Retrieving weights for run %s, criterion %s", run, criterion) weights_name = self.__get_weights_dir(run) / f"{criterion}" diff --git a/dacapo/train.py b/dacapo/train.py index cbd025a98..7beb096b4 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,4 +1,3 @@ -from copy import deepcopy from dacapo.store.create_store import create_array_store from .experiments import Run from .compute_context import LocalTorch, ComputeContext @@ -11,7 +10,6 @@ import logging logger = logging.getLogger(__name__) -logger.setLevel("INFO") def train(run_name: str, compute_context: ComputeContext = LocalTorch()): @@ -99,19 +97,11 @@ def train_run( weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - logger.warn( + weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " - "Filling stats with last observed values." ) - last_iteration_stats = run.training_stats.iteration_stats[-1] - for i in range( - last_iteration_stats.iteration, latest_weights_iteration - 1 - ): - new_iteration_stats = deepcopy(last_iteration_stats) - new_iteration_stats.iteration = i + 1 - run.training_stats.add_iteration_stats(new_iteration_stats) - trained_until = run.training_stats.trained_until() # start/resume training diff --git a/dacapo/validate.py b/dacapo/validate.py index ca1052fb6..a1cf9da7d 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -12,7 +12,6 @@ import torch from pathlib import Path -from reloading import reloading import logging logger = logging.getLogger(__name__) @@ -48,7 +47,6 @@ def validate( return validate_run(run, iteration, compute_context=compute_context) -@reloading # allows us to fix validation bugs without interrupting training def validate_run( run: Run, iteration: int, compute_context: ComputeContext = LocalTorch() ): @@ -56,216 +54,164 @@ def validate_run( load the weights of that iteration, it is assumed that the model is already loaded correctly. Returns the best parameters and scores for this iteration.""" - try: # we don't want this to hold up training - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - run.model.to(compute_context.device) - run.model.eval() + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.eval() + + if ( + run.datasplit.validate is None + or len(run.datasplit.validate) == 0 + or run.datasplit.validate[0].gt is None + ): + logger.info("Cannot validate run %s. Continuing training!", run.name) + return None, None + + # get array and weight store + weights_store = create_weights_store() + array_store = create_array_store() + iteration_scores = [] + + # get post processor and evaluator + post_processor = run.task.post_processor + evaluator = run.task.evaluator + + # Initialize the evaluator with the best scores seen so far + evaluator.set_best(run.validation_scores) + + for validation_dataset in run.datasplit.validate: + assert ( + validation_dataset.gt is not None + ), "We do not yet support validating on datasets without ground truth" + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name + ) + ( + input_raw_array_identifier, + input_gt_array_identifier, + ) = array_store.validation_input_arrays(run.name, validation_dataset.name) if ( - run.datasplit.validate is None - or len(run.datasplit.validate) == 0 - or run.datasplit.validate[0].gt is None + not Path( + f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" + ).exists() + or not Path( + f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" + ).exists() ): - logger.info("Cannot validate run %s. Continuing training!", run.name) - return None, None - - # get array and weight store - weights_store = create_weights_store() - array_store = create_array_store() - iteration_scores = [] - - # get post processor and evaluator - post_processor = run.task.post_processor - evaluator = run.task.evaluator - - # Initialize the evaluator with the best scores seen so far - evaluator.set_best(run.validation_scores) - - for validation_dataset in run.datasplit.validate: - if validation_dataset.gt is None: - logger.error( - "We do not yet support validating on datasets without ground truth" - ) - raise NotImplementedError - - logger.info( - "Validating run %s on dataset %s", run.name, validation_dataset.name + logger.info("Copying validation inputs!") + input_voxel_size = validation_dataset.raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = validation_dataset.gt.roi + + input_roi = ( + output_roi.grow(context, context) + .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") + .intersect(validation_dataset.raw.roi) ) - - ( + input_raw = ZarrArray.create_from_array_identifier( input_raw_array_identifier, - input_gt_array_identifier, - ) = array_store.validation_input_arrays(run.name, validation_dataset.name) - if ( - not Path( - f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" - ).exists() - or not Path( - f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" - ).exists() - ): - logger.info("Copying validation inputs!") - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = validation_dataset.gt.roi - - input_roi = ( - output_roi.grow(context, context) - .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") - .intersect(validation_dataset.raw.roi) - ) - input_raw = ZarrArray.create_from_array_identifier( - input_raw_array_identifier, - validation_dataset.raw.axes, - input_roi, - validation_dataset.raw.num_channels, - validation_dataset.raw.voxel_size, - validation_dataset.raw.dtype, - name=f"{run.name}_validation_raw", - write_size=input_size, - ) - input_raw[input_roi] = validation_dataset.raw[input_roi] - input_gt = ZarrArray.create_from_array_identifier( - input_gt_array_identifier, - validation_dataset.gt.axes, - output_roi, - validation_dataset.gt.num_channels, - validation_dataset.gt.voxel_size, - validation_dataset.gt.dtype, - name=f"{run.name}_validation_gt", - write_size=output_size, - ) - input_gt[output_roi] = validation_dataset.gt[output_roi] - else: - logger.info("validation inputs already copied!") - - prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset + validation_dataset.raw.axes, + input_roi, + validation_dataset.raw.num_channels, + validation_dataset.raw.voxel_size, + validation_dataset.raw.dtype, + name=f"{run.name}_validation_raw", + write_size=input_size, ) - predict( - run.model, - validation_dataset.raw, - prediction_array_identifier, - compute_context=compute_context, - output_roi=validation_dataset.gt.roi, + input_raw[input_roi] = validation_dataset.raw[input_roi] + input_gt = ZarrArray.create_from_array_identifier( + input_gt_array_identifier, + validation_dataset.gt.axes, + output_roi, + validation_dataset.gt.num_channels, + validation_dataset.gt.voxel_size, + validation_dataset.gt.dtype, + name=f"{run.name}_validation_gt", + write_size=output_size, ) + input_gt[output_roi] = validation_dataset.gt[output_roi] + else: + logger.info("validation inputs already copied!") - post_processor.set_prediction(prediction_array_identifier) - - dataset_iteration_scores = [] - - # set up dict for overall best scores - overall_best_scores = {} - for criterion in run.validation_scores.criteria: - overall_best_scores[criterion] = evaluator.get_overall_best( - validation_dataset, criterion - ) - - any_overall_best = False - output_array_identifiers = [] - for parameters in post_processor.enumerate_parameters(): - output_array_identifier = array_store.validation_output_array( - run.name, iteration, parameters, validation_dataset - ) - output_array_identifiers.append(output_array_identifier) - post_processed_array = post_processor.process( - parameters, output_array_identifier - ) - - scores = evaluator.evaluate( - output_array_identifier, validation_dataset.gt - ) - for criterion in run.validation_scores.criteria: - # replace predictions in array with the new better predictions - if evaluator.is_best( - validation_dataset, - parameters, - criterion, - scores, - ): - # then this is the current best score for this parameter, but not necessarily the overall best - higher_is_better = scores.higher_is_better(criterion) - # initial_best_score = overall_best_scores[criterion] - current_score = getattr(scores, criterion) - if not overall_best_scores[ - criterion - ] or ( # TODO: should be in evaluator - ( - higher_is_better - and current_score > overall_best_scores[criterion] - ) - or ( - not higher_is_better - and current_score < overall_best_scores[criterion] - ) - ): - any_overall_best = True - overall_best_scores[criterion] = current_score - - # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 - # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file - # Either we do that, or we only write out the overall best, regardless of parameters - best_array_identifier = array_store.best_validation_array( - run.name, criterion, index=validation_dataset.name - ) - best_array = ZarrArray.create_from_array_identifier( - best_array_identifier, - post_processed_array.axes, - post_processed_array.roi, - post_processed_array.num_channels, - post_processed_array.voxel_size, - post_processed_array.dtype, - ) - best_array[best_array.roi] = post_processed_array[ - post_processed_array.roi - ] - best_array.add_metadata( - { - "iteration": iteration, - criterion: getattr(scores, criterion), - "parameters_id": parameters.id, - } - ) - weights_store.store_best( - run, iteration, validation_dataset.name, criterion - ) - - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) - - if not any_overall_best: - # We only keep the best outputs as determined by the evaluator - for output_array_identifier in output_array_identifiers: - array_store.remove(output_array_identifier) - - iteration_scores.append(dataset_iteration_scores) - array_store.remove(prediction_array_identifier) - - run.validation_scores.add_iteration_scores( - ValidationIterationScores(iteration, iteration_scores) + prediction_array_identifier = array_store.validation_prediction_array( + run.name, iteration, validation_dataset ) - stats_store = create_stats_store() - stats_store.store_validation_iteration_scores(run.name, run.validation_scores) - except Exception as e: - logger.error( - f"Validation failed for run {run.name} at iteration " f"{iteration}.", - exc_info=e, + logger.info("Predicting on dataset %s", validation_dataset.name) + predict( + run.model, + validation_dataset.raw, + prediction_array_identifier, + compute_context=compute_context, + output_roi=validation_dataset.gt.roi, ) + logger.info("Predicted on dataset %s", validation_dataset.name) + + post_processor.set_prediction(prediction_array_identifier) + dataset_iteration_scores = [] -if __name__ == "__main__": - import argparse + for parameters in post_processor.enumerate_parameters(): + output_array_identifier = array_store.validation_output_array( + run.name, iteration, parameters, validation_dataset + ) + + post_processed_array = post_processor.process( + parameters, output_array_identifier + ) - parser = argparse.ArgumentParser() - parser.add_argument("run_name", type=str) - parser.add_argument("iteration", type=int) - args = parser.parse_args() + scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) - validate(args.run_name, args.iteration) + for criterion in run.validation_scores.criteria: + # replace predictions in array with the new better predictions + if evaluator.is_best( + validation_dataset, + parameters, + criterion, + scores, + ): + best_array_identifier = array_store.best_validation_array( + run.name, criterion, index=validation_dataset.name + ) + best_array = ZarrArray.create_from_array_identifier( + best_array_identifier, + post_processed_array.axes, + post_processed_array.roi, + post_processed_array.num_channels, + post_processed_array.voxel_size, + post_processed_array.dtype, + ) + best_array[best_array.roi] = post_processed_array[ + post_processed_array.roi + ] + best_array.add_metadata( + { + "iteration": iteration, + criterion: getattr(scores, criterion), + "parameters_id": parameters.id, + } + ) + weights_store.store_best( + run, iteration, validation_dataset.name, criterion + ) + + # delete current output. We only keep the best outputs as determined by + # the evaluator + array_store.remove(output_array_identifier) + + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) + + iteration_scores.append(dataset_iteration_scores) + array_store.remove(prediction_array_identifier) + + run.validation_scores.add_iteration_scores( + ValidationIterationScores(iteration, iteration_scores) + ) + stats_store = create_stats_store() + stats_store.store_validation_iteration_scores(run.name, run.validation_scores)