diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index b2a015a75..7db2d9c27 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -74,7 +74,12 @@ def run_blockwise( ) print("Running blockwise with worker_file: ", worker_file) print(f"Using compute context: {create_compute_context()}") - success = daisy.run_blockwise([task]) + compute_context = create_compute_context() + print(f"Using compute context: {compute_context}") + + multiprocessing = compute_context.distribute_workers + + success = daisy.run_blockwise([task], multiprocessing=multiprocessing) return success diff --git a/dacapo/cli.py b/dacapo/cli.py index 2af9aea77..a7966e236 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -76,8 +76,19 @@ def cli(log_level): @click.option( "-r", "--run-name", required=True, type=str, help="The NAME of the run to train." ) -def train(run_name): - dacapo.train(run_name) # TODO: run with compute_context +@click.option( + "--no-validation", is_flag=True, help="Disable validation after training." +) +def train(run_name, no_validation): + """ + Train a model with the specified run name. + + Args: + run_name (str): The name of the run to train. + no_validation (bool): Flag to disable validation after training. + """ + do_validate = not no_validation + dacapo.train(run_name, do_validate=do_validate) @cli.command() diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py index ba6fd99f0..4a5dc0208 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py @@ -171,7 +171,9 @@ def roi(self) -> Roi: This method returns the region of interest of the resampled array. """ - return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink") + return self._source_array.roi.snap_to_grid( + np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink" + ) @property def writable(self) -> bool: @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray: Note: This method returns the data of the resampled array within the given region of interest. """ - snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow") + snapped_roi = roi.snap_to_grid( + np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow" + ) resampled_array = funlib.persistence.Array( rescale( self._source_array[snapped_roi].astype(np.float32), @@ -352,4 +356,4 @@ def _source_name(self): Note: This method returns the name of the source array. """ - return self._source_array._source_name() + return self._source_array._source_name() \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 30c6ac693..3413108e8 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -14,6 +14,7 @@ from collections import OrderedDict import logging from upath import UPath as Path +import os import json from typing import Dict, Tuple, Any, Optional, List @@ -273,7 +274,9 @@ def roi(self) -> Roi: This method is used to return the region of interest of the array. """ if self.snap_to_grid is not None: - return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink") + return self._daisy_array.roi.snap_to_grid( + np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink" + ) else: return self._daisy_array.roi @@ -426,33 +429,12 @@ def create_from_array_identifier( num_channels, voxel_size, dtype, - mode="a", write_size=None, name=None, - overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that - this array_identifier points to a dataset that does not yet exist. - - Args: - array_identifier (ArrayIdentifier): The array identifier. - axes (List[str]): The axes of the array. - roi (Roi): The region of interest. - num_channels (int): The number of channels. - voxel_size (Coordinate): The voxel size. - dtype (Any): The data type. - write_size (Optional[Coordinate]): The write size. - name (Optional[str]): The name of the array. - overwrite (bool): The boolean value to overwrite the array. - Returns: - ZarrArray: The ZarrArray. - Raises: - NotImplementedError - Examples: - >>> create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False) - Notes: - This method is used to create a new ZarrArray given an array identifier. + this array_identifier points to a dataset that does not yet exist """ if write_size is None: # total storage per block is approx c*x*y*z*dtype_size @@ -469,11 +451,6 @@ def create_from_array_identifier( write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape))) zarr_container = zarr.open(array_identifier.container, "a") - if num_channels is None or num_channels == 1: - axes = [axis for axis in axes if "c" not in axis] - num_channels = None - else: - axes = ["c"] + [axis for axis in axes if "c" not in axis] try: funlib.persistence.prepare_ds( f"{array_identifier.container}", @@ -483,41 +460,21 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, - delete=overwrite, - force_exact_write_size=True, ) zarr_dataset = zarr_container[array_identifier.dataset] - if array_identifier.container.name.endswith("n5"): - zarr_dataset.attrs["offset"] = roi.offset[::-1] - zarr_dataset.attrs["resolution"] = voxel_size[::-1] - zarr_dataset.attrs["axes"] = axes[::-1] - # to make display right in neuroglancer: TODO ADD CHANNELS - zarr_dataset.attrs["dimension_units"] = [ - f"{size} nm" for size in voxel_size[::-1] - ] - zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [ - a if a != "c" else "c^" for a in axes[::-1] - ] - else: - zarr_dataset.attrs["offset"] = roi.offset - zarr_dataset.attrs["resolution"] = voxel_size - zarr_dataset.attrs["axes"] = axes - # to make display right in neuroglancer: TODO ADD CHANNELS - zarr_dataset.attrs["dimension_units"] = [ - f"{size} nm" for size in voxel_size - ] - zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [ - a if a != "c" else "c^" for a in axes - ] - if "c" in axes: - if axes.index("c") == 0: - zarr_dataset.attrs["dimension_units"] = [ - str(num_channels) - ] + zarr_dataset.attrs["dimension_units"] - else: - zarr_dataset.attrs["dimension_units"] = zarr_dataset.attrs[ - "dimension_units" - ] + [str(num_channels)] + zarr_dataset.attrs["offset"] = ( + roi.offset[::-1] + if array_identifier.container.name.endswith("n5") + else roi.offset + ) + zarr_dataset.attrs["resolution"] = ( + voxel_size[::-1] + if array_identifier.container.name.endswith("n5") + else voxel_size + ) + zarr_dataset.attrs["axes"] = ( + axes[::-1] if array_identifier.container.name.endswith("n5") else axes + ) except zarr.errors.ContainsArrayError: zarr_dataset = zarr_container[array_identifier.dataset] assert ( @@ -733,4 +690,4 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None: """ dataset = zarr.open(self.file_name, mode="a")[self.dataset] for k, v in metadata.items(): - dataset.attrs[k] = v + dataset.attrs[k] = v \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 6c1a214cd..ec37b0747 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -913,23 +913,23 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): constant=1, ) - if len(target_images) > 1: - gt_config = ConcatArrayConfig( - name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt", - channels=[organelle for organelle in current_targets], - # source_array_configs={k: gt for k, gt in target_images.items()}, - source_array_configs={k: target_images[k] for k in current_targets}, - ) - mask_config = ConcatArrayConfig( - name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask", - channels=[organelle for organelle in current_targets], - # source_array_configs={k: mask for k, mask in target_masks.items()}, - # to be sure to have the same order - source_array_configs={k: target_masks[k] for k in current_targets}, - ) - else: - gt_config = list(target_images.values())[0] - mask_config = list(target_masks.values())[0] + # if len(target_images) > 1: + gt_config = ConcatArrayConfig( + name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt", + channels=[organelle for organelle in current_targets], + # source_array_configs={k: gt for k, gt in target_images.items()}, + source_array_configs={k: target_images[k] for k in current_targets}, + ) + mask_config = ConcatArrayConfig( + name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask", + channels=[organelle for organelle in current_targets], + # source_array_configs={k: mask for k, mask in target_masks.items()}, + # to be sure to have the same order + source_array_configs={k: target_masks[k] for k in current_targets}, + ) + # else: + # gt_config = list(target_images.values())[0] + # mask_config = list(target_masks.values())[0] return raw_config, gt_config, mask_config diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index f99c64d3a..aaa69e1e9 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -1,12 +1,12 @@ -from upath import UPath as Path -from dacapo.blockwise.scheduler import run_blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters from dacapo.store.array_store import LocalArrayIdentifier from .post_processor import PostProcessor -import dacapo.blockwise import numpy as np +import daisy from daisy import Roi, Coordinate +from dacapo.utils.array_utils import to_ndarray, save_ndarray +from funlib.persistence import open_ds from typing import Iterable @@ -43,7 +43,7 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: Note: This method should return a generator of instances of ``ThresholdPostProcessorParameters``. """ - for i, threshold in enumerate([100, 127, 150]): + for i, threshold in enumerate([127]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) def set_prediction(self, prediction_array_identifier): @@ -117,28 +117,31 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, - write_size, ) + read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) - # run blockwise post-processing - sucess = run_blockwise( - worker_file=str( - Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") - ), + input_array = open_ds(self.prediction_array_identifier.container.path,self.prediction_array_identifier.dataset) + + def process_block(block): + print("Predicting block", block.read_roi) + data = to_ndarray(input_array,block.read_roi) > parameters.threshold + if int(data.max()) == 0: + print("No data in block", block.read_roi) + return + save_ndarray(data, block.write_roi, output_array) + + task = daisy.Task( + f"threshold_{output_array.dataset}", total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, - num_workers=num_workers, - max_retries=2, # TODO: make this an option - timeout=None, # TODO: make this an option - ###### - input_array_identifier=self.prediction_array_identifier, - output_array_identifier=output_array_identifier, - threshold=parameters.threshold, + process_function=process_block, + check_function=None, + read_write_conflict=False, + fit="overhang", + max_retries=0, + timeout=None, ) - if not sucess: - raise RuntimeError("Blockwise post-processing failed.") - - return output_array + return daisy.run_blockwise([task], multiprocessing=False) \ No newline at end of file diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 6b9ecf205..372964df7 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -111,7 +111,7 @@ def create_optimizer(self, model): optimizer = torch.optim.RAdam( lr=self.learning_rate, params=model.parameters(), - decoupled_weight_decay=True, + # decoupled_weight_decay=True, ) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, @@ -161,6 +161,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") + dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") + datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -207,9 +209,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder, drop_channels=True, ) - + gp.Pad(raw_key, None) - + gp.Pad(gt_key, None) - + gp.Pad(mask_key, None) + + gp.Pad(raw_key, None, mode="constant", value=0) + + gp.Pad(gt_key, None, mode="constant", value=0) + + gp.Pad(mask_key, None, mode="constant", value=0) + gp.RandomLocation( ensure_nonempty=( sample_points_key if points_source is not None else None @@ -225,6 +227,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) + dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) @@ -233,10 +243,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=weight_key, + weights_key=datasets_weight_key, mask_key=mask_key, ) + pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) + # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py new file mode 100644 index 000000000..6024bb1a5 --- /dev/null +++ b/dacapo/predict_crop.py @@ -0,0 +1,104 @@ +from dacapo.experiments.model import Model +from dacapo.store.local_array_store import LocalArrayIdentifier +from funlib.persistence import open_ds, prepare_ds, Array +from dacapo.utils.array_utils import to_ndarray +from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray +from funlib.geometry import Coordinate, Roi +import numpy as np +from dacapo.compute_context import create_compute_context +from typing import Optional +import logging +import daisy +import torch +import os +from dacapo.utils.array_utils import to_ndarray, save_ndarray + +logger = logging.getLogger(__name__) + + +def predict( + model: Model, + raw_array_identifier: LocalArrayIdentifier, + prediction_array_identifier: LocalArrayIdentifier, + output_roi: Optional[Roi] = None, + min_raw: float = 0, + max_raw: float = 255, +): + shift = min_raw + scale = max_raw - min_raw + # get the model's input and output size + raw_array = open_ds(raw_array_identifier.container.path, raw_array_identifier.dataset) + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + context = (input_size - output_size) / 2 + + if output_roi is None: + input_roi = raw_array.roi + output_roi = input_roi.grow(-context, -context) + else: + input_roi = output_roi.grow(context, context) + + + read_roi = Roi((0,0,0), input_size) + write_roi = read_roi.grow(-context, -context) + + axes = ["c", "z", "y", "x"] + + num_channels = model.num_out_channels + + result_dataset = ZarrArray.create_from_array_identifier( + prediction_array_identifier, + axes, + output_roi, + num_channels, + output_voxel_size, + np.float32, + ) + + + logger.info("Total input ROI: %s, output ROI: %s", input_size, output_roi) + logger.info("Block read ROI: %s, write ROI: %s", read_roi, write_roi) + + out_container, out_dataset = ( + prediction_array_identifier.container.path, + prediction_array_identifier.dataset, + ) + compute_context = create_compute_context() + device = compute_context.device + + + def predict_fn(block): + raw_input = to_ndarray(raw_array,block.read_roi) + raw_input = 2.0 * (raw_input.astype(np.float32) - shift )/ scale - 1.0 + if len(raw_input.shape) == 3: + raw_input = np.expand_dims(raw_input, (0, 1)) + with torch.no_grad(): + predictions = model.forward(torch.from_numpy(raw_input).float().to(device)).detach().cpu().numpy()[0] + + predictions = (predictions + 1) * 255.0 / 2.0 + print(f"Predicting block {block.read_roi} uniques: {np.unique(predictions)}") + save_ndarray(predictions, block.write_roi, result_dataset) + # result_dataset[block.write_roi] = predictions + + # fixing the input roi to be a multiple of the output voxel size + input_roi = input_roi.snap_to_grid(np.lcm(input_voxel_size, output_voxel_size), mode="shrink") + + task = daisy.Task( + f"predict_{out_container}_{out_dataset}", + total_roi=input_roi, + read_roi=Roi((0, 0, 0), input_size), + write_roi=Roi(context, output_size), + process_function=predict_fn, + check_function=None, + read_write_conflict=False, + fit="overhang", + max_retries=0, + timeout=None, + ) + + return daisy.run_blockwise([task], multiprocessing=False) diff --git a/dacapo/train.py b/dacapo/train.py index a1d884b3d..a9e599970 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -def train(run_name: str): +def train(run_name: str, do_validate=True): """ Train a run @@ -44,7 +44,7 @@ def train(run_name: str): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - return train_run(run) + return train_run(run, do_validate) def train_run(run: Run, do_validate=True): diff --git a/dacapo/utils/array_utils.py b/dacapo/utils/array_utils.py new file mode 100644 index 000000000..d3efaac79 --- /dev/null +++ b/dacapo/utils/array_utils.py @@ -0,0 +1,62 @@ +import numpy as np +from funlib.persistence import Array + +def to_ndarray(result_data, roi, fill_value=0): + """An alternative implementation of `__getitem__` that supports + using fill values to request data that may extend outside the + roi covered by result_data. + + Args: + + roi (`class:Roi`, optional): + + If given, copy only the data represented by this ROI. This is + equivalent to:: + + array[roi].to_ndarray() + + fill_value (scalar, optional): + + The value to use to fill in values that are outside the ROI + provided by this data. Defaults to 0. + """ + + shape = roi.shape / result_data.voxel_size + data = np.zeros( + result_data[result_data.roi].shape[: result_data.n_channel_dims] + shape, dtype=result_data.data.dtype + ) + if fill_value != 0: + data[:] = fill_value + + array = Array(data, roi, result_data.voxel_size) + + shared_roi = result_data.roi.intersect(roi) + + if not shared_roi.empty: + array[shared_roi] = result_data[shared_roi] + + return data + +def save_ndarray(data, roi, array): + """An alternative implementation of `__setitem__` that supports + using fill values to request data that may extend outside the + roi covered by result_data. + + Args: + + roi (`class:Roi`, optional): + + If given, copy only the data represented by this ROI. This is + equivalent to:: + + array[roi] = data + + fill_value (scalar, optional): + + The value to use to fill in values that are outside the ROI + provided by this data. Defaults to 0. + """ + intersection_roi = roi.intersect(array.roi) + if not intersection_roi.empty: + result_array = Array(data, roi, array.voxel_size) + array[intersection_roi] = result_array[intersection_roi] \ No newline at end of file diff --git a/dacapo/utils/view.py b/dacapo/utils/view.py index 203f98cf4..2018ff2a9 100644 --- a/dacapo/utils/view.py +++ b/dacapo/utils/view.py @@ -440,7 +440,7 @@ def open_from_array_identitifier(self, array_identifier): >>> ds = viewer.open_from_array_identitifier(array_identifier) """ if os.path.exists(array_identifier.container / array_identifier.dataset): - return open_ds(str(array_identifier.container), array_identifier.dataset) + return open_ds(str(array_identifier.container.path), array_identifier.dataset) else: return None diff --git a/dacapo/validate.py b/dacapo/validate.py index 44f091e7a..b31a036b1 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -1,109 +1,69 @@ -from dacapo.compute_context import create_compute_context -from .predict import predict +from .predict_crop import predict + from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray -from .store.create_store import ( +from dacapo.store.create_store import ( create_array_store, create_config_store, create_stats_store, create_weights_store, ) +import torch -from upath import UPath as Path +from pathlib import Path import logging -from warnings import warn logger = logging.getLogger(__name__) -def validate_run( - run_name: str, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate a run at a given iteration. Loads the weights from a previously +def validate(run_name: str, iteration: int = 0): + """Validate a run at a given iteration. Loads the weights from a previously stored checkpoint. Returns the best parameters and scores for this - iteration. + iteration.""" + + logger.info("Validating run %s at iteration %d...", run_name, iteration) - Args: - run: The name of the run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays + # create run - """ - # Load the model and weights config_store = create_config_store() run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - compute_context = create_compute_context() - if iteration is not None and not compute_context.distribute_workers: - # create weights store - weights_store = create_weights_store() - - # load weights - run.model.load_state_dict( - weights_store.retrieve_weights(run_name, iteration).model - ) - - return validate( - run=run, - iteration=iteration, - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, - ) - - -def validate( - run: Run, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate a run at a given iteration. Loads the weights from a previously - stored checkpoint. Returns the best parameters and scores for this - iteration. - - Args: - run: The run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays - Returns: - The best parameters and scores for this iteration - Raises: - ValueError: If the run does not have a validation dataset or the dataset does not have ground truth. - Example: - validate(my_run, 1000) - """ - - print(f"Validating run {run.name} at iteration {iteration}...") - - run_name = run.name # read in previous training/validation stats + stats_store = create_stats_store() run.training_stats = stats_store.retrieve_training_stats(run_name) run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( run_name ) + # create weights store and read weights + if iteration > 0: + weights_store = create_weights_store() + weights_store.retrieve_weights(run, iteration) + + return validate_run(run, iteration) + + +def validate_run(run: Run, iteration: int): + """Validate an already loaded run at the given iteration. This does not + load the weights of that iteration, it is assumed that the model is already + loaded correctly. Returns the best parameters and scores for this + iteration.""" + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.eval() + if ( run.datasplit.validate is None or len(run.datasplit.validate) == 0 or run.datasplit.validate[0].gt is None ): - raise ValueError(f"Cannot validate run {run.name} at iteration {iteration}.") + logger.info("Cannot validate run %s. Continuing training!", run.name) + return None, None # get array and weight store + weights_store = create_weights_store() array_store = create_array_store() iteration_scores = [] @@ -112,36 +72,19 @@ def validate( evaluator = run.task.evaluator # Initialize the evaluator with the best scores seen so far - try: - evaluator.set_best(run.validation_scores) - except ValueError: - logger.warn( - f"Could not set best scores for run {run.name} at iteration {iteration}." - ) - + evaluator.set_best(run.validation_scores) for validation_dataset in run.datasplit.validate: - if validation_dataset.gt is None: - logger.error( - "We do not yet support validating on datasets without ground truth" - ) - raise NotImplementedError - - print(f"Validating run {run.name} on dataset {validation_dataset.name}") + assert ( + validation_dataset.gt is not None + ), "We do not yet support validating on datasets without ground truth" + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name + ) ( input_raw_array_identifier, input_gt_array_identifier, ) = array_store.validation_input_arrays(run.name, validation_dataset.name) - - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = validation_dataset.gt.roi - if ( not Path( f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" @@ -150,7 +93,15 @@ def validate( f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" ).exists() ): - print("Copying validation inputs!") + logger.info("Copying validation inputs!") + input_voxel_size = validation_dataset.raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = validation_dataset.gt.roi input_roi = ( output_roi.grow(context, context) @@ -167,7 +118,7 @@ def validate( name=f"{run.name}_validation_raw", write_size=input_size, ) - input_raw[input_roi] = validation_dataset.raw[input_roi].squeeze() + input_raw[input_roi] = validation_dataset.raw[input_roi] input_gt = ZarrArray.create_from_array_identifier( input_gt_array_identifier, validation_dataset.gt.axes, @@ -178,140 +129,81 @@ def validate( name=f"{run.name}_validation_gt", write_size=output_size, ) - input_gt[output_roi] = validation_dataset.gt[output_roi].squeeze() + input_gt[output_roi] = validation_dataset.gt[output_roi] else: - print("validation inputs already copied!") + logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset.name + run.name, iteration, validation_dataset ) - compute_context = create_compute_context() - success = predict( - run, - iteration if compute_context.distribute_workers else None, - input_container=input_raw_array_identifier.container, - input_dataset=input_raw_array_identifier.dataset, - output_path=prediction_array_identifier, - output_roi=validation_dataset.gt.roi, # type: ignore - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, + predict( + run.model, + input_raw_array_identifier, + prediction_array_identifier, + output_roi=validation_dataset.gt.roi, ) - if not success: - logger.error( - f"Could not predict run {run.name} on dataset {validation_dataset.name}." - ) - continue - - print(f"Predicted on dataset {validation_dataset.name}") - post_processor.set_prediction(prediction_array_identifier) - # # set up dict for overall best scores per dataset - # overall_best_scores = {} - # for criterion in run.validation_scores.criteria: - # overall_best_scores[criterion] = evaluator.get_overall_best( - # validation_dataset, criterion - # ) - - # any_overall_best = False - output_array_identifiers = [] dataset_iteration_scores = [] + for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration, str(parameters), validation_dataset.name + run.name, iteration, parameters, validation_dataset ) - output_array_identifiers.append(output_array_identifier) + post_processed_array = post_processor.process( parameters, output_array_identifier ) - try: - scores = evaluator.evaluate( - output_array_identifier, validation_dataset.gt # type: ignore - ) - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) - # for criterion in run.validation_scores.criteria: - # # replace predictions in array with the new better predictions - # if evaluator.is_best( - # validation_dataset, - # parameters, - # criterion, - # scores, - # ): - # # then this is the current best score for this parameter, but not necessarily the overall best - # # initial_best_score = overall_best_scores[criterion] - # current_score = getattr(scores, criterion) - # if not overall_best_scores[criterion] or evaluator.compare( - # current_score, overall_best_scores[criterion], criterion - # ): - # any_overall_best = True - # overall_best_scores[criterion] = current_score - - # # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 - # # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file - # # Either we do that, or we only write out the overall best, regardless of parameters - # best_array_identifier = array_store.best_validation_array( - # run.name, - # criterion, - # index=validation_dataset.name, - # ) - # best_array = ZarrArray.create_from_array_identifier( - # best_array_identifier, - # post_processed_array.axes, - # post_processed_array.roi, - # post_processed_array.num_channels, - # post_processed_array.voxel_size, - # post_processed_array.dtype, - # output_size, - # ) - # best_array[best_array.roi] = post_processed_array[ - # post_processed_array.roi - # ] - # best_array.add_metadata( - # { - # "iteration": iteration, - # criterion: getattr(scores, criterion), - # "parameters_id": parameters.id, - # } - # ) - # weights_store.store_best( - # run.name, - # iteration, - # validation_dataset.name, - # criterion, - # ) - except: - logger.error( - f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.", - exc_info=True, - stack_info=True, - ) - - # if not any_overall_best: - # # We only keep the best outputs as determined by the evaluator - # for output_array_identifier in output_array_identifiers: - # array_store.remove(prediction_array_identifier) - # array_store.remove(output_array_identifier) + scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) + + for criterion in run.validation_scores.criteria: + # replace predictions in array with the new better predictions + if evaluator.is_best( + validation_dataset, + parameters, + criterion, + scores, + ): + best_array_identifier = array_store.best_validation_array( + run.name, criterion, index=validation_dataset.name + ) + best_array = ZarrArray.create_from_array_identifier( + best_array_identifier, + post_processed_array.axes, + post_processed_array.roi, + post_processed_array.num_channels, + post_processed_array.voxel_size, + post_processed_array.dtype, + ) + best_array[best_array.roi] = post_processed_array[ + post_processed_array.roi + ] + best_array.add_metadata( + { + "iteration": iteration, + criterion: getattr(scores, criterion), + "parameters_id": parameters.id, + } + ) + weights_store.store_best( + run, iteration, validation_dataset.name, criterion + ) + + # delete current output. We only keep the best outputs as determined by + # the evaluator + array_store.remove(output_array_identifier) + + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) iteration_scores.append(dataset_iteration_scores) + array_store.remove(prediction_array_identifier) run.validation_scores.add_iteration_scores( ValidationIterationScores(iteration, iteration_scores) ) stats_store = create_stats_store() stats_store.store_validation_iteration_scores(run.name, run.validation_scores) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("run_name", type=str) - parser.add_argument("iteration", type=int) - args = parser.parse_args() - - validate(args.run_name, args.iteration)