From a1e46bd073592dd990b6592ae1dd4adb4f6f45f2 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 23 Sep 2024 17:36:29 -0400 Subject: [PATCH 01/17] local blockwise --- dacapo/blockwise/scheduler.py | 7 +- dacapo/cli.py | 15 +- .../datasets/arrays/resampled_array.py | 10 +- .../datasplits/datasets/arrays/zarr_array.py | 81 ++--- .../datasplits/datasplit_generator.py | 34 +- .../threshold_post_processor.py | 45 +-- .../experiments/trainers/gunpowder_trainer.py | 22 +- dacapo/predict_crop.py | 104 ++++++ dacapo/train.py | 4 +- dacapo/utils/array_utils.py | 62 ++++ dacapo/utils/view.py | 2 +- dacapo/validate.py | 314 ++++++------------ 12 files changed, 375 insertions(+), 325 deletions(-) create mode 100644 dacapo/predict_crop.py create mode 100644 dacapo/utils/array_utils.py diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index b2a015a75..7db2d9c27 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -74,7 +74,12 @@ def run_blockwise( ) print("Running blockwise with worker_file: ", worker_file) print(f"Using compute context: {create_compute_context()}") - success = daisy.run_blockwise([task]) + compute_context = create_compute_context() + print(f"Using compute context: {compute_context}") + + multiprocessing = compute_context.distribute_workers + + success = daisy.run_blockwise([task], multiprocessing=multiprocessing) return success diff --git a/dacapo/cli.py b/dacapo/cli.py index 2af9aea77..a7966e236 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -76,8 +76,19 @@ def cli(log_level): @click.option( "-r", "--run-name", required=True, type=str, help="The NAME of the run to train." ) -def train(run_name): - dacapo.train(run_name) # TODO: run with compute_context +@click.option( + "--no-validation", is_flag=True, help="Disable validation after training." +) +def train(run_name, no_validation): + """ + Train a model with the specified run name. + + Args: + run_name (str): The name of the run to train. + no_validation (bool): Flag to disable validation after training. + """ + do_validate = not no_validation + dacapo.train(run_name, do_validate=do_validate) @cli.command() diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py index ba6fd99f0..4a5dc0208 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py @@ -171,7 +171,9 @@ def roi(self) -> Roi: This method returns the region of interest of the resampled array. """ - return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink") + return self._source_array.roi.snap_to_grid( + np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink" + ) @property def writable(self) -> bool: @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray: Note: This method returns the data of the resampled array within the given region of interest. """ - snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow") + snapped_roi = roi.snap_to_grid( + np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow" + ) resampled_array = funlib.persistence.Array( rescale( self._source_array[snapped_roi].astype(np.float32), @@ -352,4 +356,4 @@ def _source_name(self): Note: This method returns the name of the source array. """ - return self._source_array._source_name() + return self._source_array._source_name() \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 30c6ac693..3413108e8 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -14,6 +14,7 @@ from collections import OrderedDict import logging from upath import UPath as Path +import os import json from typing import Dict, Tuple, Any, Optional, List @@ -273,7 +274,9 @@ def roi(self) -> Roi: This method is used to return the region of interest of the array. """ if self.snap_to_grid is not None: - return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink") + return self._daisy_array.roi.snap_to_grid( + np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink" + ) else: return self._daisy_array.roi @@ -426,33 +429,12 @@ def create_from_array_identifier( num_channels, voxel_size, dtype, - mode="a", write_size=None, name=None, - overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that - this array_identifier points to a dataset that does not yet exist. - - Args: - array_identifier (ArrayIdentifier): The array identifier. - axes (List[str]): The axes of the array. - roi (Roi): The region of interest. - num_channels (int): The number of channels. - voxel_size (Coordinate): The voxel size. - dtype (Any): The data type. - write_size (Optional[Coordinate]): The write size. - name (Optional[str]): The name of the array. - overwrite (bool): The boolean value to overwrite the array. - Returns: - ZarrArray: The ZarrArray. - Raises: - NotImplementedError - Examples: - >>> create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False) - Notes: - This method is used to create a new ZarrArray given an array identifier. + this array_identifier points to a dataset that does not yet exist """ if write_size is None: # total storage per block is approx c*x*y*z*dtype_size @@ -469,11 +451,6 @@ def create_from_array_identifier( write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape))) zarr_container = zarr.open(array_identifier.container, "a") - if num_channels is None or num_channels == 1: - axes = [axis for axis in axes if "c" not in axis] - num_channels = None - else: - axes = ["c"] + [axis for axis in axes if "c" not in axis] try: funlib.persistence.prepare_ds( f"{array_identifier.container}", @@ -483,41 +460,21 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, - delete=overwrite, - force_exact_write_size=True, ) zarr_dataset = zarr_container[array_identifier.dataset] - if array_identifier.container.name.endswith("n5"): - zarr_dataset.attrs["offset"] = roi.offset[::-1] - zarr_dataset.attrs["resolution"] = voxel_size[::-1] - zarr_dataset.attrs["axes"] = axes[::-1] - # to make display right in neuroglancer: TODO ADD CHANNELS - zarr_dataset.attrs["dimension_units"] = [ - f"{size} nm" for size in voxel_size[::-1] - ] - zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [ - a if a != "c" else "c^" for a in axes[::-1] - ] - else: - zarr_dataset.attrs["offset"] = roi.offset - zarr_dataset.attrs["resolution"] = voxel_size - zarr_dataset.attrs["axes"] = axes - # to make display right in neuroglancer: TODO ADD CHANNELS - zarr_dataset.attrs["dimension_units"] = [ - f"{size} nm" for size in voxel_size - ] - zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [ - a if a != "c" else "c^" for a in axes - ] - if "c" in axes: - if axes.index("c") == 0: - zarr_dataset.attrs["dimension_units"] = [ - str(num_channels) - ] + zarr_dataset.attrs["dimension_units"] - else: - zarr_dataset.attrs["dimension_units"] = zarr_dataset.attrs[ - "dimension_units" - ] + [str(num_channels)] + zarr_dataset.attrs["offset"] = ( + roi.offset[::-1] + if array_identifier.container.name.endswith("n5") + else roi.offset + ) + zarr_dataset.attrs["resolution"] = ( + voxel_size[::-1] + if array_identifier.container.name.endswith("n5") + else voxel_size + ) + zarr_dataset.attrs["axes"] = ( + axes[::-1] if array_identifier.container.name.endswith("n5") else axes + ) except zarr.errors.ContainsArrayError: zarr_dataset = zarr_container[array_identifier.dataset] assert ( @@ -733,4 +690,4 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None: """ dataset = zarr.open(self.file_name, mode="a")[self.dataset] for k, v in metadata.items(): - dataset.attrs[k] = v + dataset.attrs[k] = v \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 6c1a214cd..ec37b0747 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -913,23 +913,23 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): constant=1, ) - if len(target_images) > 1: - gt_config = ConcatArrayConfig( - name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt", - channels=[organelle for organelle in current_targets], - # source_array_configs={k: gt for k, gt in target_images.items()}, - source_array_configs={k: target_images[k] for k in current_targets}, - ) - mask_config = ConcatArrayConfig( - name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask", - channels=[organelle for organelle in current_targets], - # source_array_configs={k: mask for k, mask in target_masks.items()}, - # to be sure to have the same order - source_array_configs={k: target_masks[k] for k in current_targets}, - ) - else: - gt_config = list(target_images.values())[0] - mask_config = list(target_masks.values())[0] + # if len(target_images) > 1: + gt_config = ConcatArrayConfig( + name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt", + channels=[organelle for organelle in current_targets], + # source_array_configs={k: gt for k, gt in target_images.items()}, + source_array_configs={k: target_images[k] for k in current_targets}, + ) + mask_config = ConcatArrayConfig( + name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask", + channels=[organelle for organelle in current_targets], + # source_array_configs={k: mask for k, mask in target_masks.items()}, + # to be sure to have the same order + source_array_configs={k: target_masks[k] for k in current_targets}, + ) + # else: + # gt_config = list(target_images.values())[0] + # mask_config = list(target_masks.values())[0] return raw_config, gt_config, mask_config diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index f99c64d3a..aaa69e1e9 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -1,12 +1,12 @@ -from upath import UPath as Path -from dacapo.blockwise.scheduler import run_blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters from dacapo.store.array_store import LocalArrayIdentifier from .post_processor import PostProcessor -import dacapo.blockwise import numpy as np +import daisy from daisy import Roi, Coordinate +from dacapo.utils.array_utils import to_ndarray, save_ndarray +from funlib.persistence import open_ds from typing import Iterable @@ -43,7 +43,7 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: Note: This method should return a generator of instances of ``ThresholdPostProcessorParameters``. """ - for i, threshold in enumerate([100, 127, 150]): + for i, threshold in enumerate([127]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) def set_prediction(self, prediction_array_identifier): @@ -117,28 +117,31 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, - write_size, ) + read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) - # run blockwise post-processing - sucess = run_blockwise( - worker_file=str( - Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") - ), + input_array = open_ds(self.prediction_array_identifier.container.path,self.prediction_array_identifier.dataset) + + def process_block(block): + print("Predicting block", block.read_roi) + data = to_ndarray(input_array,block.read_roi) > parameters.threshold + if int(data.max()) == 0: + print("No data in block", block.read_roi) + return + save_ndarray(data, block.write_roi, output_array) + + task = daisy.Task( + f"threshold_{output_array.dataset}", total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, - num_workers=num_workers, - max_retries=2, # TODO: make this an option - timeout=None, # TODO: make this an option - ###### - input_array_identifier=self.prediction_array_identifier, - output_array_identifier=output_array_identifier, - threshold=parameters.threshold, + process_function=process_block, + check_function=None, + read_write_conflict=False, + fit="overhang", + max_retries=0, + timeout=None, ) - if not sucess: - raise RuntimeError("Blockwise post-processing failed.") - - return output_array + return daisy.run_blockwise([task], multiprocessing=False) \ No newline at end of file diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 6b9ecf205..372964df7 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -111,7 +111,7 @@ def create_optimizer(self, model): optimizer = torch.optim.RAdam( lr=self.learning_rate, params=model.parameters(), - decoupled_weight_decay=True, + # decoupled_weight_decay=True, ) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, @@ -161,6 +161,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") + dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") + datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -207,9 +209,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder, drop_channels=True, ) - + gp.Pad(raw_key, None) - + gp.Pad(gt_key, None) - + gp.Pad(mask_key, None) + + gp.Pad(raw_key, None, mode="constant", value=0) + + gp.Pad(gt_key, None, mode="constant", value=0) + + gp.Pad(mask_key, None, mode="constant", value=0) + gp.RandomLocation( ensure_nonempty=( sample_points_key if points_source is not None else None @@ -225,6 +227,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) + # Add predictor nodes to dataset_source + dataset_source += DaCapoTargetFilter( + task.predictor, + gt_key=gt_key, + weights_key=dataset_weight_key, + mask_key=mask_key, + ) + dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) @@ -233,10 +243,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=weight_key, + weights_key=datasets_weight_key, mask_key=mask_key, ) + pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) + # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py new file mode 100644 index 000000000..6024bb1a5 --- /dev/null +++ b/dacapo/predict_crop.py @@ -0,0 +1,104 @@ +from dacapo.experiments.model import Model +from dacapo.store.local_array_store import LocalArrayIdentifier +from funlib.persistence import open_ds, prepare_ds, Array +from dacapo.utils.array_utils import to_ndarray +from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray +from funlib.geometry import Coordinate, Roi +import numpy as np +from dacapo.compute_context import create_compute_context +from typing import Optional +import logging +import daisy +import torch +import os +from dacapo.utils.array_utils import to_ndarray, save_ndarray + +logger = logging.getLogger(__name__) + + +def predict( + model: Model, + raw_array_identifier: LocalArrayIdentifier, + prediction_array_identifier: LocalArrayIdentifier, + output_roi: Optional[Roi] = None, + min_raw: float = 0, + max_raw: float = 255, +): + shift = min_raw + scale = max_raw - min_raw + # get the model's input and output size + raw_array = open_ds(raw_array_identifier.container.path, raw_array_identifier.dataset) + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + context = (input_size - output_size) / 2 + + if output_roi is None: + input_roi = raw_array.roi + output_roi = input_roi.grow(-context, -context) + else: + input_roi = output_roi.grow(context, context) + + + read_roi = Roi((0,0,0), input_size) + write_roi = read_roi.grow(-context, -context) + + axes = ["c", "z", "y", "x"] + + num_channels = model.num_out_channels + + result_dataset = ZarrArray.create_from_array_identifier( + prediction_array_identifier, + axes, + output_roi, + num_channels, + output_voxel_size, + np.float32, + ) + + + logger.info("Total input ROI: %s, output ROI: %s", input_size, output_roi) + logger.info("Block read ROI: %s, write ROI: %s", read_roi, write_roi) + + out_container, out_dataset = ( + prediction_array_identifier.container.path, + prediction_array_identifier.dataset, + ) + compute_context = create_compute_context() + device = compute_context.device + + + def predict_fn(block): + raw_input = to_ndarray(raw_array,block.read_roi) + raw_input = 2.0 * (raw_input.astype(np.float32) - shift )/ scale - 1.0 + if len(raw_input.shape) == 3: + raw_input = np.expand_dims(raw_input, (0, 1)) + with torch.no_grad(): + predictions = model.forward(torch.from_numpy(raw_input).float().to(device)).detach().cpu().numpy()[0] + + predictions = (predictions + 1) * 255.0 / 2.0 + print(f"Predicting block {block.read_roi} uniques: {np.unique(predictions)}") + save_ndarray(predictions, block.write_roi, result_dataset) + # result_dataset[block.write_roi] = predictions + + # fixing the input roi to be a multiple of the output voxel size + input_roi = input_roi.snap_to_grid(np.lcm(input_voxel_size, output_voxel_size), mode="shrink") + + task = daisy.Task( + f"predict_{out_container}_{out_dataset}", + total_roi=input_roi, + read_roi=Roi((0, 0, 0), input_size), + write_roi=Roi(context, output_size), + process_function=predict_fn, + check_function=None, + read_write_conflict=False, + fit="overhang", + max_retries=0, + timeout=None, + ) + + return daisy.run_blockwise([task], multiprocessing=False) diff --git a/dacapo/train.py b/dacapo/train.py index a1d884b3d..a9e599970 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -def train(run_name: str): +def train(run_name: str, do_validate=True): """ Train a run @@ -44,7 +44,7 @@ def train(run_name: str): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - return train_run(run) + return train_run(run, do_validate) def train_run(run: Run, do_validate=True): diff --git a/dacapo/utils/array_utils.py b/dacapo/utils/array_utils.py new file mode 100644 index 000000000..d3efaac79 --- /dev/null +++ b/dacapo/utils/array_utils.py @@ -0,0 +1,62 @@ +import numpy as np +from funlib.persistence import Array + +def to_ndarray(result_data, roi, fill_value=0): + """An alternative implementation of `__getitem__` that supports + using fill values to request data that may extend outside the + roi covered by result_data. + + Args: + + roi (`class:Roi`, optional): + + If given, copy only the data represented by this ROI. This is + equivalent to:: + + array[roi].to_ndarray() + + fill_value (scalar, optional): + + The value to use to fill in values that are outside the ROI + provided by this data. Defaults to 0. + """ + + shape = roi.shape / result_data.voxel_size + data = np.zeros( + result_data[result_data.roi].shape[: result_data.n_channel_dims] + shape, dtype=result_data.data.dtype + ) + if fill_value != 0: + data[:] = fill_value + + array = Array(data, roi, result_data.voxel_size) + + shared_roi = result_data.roi.intersect(roi) + + if not shared_roi.empty: + array[shared_roi] = result_data[shared_roi] + + return data + +def save_ndarray(data, roi, array): + """An alternative implementation of `__setitem__` that supports + using fill values to request data that may extend outside the + roi covered by result_data. + + Args: + + roi (`class:Roi`, optional): + + If given, copy only the data represented by this ROI. This is + equivalent to:: + + array[roi] = data + + fill_value (scalar, optional): + + The value to use to fill in values that are outside the ROI + provided by this data. Defaults to 0. + """ + intersection_roi = roi.intersect(array.roi) + if not intersection_roi.empty: + result_array = Array(data, roi, array.voxel_size) + array[intersection_roi] = result_array[intersection_roi] \ No newline at end of file diff --git a/dacapo/utils/view.py b/dacapo/utils/view.py index 203f98cf4..2018ff2a9 100644 --- a/dacapo/utils/view.py +++ b/dacapo/utils/view.py @@ -440,7 +440,7 @@ def open_from_array_identitifier(self, array_identifier): >>> ds = viewer.open_from_array_identitifier(array_identifier) """ if os.path.exists(array_identifier.container / array_identifier.dataset): - return open_ds(str(array_identifier.container), array_identifier.dataset) + return open_ds(str(array_identifier.container.path), array_identifier.dataset) else: return None diff --git a/dacapo/validate.py b/dacapo/validate.py index 44f091e7a..b31a036b1 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -1,109 +1,69 @@ -from dacapo.compute_context import create_compute_context -from .predict import predict +from .predict_crop import predict + from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray -from .store.create_store import ( +from dacapo.store.create_store import ( create_array_store, create_config_store, create_stats_store, create_weights_store, ) +import torch -from upath import UPath as Path +from pathlib import Path import logging -from warnings import warn logger = logging.getLogger(__name__) -def validate_run( - run_name: str, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate a run at a given iteration. Loads the weights from a previously +def validate(run_name: str, iteration: int = 0): + """Validate a run at a given iteration. Loads the weights from a previously stored checkpoint. Returns the best parameters and scores for this - iteration. + iteration.""" + + logger.info("Validating run %s at iteration %d...", run_name, iteration) - Args: - run: The name of the run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays + # create run - """ - # Load the model and weights config_store = create_config_store() run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - compute_context = create_compute_context() - if iteration is not None and not compute_context.distribute_workers: - # create weights store - weights_store = create_weights_store() - - # load weights - run.model.load_state_dict( - weights_store.retrieve_weights(run_name, iteration).model - ) - - return validate( - run=run, - iteration=iteration, - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, - ) - - -def validate( - run: Run, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate a run at a given iteration. Loads the weights from a previously - stored checkpoint. Returns the best parameters and scores for this - iteration. - - Args: - run: The run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays - Returns: - The best parameters and scores for this iteration - Raises: - ValueError: If the run does not have a validation dataset or the dataset does not have ground truth. - Example: - validate(my_run, 1000) - """ - - print(f"Validating run {run.name} at iteration {iteration}...") - - run_name = run.name # read in previous training/validation stats + stats_store = create_stats_store() run.training_stats = stats_store.retrieve_training_stats(run_name) run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( run_name ) + # create weights store and read weights + if iteration > 0: + weights_store = create_weights_store() + weights_store.retrieve_weights(run, iteration) + + return validate_run(run, iteration) + + +def validate_run(run: Run, iteration: int): + """Validate an already loaded run at the given iteration. This does not + load the weights of that iteration, it is assumed that the model is already + loaded correctly. Returns the best parameters and scores for this + iteration.""" + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.eval() + if ( run.datasplit.validate is None or len(run.datasplit.validate) == 0 or run.datasplit.validate[0].gt is None ): - raise ValueError(f"Cannot validate run {run.name} at iteration {iteration}.") + logger.info("Cannot validate run %s. Continuing training!", run.name) + return None, None # get array and weight store + weights_store = create_weights_store() array_store = create_array_store() iteration_scores = [] @@ -112,36 +72,19 @@ def validate( evaluator = run.task.evaluator # Initialize the evaluator with the best scores seen so far - try: - evaluator.set_best(run.validation_scores) - except ValueError: - logger.warn( - f"Could not set best scores for run {run.name} at iteration {iteration}." - ) - + evaluator.set_best(run.validation_scores) for validation_dataset in run.datasplit.validate: - if validation_dataset.gt is None: - logger.error( - "We do not yet support validating on datasets without ground truth" - ) - raise NotImplementedError - - print(f"Validating run {run.name} on dataset {validation_dataset.name}") + assert ( + validation_dataset.gt is not None + ), "We do not yet support validating on datasets without ground truth" + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name + ) ( input_raw_array_identifier, input_gt_array_identifier, ) = array_store.validation_input_arrays(run.name, validation_dataset.name) - - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = validation_dataset.gt.roi - if ( not Path( f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" @@ -150,7 +93,15 @@ def validate( f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" ).exists() ): - print("Copying validation inputs!") + logger.info("Copying validation inputs!") + input_voxel_size = validation_dataset.raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = validation_dataset.gt.roi input_roi = ( output_roi.grow(context, context) @@ -167,7 +118,7 @@ def validate( name=f"{run.name}_validation_raw", write_size=input_size, ) - input_raw[input_roi] = validation_dataset.raw[input_roi].squeeze() + input_raw[input_roi] = validation_dataset.raw[input_roi] input_gt = ZarrArray.create_from_array_identifier( input_gt_array_identifier, validation_dataset.gt.axes, @@ -178,140 +129,81 @@ def validate( name=f"{run.name}_validation_gt", write_size=output_size, ) - input_gt[output_roi] = validation_dataset.gt[output_roi].squeeze() + input_gt[output_roi] = validation_dataset.gt[output_roi] else: - print("validation inputs already copied!") + logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset.name + run.name, iteration, validation_dataset ) - compute_context = create_compute_context() - success = predict( - run, - iteration if compute_context.distribute_workers else None, - input_container=input_raw_array_identifier.container, - input_dataset=input_raw_array_identifier.dataset, - output_path=prediction_array_identifier, - output_roi=validation_dataset.gt.roi, # type: ignore - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, + predict( + run.model, + input_raw_array_identifier, + prediction_array_identifier, + output_roi=validation_dataset.gt.roi, ) - if not success: - logger.error( - f"Could not predict run {run.name} on dataset {validation_dataset.name}." - ) - continue - - print(f"Predicted on dataset {validation_dataset.name}") - post_processor.set_prediction(prediction_array_identifier) - # # set up dict for overall best scores per dataset - # overall_best_scores = {} - # for criterion in run.validation_scores.criteria: - # overall_best_scores[criterion] = evaluator.get_overall_best( - # validation_dataset, criterion - # ) - - # any_overall_best = False - output_array_identifiers = [] dataset_iteration_scores = [] + for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration, str(parameters), validation_dataset.name + run.name, iteration, parameters, validation_dataset ) - output_array_identifiers.append(output_array_identifier) + post_processed_array = post_processor.process( parameters, output_array_identifier ) - try: - scores = evaluator.evaluate( - output_array_identifier, validation_dataset.gt # type: ignore - ) - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) - # for criterion in run.validation_scores.criteria: - # # replace predictions in array with the new better predictions - # if evaluator.is_best( - # validation_dataset, - # parameters, - # criterion, - # scores, - # ): - # # then this is the current best score for this parameter, but not necessarily the overall best - # # initial_best_score = overall_best_scores[criterion] - # current_score = getattr(scores, criterion) - # if not overall_best_scores[criterion] or evaluator.compare( - # current_score, overall_best_scores[criterion], criterion - # ): - # any_overall_best = True - # overall_best_scores[criterion] = current_score - - # # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 - # # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file - # # Either we do that, or we only write out the overall best, regardless of parameters - # best_array_identifier = array_store.best_validation_array( - # run.name, - # criterion, - # index=validation_dataset.name, - # ) - # best_array = ZarrArray.create_from_array_identifier( - # best_array_identifier, - # post_processed_array.axes, - # post_processed_array.roi, - # post_processed_array.num_channels, - # post_processed_array.voxel_size, - # post_processed_array.dtype, - # output_size, - # ) - # best_array[best_array.roi] = post_processed_array[ - # post_processed_array.roi - # ] - # best_array.add_metadata( - # { - # "iteration": iteration, - # criterion: getattr(scores, criterion), - # "parameters_id": parameters.id, - # } - # ) - # weights_store.store_best( - # run.name, - # iteration, - # validation_dataset.name, - # criterion, - # ) - except: - logger.error( - f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.", - exc_info=True, - stack_info=True, - ) - - # if not any_overall_best: - # # We only keep the best outputs as determined by the evaluator - # for output_array_identifier in output_array_identifiers: - # array_store.remove(prediction_array_identifier) - # array_store.remove(output_array_identifier) + scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) + + for criterion in run.validation_scores.criteria: + # replace predictions in array with the new better predictions + if evaluator.is_best( + validation_dataset, + parameters, + criterion, + scores, + ): + best_array_identifier = array_store.best_validation_array( + run.name, criterion, index=validation_dataset.name + ) + best_array = ZarrArray.create_from_array_identifier( + best_array_identifier, + post_processed_array.axes, + post_processed_array.roi, + post_processed_array.num_channels, + post_processed_array.voxel_size, + post_processed_array.dtype, + ) + best_array[best_array.roi] = post_processed_array[ + post_processed_array.roi + ] + best_array.add_metadata( + { + "iteration": iteration, + criterion: getattr(scores, criterion), + "parameters_id": parameters.id, + } + ) + weights_store.store_best( + run, iteration, validation_dataset.name, criterion + ) + + # delete current output. We only keep the best outputs as determined by + # the evaluator + array_store.remove(output_array_identifier) + + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) iteration_scores.append(dataset_iteration_scores) + array_store.remove(prediction_array_identifier) run.validation_scores.add_iteration_scores( ValidationIterationScores(iteration, iteration_scores) ) stats_store = create_stats_store() stats_store.store_validation_iteration_scores(run.name, run.validation_scores) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("run_name", type=str) - parser.add_argument("iteration", type=int) - args = parser.parse_args() - - validate(args.run_name, args.iteration) From a81983b3bdae88f76497217f857f592b7884fa0e Mon Sep 17 00:00:00 2001 From: mzouink Date: Mon, 23 Sep 2024 21:37:40 +0000 Subject: [PATCH 02/17] :art: Format Python code with psf/black --- dacapo/cli.py | 2 +- .../datasets/arrays/resampled_array.py | 2 +- .../datasplits/datasets/arrays/zarr_array.py | 8 +++--- .../threshold_post_processor.py | 10 ++++--- dacapo/predict_crop.py | 28 ++++++++++++------- dacapo/utils/array_utils.py | 7 +++-- dacapo/utils/view.py | 4 ++- 7 files changed, 38 insertions(+), 23 deletions(-) diff --git a/dacapo/cli.py b/dacapo/cli.py index a7966e236..d1f7ab2ae 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -88,7 +88,7 @@ def train(run_name, no_validation): no_validation (bool): Flag to disable validation after training. """ do_validate = not no_validation - dacapo.train(run_name, do_validate=do_validate) + dacapo.train(run_name, do_validate=do_validate) @cli.command() diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py index 4a5dc0208..86367e50b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py @@ -356,4 +356,4 @@ def _source_name(self): Note: This method returns the name of the source array. """ - return self._source_array._source_name() \ No newline at end of file + return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 3413108e8..4379f153f 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -14,7 +14,7 @@ from collections import OrderedDict import logging from upath import UPath as Path -import os +import os import json from typing import Dict, Tuple, Any, Optional, List @@ -275,8 +275,8 @@ def roi(self) -> Roi: """ if self.snap_to_grid is not None: return self._daisy_array.roi.snap_to_grid( - np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink" - ) + np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink" + ) else: return self._daisy_array.roi @@ -690,4 +690,4 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None: """ dataset = zarr.open(self.file_name, mode="a")[self.dataset] for k, v in metadata.items(): - dataset.attrs[k] = v \ No newline at end of file + dataset.attrs[k] = v diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index aaa69e1e9..cb8e1226d 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -119,13 +119,15 @@ def process( np.uint8, ) - read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) - input_array = open_ds(self.prediction_array_identifier.container.path,self.prediction_array_identifier.dataset) + input_array = open_ds( + self.prediction_array_identifier.container.path, + self.prediction_array_identifier.dataset, + ) def process_block(block): print("Predicting block", block.read_roi) - data = to_ndarray(input_array,block.read_roi) > parameters.threshold + data = to_ndarray(input_array, block.read_roi) > parameters.threshold if int(data.max()) == 0: print("No data in block", block.read_roi) return @@ -144,4 +146,4 @@ def process_block(block): timeout=None, ) - return daisy.run_blockwise([task], multiprocessing=False) \ No newline at end of file + return daisy.run_blockwise([task], multiprocessing=False) diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py index 6024bb1a5..23556ad67 100644 --- a/dacapo/predict_crop.py +++ b/dacapo/predict_crop.py @@ -27,7 +27,9 @@ def predict( shift = min_raw scale = max_raw - min_raw # get the model's input and output size - raw_array = open_ds(raw_array_identifier.container.path, raw_array_identifier.dataset) + raw_array = open_ds( + raw_array_identifier.container.path, raw_array_identifier.dataset + ) input_voxel_size = Coordinate(raw_array.voxel_size) output_voxel_size = model.scale(input_voxel_size) @@ -43,8 +45,7 @@ def predict( else: input_roi = output_roi.grow(context, context) - - read_roi = Roi((0,0,0), input_size) + read_roi = Roi((0, 0, 0), input_size) write_roi = read_roi.grow(-context, -context) axes = ["c", "z", "y", "x"] @@ -60,7 +61,6 @@ def predict( np.float32, ) - logger.info("Total input ROI: %s, output ROI: %s", input_size, output_roi) logger.info("Block read ROI: %s, write ROI: %s", read_roi, write_roi) @@ -71,22 +71,30 @@ def predict( compute_context = create_compute_context() device = compute_context.device - def predict_fn(block): - raw_input = to_ndarray(raw_array,block.read_roi) - raw_input = 2.0 * (raw_input.astype(np.float32) - shift )/ scale - 1.0 + raw_input = to_ndarray(raw_array, block.read_roi) + raw_input = 2.0 * (raw_input.astype(np.float32) - shift) / scale - 1.0 if len(raw_input.shape) == 3: raw_input = np.expand_dims(raw_input, (0, 1)) with torch.no_grad(): - predictions = model.forward(torch.from_numpy(raw_input).float().to(device)).detach().cpu().numpy()[0] + predictions = ( + model.forward(torch.from_numpy(raw_input).float().to(device)) + .detach() + .cpu() + .numpy()[0] + ) predictions = (predictions + 1) * 255.0 / 2.0 - print(f"Predicting block {block.read_roi} uniques: {np.unique(predictions)}") + print( + f"Predicting block {block.read_roi} uniques: {np.unique(predictions)}" + ) save_ndarray(predictions, block.write_roi, result_dataset) # result_dataset[block.write_roi] = predictions # fixing the input roi to be a multiple of the output voxel size - input_roi = input_roi.snap_to_grid(np.lcm(input_voxel_size, output_voxel_size), mode="shrink") + input_roi = input_roi.snap_to_grid( + np.lcm(input_voxel_size, output_voxel_size), mode="shrink" + ) task = daisy.Task( f"predict_{out_container}_{out_dataset}", diff --git a/dacapo/utils/array_utils.py b/dacapo/utils/array_utils.py index d3efaac79..6d0293a4c 100644 --- a/dacapo/utils/array_utils.py +++ b/dacapo/utils/array_utils.py @@ -1,6 +1,7 @@ import numpy as np from funlib.persistence import Array + def to_ndarray(result_data, roi, fill_value=0): """An alternative implementation of `__getitem__` that supports using fill values to request data that may extend outside the @@ -23,7 +24,8 @@ def to_ndarray(result_data, roi, fill_value=0): shape = roi.shape / result_data.voxel_size data = np.zeros( - result_data[result_data.roi].shape[: result_data.n_channel_dims] + shape, dtype=result_data.data.dtype + result_data[result_data.roi].shape[: result_data.n_channel_dims] + shape, + dtype=result_data.data.dtype, ) if fill_value != 0: data[:] = fill_value @@ -37,6 +39,7 @@ def to_ndarray(result_data, roi, fill_value=0): return data + def save_ndarray(data, roi, array): """An alternative implementation of `__setitem__` that supports using fill values to request data that may extend outside the @@ -59,4 +62,4 @@ def save_ndarray(data, roi, array): intersection_roi = roi.intersect(array.roi) if not intersection_roi.empty: result_array = Array(data, roi, array.voxel_size) - array[intersection_roi] = result_array[intersection_roi] \ No newline at end of file + array[intersection_roi] = result_array[intersection_roi] diff --git a/dacapo/utils/view.py b/dacapo/utils/view.py index 2018ff2a9..ebbc1b61b 100644 --- a/dacapo/utils/view.py +++ b/dacapo/utils/view.py @@ -440,7 +440,9 @@ def open_from_array_identitifier(self, array_identifier): >>> ds = viewer.open_from_array_identitifier(array_identifier) """ if os.path.exists(array_identifier.container / array_identifier.dataset): - return open_ds(str(array_identifier.container.path), array_identifier.dataset) + return open_ds( + str(array_identifier.container.path), array_identifier.dataset + ) else: return None From 473bc38ca4b7373ea216070725210c601015c7af Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Thu, 26 Sep 2024 11:43:56 -0400 Subject: [PATCH 03/17] all local changes --- dacapo/__init__.py | 1 + .../architectures/cnnectome_unet.py | 7 +- dacapo/experiments/datasplits/__init__.py | 2 +- .../datasplits/datasets/dataset.py | 4 +- .../datasplits/datasplit_generator.py | 2 + .../binary_segmentation_evaluator.py | 5 +- .../threshold_post_processor.py | 2 +- .../threshold_post_processor_parameters.py | 3 + .../experiments/trainers/gunpowder_trainer.py | 31 ++---- dacapo/predict_crop.py | 17 ++- dacapo/submit_predict.py | 79 ++++++++++++++ dacapo/train.py | 4 +- dacapo/validate.py | 103 ++++++++++-------- 13 files changed, 177 insertions(+), 83 deletions(-) create mode 100644 dacapo/submit_predict.py diff --git a/dacapo/__init__.py b/dacapo/__init__.py index f54a1e06d..9051c1b9e 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -8,3 +8,4 @@ from .validate import validate, validate_run # noqa from .predict import predict # noqa from .blockwise import run_blockwise, segment_blockwise # noqa +from .submit_predict import full_predict # noqa \ No newline at end of file diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index acb345c88..dbd084c4c 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -827,7 +827,7 @@ class ConvPass(torch.nn.Module): """ def __init__( - self, in_channels, out_channels, kernel_sizes, activation, padding="valid" + self, in_channels, out_channels, kernel_sizes, activation, padding="valid", batch_normalize=True ): """ Convolutional pass module. This module performs a series of @@ -869,6 +869,11 @@ def __init__( try: layers.append(conv(in_channels, out_channels, kernel_size, padding=pad)) + if batch_normalize: + layers.append({ + 2: torch.nn.BatchNorm2d, + 3: torch.nn.BatchNorm3d, + }[self.dims](out_channels)) except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) diff --git a/dacapo/experiments/datasplits/__init__.py b/dacapo/experiments/datasplits/__init__.py index 0bc9fe628..ad1ad4880 100644 --- a/dacapo/experiments/datasplits/__init__.py +++ b/dacapo/experiments/datasplits/__init__.py @@ -4,4 +4,4 @@ from .dummy_datasplit_config import DummyDataSplitConfig from .train_validate_datasplit import TrainValidateDataSplit from .train_validate_datasplit_config import TrainValidateDataSplitConfig -from .datasplit_generator import DataSplitGenerator +from .datasplit_generator import DataSplitGenerator, DatasetSpec diff --git a/dacapo/experiments/datasplits/datasets/dataset.py b/dacapo/experiments/datasplits/datasets/dataset.py index d3591b447..ef8ad2a1d 100644 --- a/dacapo/experiments/datasplits/datasets/dataset.py +++ b/dacapo/experiments/datasplits/datasets/dataset.py @@ -90,7 +90,7 @@ def __repr__(self) -> str: Notes: This method is used to return the official string representation of the dataset object. """ - return f"Dataset({self.name})" + return f"ds_{self.name.replace('/', '_')}" def __str__(self) -> str: """ @@ -109,7 +109,7 @@ def __str__(self) -> str: Notes: This method is used to return the string representation of the dataset object. """ - return f"Dataset({self.name})" + return f"ds_{self.name.replace('/', '_')}" def _neuroglancer_layers(self, prefix="", exclude_layers=None): """ diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index ec37b0747..a68cf3c97 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -82,6 +82,7 @@ def resize_if_needed( f"have different dimensions {zarr_array.dims}" ) if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]): + print(f"dataset {array_config} needs resampling to {target_resolution}, upsample: {raw_upsample}, downsample: {raw_downsample}") return ResampledArrayConfig( name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled", source_array_config=array_config, @@ -90,6 +91,7 @@ def resize_if_needed( interp_order=False, ) else: + # print(f"dataset {array_config.dataset} does not need resampling found {raw_voxel_size}=={target_resolution}") return array_config diff --git a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py index d6ade542e..178dd8f4b 100644 --- a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py @@ -127,8 +127,9 @@ def evaluate(self, output_array_identifier, evaluation_array): This function is used to evaluate the output array against the evaluation array. """ output_array = ZarrArray.open_from_array_identifier(output_array_identifier) - evaluation_data = evaluation_array[evaluation_array.roi].squeeze() - output_data = output_array[output_array.roi].squeeze() + # removed the .squeeze() because it was used for batch size and now we are feeding 4d c, z, y, x + evaluation_data = evaluation_array[evaluation_array.roi] + output_data = output_array[output_array.roi] print( f"Evaluating binary segmentations on evaluation_data of shape: {evaluation_data.shape}" ) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index aaa69e1e9..8b00ef9e3 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -124,8 +124,8 @@ def process( input_array = open_ds(self.prediction_array_identifier.container.path,self.prediction_array_identifier.dataset) def process_block(block): - print("Predicting block", block.read_roi) data = to_ndarray(input_array,block.read_roi) > parameters.threshold + data = data.astype(np.uint8) if int(data.max()) == 0: print("No data in block", block.read_roi) return diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py index 014fc1ec2..18114a670 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py @@ -19,3 +19,6 @@ class ThresholdPostProcessorParameters(PostProcessorParameters): """ threshold: float = attr.ib(default=0.0) + + def __str__(self): + return f"threshold_{self.threshold}" diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 372964df7..9a86303cf 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -111,7 +111,7 @@ def create_optimizer(self, model): optimizer = torch.optim.RAdam( lr=self.learning_rate, params=model.parameters(), - # decoupled_weight_decay=True, + decoupled_weight_decay=True, ) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, @@ -161,8 +161,6 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER") target_key = gp.ArrayKey("TARGET") - dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT") - datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT") weight_key = gp.ArrayKey("WEIGHT") sample_points_key = gp.GraphKey("SAMPLE_POINTS") @@ -209,9 +207,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): mask_placeholder, drop_channels=True, ) - + gp.Pad(raw_key, None, mode="constant", value=0) - + gp.Pad(gt_key, None, mode="constant", value=0) - + gp.Pad(mask_key, None, mode="constant", value=0) + + gp.Pad(raw_key, None) + + gp.Pad(gt_key, None) + + gp.Pad(mask_key, None) + gp.RandomLocation( ensure_nonempty=( sample_points_key if points_source is not None else None @@ -223,18 +221,11 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) dataset_source += gp.Reject(mask_placeholder, 1e-6) + dataset_source += gp.Reject(gt_key,0.1) for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) - # Add predictor nodes to dataset_source - dataset_source += DaCapoTargetFilter( - task.predictor, - gt_key=gt_key, - weights_key=dataset_weight_key, - mask_key=mask_key, - ) - dataset_sources.append(dataset_source) pipeline = tuple(dataset_sources) + gp.RandomProvider(weights) @@ -243,12 +234,10 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): task.predictor, gt_key=gt_key, target_key=target_key, - weights_key=datasets_weight_key, + weights_key=weight_key, mask_key=mask_key, ) - pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key) - # Trainer attributes: if self.num_data_fetchers > 1: pipeline += gp.PreCache(num_workers=self.num_data_fetchers) @@ -363,11 +352,15 @@ def iterate(self, num_iterations, model, optimizer, device): snapshot_array_identifier = ( self.snapshot_container.array_identifier(k) ) + if v.num_channels == 1: + channels = None + else: + channels = v.num_channels ZarrArray.create_from_array_identifier( snapshot_array_identifier, v.axes, v.roi, - v.num_channels, + channels, v.voxel_size, v.dtype if not v.dtype == bool else np.float32, model.output_shape * v.voxel_size, @@ -587,4 +580,4 @@ def load_batch(event): print(viewer) load_batch(None) - input("Enter to quit!") + input("Enter to quit!") \ No newline at end of file diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py index 6024bb1a5..33e26a42e 100644 --- a/dacapo/predict_crop.py +++ b/dacapo/predict_crop.py @@ -21,13 +21,12 @@ def predict( raw_array_identifier: LocalArrayIdentifier, prediction_array_identifier: LocalArrayIdentifier, output_roi: Optional[Roi] = None, - min_raw: float = 0, - max_raw: float = 255, ): - shift = min_raw - scale = max_raw - min_raw # get the model's input and output size - raw_array = open_ds(raw_array_identifier.container.path, raw_array_identifier.dataset) + if isinstance(raw_array_identifier, LocalArrayIdentifier): + raw_array = open_ds(raw_array_identifier.container.path, raw_array_identifier.dataset) + else: + raw_array = raw_array_identifier input_voxel_size = Coordinate(raw_array.voxel_size) output_voxel_size = model.scale(input_voxel_size) @@ -74,14 +73,12 @@ def predict( def predict_fn(block): raw_input = to_ndarray(raw_array,block.read_roi) - raw_input = 2.0 * (raw_input.astype(np.float32) - shift )/ scale - 1.0 - if len(raw_input.shape) == 3: - raw_input = np.expand_dims(raw_input, (0, 1)) + # expend batch dim + # no need to normalize, done by datasplit + raw_input = np.expand_dims(raw_input, (0, 1)) with torch.no_grad(): predictions = model.forward(torch.from_numpy(raw_input).float().to(device)).detach().cpu().numpy()[0] - predictions = (predictions + 1) * 255.0 / 2.0 - print(f"Predicting block {block.read_roi} uniques: {np.unique(predictions)}") save_ndarray(predictions, block.write_roi, result_dataset) # result_dataset[block.write_roi] = predictions diff --git a/dacapo/submit_predict.py b/dacapo/submit_predict.py new file mode 100644 index 000000000..00b3c9907 --- /dev/null +++ b/dacapo/submit_predict.py @@ -0,0 +1,79 @@ +from .predict_crop import predict + +from .experiments import Run, ValidationIterationScores +from .experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.store.create_store import ( + create_array_store, + create_config_store, + create_stats_store, + create_weights_store, +) +import torch + +from pathlib import Path +import logging + +logger = logging.getLogger(__name__) + + +def full_predict(run_name: str, iteration: int = 0, datasets_config=0): + """Validate a run at a given iteration. Loads the weights from a previously + stored checkpoint. Returns the best parameters and scores for this + iteration.""" + array_store = create_array_store() + + logger.info("Validating run %s at iteration %d...", run_name, iteration) + + # create run + + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # read in previous training/validation stats + + stats_store = create_stats_store() + run.training_stats = stats_store.retrieve_training_stats(run_name) + run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( + run_name + ) + + # create weights store and read weights + if iteration > 0: + weights_store = create_weights_store() + weights_store.retrieve_weights(run, iteration) + + + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.eval() + + if (run.datasplit.validate is None + or len(run.datasplit.validate) == 0 + or run.datasplit.validate[0].gt is None + ): + logger.error("Cannot validate run %s. Continuing training!", run.name) + return None, None + + # get array and weight store + + validation_dataset = run.datasplit.validate[datasets_config] + + assert ( + validation_dataset.gt is not None + ), "We do not yet support validating on datasets without ground truth" + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name + ) + + + + prediction_array_identifier = array_store.validation_prediction_array( + run.name, iteration+21, validation_dataset + ) + predict( + run.model, + validation_dataset.raw, + prediction_array_identifier, + output_roi=validation_dataset.raw.roi, + ) diff --git a/dacapo/train.py b/dacapo/train.py index a9e599970..eb0b5ffd2 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -6,7 +6,7 @@ create_weights_store, ) from dacapo.experiments import Run -from dacapo.validate import validate +from dacapo.validate import validate_run import torch from tqdm import tqdm @@ -196,7 +196,7 @@ def train_run(run: Run, do_validate=True): ) validate_thread.start() else: - validate( + validate_run( run, iteration_stats.iteration + 1, ) diff --git a/dacapo/validate.py b/dacapo/validate.py index b31a036b1..5e088b31a 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def validate(run_name: str, iteration: int = 0): +def validate(run_name: str, iteration: int = 0, datasets_config=None): """Validate a run at a given iteration. Loads the weights from a previously stored checkpoint. Returns the best parameters and scores for this iteration.""" @@ -42,10 +42,10 @@ def validate(run_name: str, iteration: int = 0): weights_store = create_weights_store() weights_store.retrieve_weights(run, iteration) - return validate_run(run, iteration) + return validate_run(run, iteration, datasets_config) -def validate_run(run: Run, iteration: int): +def validate_run(run: Run, iteration: int, datasets_config=None): """Validate an already loaded run at the given iteration. This does not load the weights of that iteration, it is assumed that the model is already loaded correctly. Returns the best parameters and scores for this @@ -54,16 +54,14 @@ def validate_run(run: Run, iteration: int): torch.backends.cudnn.benchmark = True run.model.eval() - if ( - run.datasplit.validate is None + if (run.datasplit.validate is None or len(run.datasplit.validate) == 0 or run.datasplit.validate[0].gt is None ): - logger.info("Cannot validate run %s. Continuing training!", run.name) + logger.error("Cannot validate run %s. Continuing training!", run.name) return None, None # get array and weight store - weights_store = create_weights_store() array_store = create_array_store() iteration_scores = [] @@ -71,9 +69,25 @@ def validate_run(run: Run, iteration: int): post_processor = run.task.post_processor evaluator = run.task.evaluator + input_voxel_size = run.datasplit.train[0].raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) + # Initialize the evaluator with the best scores seen so far - evaluator.set_best(run.validation_scores) - for validation_dataset in run.datasplit.validate: + # evaluator.set_best(run.validation_scores) + if datasets_config is None: + datasets = run.datasplit.validate + else: + from dacapo.experiments.datasplits import DataSplitGenerator + datasplit_config = DataSplitGenerator( + "", + datasets_config, + input_voxel_size, + output_voxel_size, + targets = run.task.evaluator.channels + ).compute().validate_configs + datasets = [validate_config.dataset_type(validate_config) for validate_config in datasplit_config] + + for validation_dataset in datasets: assert ( validation_dataset.gt is not None ), "We do not yet support validating on datasets without ground truth" @@ -94,8 +108,7 @@ def validate_run(run: Run, iteration: int): ).exists() ): logger.info("Copying validation inputs!") - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) + input_shape = run.model.eval_input_shape input_size = input_voxel_size * input_shape output_shape = run.model.compute_output_shape(input_shape)[1] @@ -158,49 +171,49 @@ def validate_run(run: Run, iteration: int): scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) - for criterion in run.validation_scores.criteria: - # replace predictions in array with the new better predictions - if evaluator.is_best( - validation_dataset, - parameters, - criterion, - scores, - ): - best_array_identifier = array_store.best_validation_array( - run.name, criterion, index=validation_dataset.name - ) - best_array = ZarrArray.create_from_array_identifier( - best_array_identifier, - post_processed_array.axes, - post_processed_array.roi, - post_processed_array.num_channels, - post_processed_array.voxel_size, - post_processed_array.dtype, - ) - best_array[best_array.roi] = post_processed_array[ - post_processed_array.roi - ] - best_array.add_metadata( - { - "iteration": iteration, - criterion: getattr(scores, criterion), - "parameters_id": parameters.id, - } - ) - weights_store.store_best( - run, iteration, validation_dataset.name, criterion - ) + # for criterion in run.validation_scores.criteria: + # # replace predictions in array with the new better predictions + # if evaluator.is_best( + # validation_dataset, + # parameters, + # criterion, + # scores, + # ): + # best_array_identifier = array_store.best_validation_array( + # run.name, criterion, index=validation_dataset.name + # ) + # best_array = ZarrArray.create_from_array_identifier( + # best_array_identifier, + # post_processed_array.axes, + # post_processed_array.roi, + # post_processed_array.num_channels, + # post_processed_array.voxel_size, + # post_processed_array.dtype, + # ) + # best_array[best_array.roi] = post_processed_array[ + # post_processed_array.roi + # ] + # best_array.add_metadata( + # { + # "iteration": iteration, + # criterion: getattr(scores, criterion), + # "parameters_id": parameters.id, + # } + # ) + # weights_store.store_best( + # run, iteration, validation_dataset.name, criterion + # ) # delete current output. We only keep the best outputs as determined by # the evaluator - array_store.remove(output_array_identifier) + # array_store.remove(output_array_identifier) dataset_iteration_scores.append( [getattr(scores, criterion) for criterion in scores.criteria] ) iteration_scores.append(dataset_iteration_scores) - array_store.remove(prediction_array_identifier) + # array_store.remove(prediction_array_identifier) run.validation_scores.add_iteration_scores( ValidationIterationScores(iteration, iteration_scores) From 1eca4e8255f0f523785781d727b7dca92bc9a62f Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Thu, 26 Sep 2024 17:12:19 -0400 Subject: [PATCH 04/17] local changes --- dacapo/predict_crop.py | 4 +- dacapo/submit_predict.py | 90 ++++++++++++++++++++++++++++++---------- dacapo/validate.py | 13 ++++-- 3 files changed, 81 insertions(+), 26 deletions(-) diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py index b1bb31f42..a53f38a5a 100644 --- a/dacapo/predict_crop.py +++ b/dacapo/predict_crop.py @@ -55,7 +55,7 @@ def predict( output_roi, num_channels, output_voxel_size, - np.float32, + np.uint8, ) logger.info("Total input ROI: %s, output ROI: %s", input_size, output_roi) @@ -76,6 +76,8 @@ def predict_fn(block): with torch.no_grad(): predictions = model.forward(torch.from_numpy(raw_input).float().to(device)).detach().cpu().numpy()[0] predictions = (predictions + 1) * 255.0 / 2.0 + predictions[predictions> 250] = 0 + predictions = np.round(predictions).astype(np.uint8) save_ndarray(predictions, block.write_roi, result_dataset) # result_dataset[block.write_roi] = predictions diff --git a/dacapo/submit_predict.py b/dacapo/submit_predict.py index 00b3c9907..178a90d7b 100644 --- a/dacapo/submit_predict.py +++ b/dacapo/submit_predict.py @@ -12,15 +12,14 @@ from pathlib import Path import logging - +from dacapo.compute_context import create_compute_context logger = logging.getLogger(__name__) -def full_predict(run_name: str, iteration: int = 0, datasets_config=0): +def full_predict(run_name: str, iteration: int , roi): """Validate a run at a given iteration. Loads the weights from a previously stored checkpoint. Returns the best parameters and scores for this iteration.""" - array_store = create_array_store() logger.info("Validating run %s at iteration %d...", run_name, iteration) @@ -30,6 +29,10 @@ def full_predict(run_name: str, iteration: int = 0, datasets_config=0): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) + compute_context = create_compute_context() + device = compute_context.device + run.model.to(device) + # read in previous training/validation stats stats_store = create_stats_store() @@ -41,9 +44,17 @@ def full_predict(run_name: str, iteration: int = 0, datasets_config=0): # create weights store and read weights if iteration > 0: weights_store = create_weights_store() - weights_store.retrieve_weights(run, iteration) + weights = weights_store.retrieve_weights(run, iteration) + run.model.load_state_dict(weights.model) + return full_predict_run(run, iteration,roi) + +def full_predict_run(run: Run, iteration: int,roi): + """Validate an already loaded run at the given iteration. This does not + load the weights of that iteration, it is assumed that the model is already + loaded correctly. Returns the best parameters and scores for this + iteration.""" # set benchmark flag to True for performance torch.backends.cudnn.benchmark = True run.model.eval() @@ -56,24 +67,61 @@ def full_predict(run_name: str, iteration: int = 0, datasets_config=0): return None, None # get array and weight store + array_store = create_array_store() - validation_dataset = run.datasplit.validate[datasets_config] + input_voxel_size = run.datasplit.train[0].raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) - assert ( - validation_dataset.gt is not None - ), "We do not yet support validating on datasets without ground truth" - logger.info( - "Validating run %s on dataset %s", run.name, validation_dataset.name - ) + # Initialize the evaluator with the best scores seen so far + # evaluator.set_best(run.validation_scores) - + datasets = run.datasplit.validate + for validation_dataset in datasets: + assert ( + validation_dataset.gt is not None + ), "We do not yet support validating on datasets without ground truth" + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name + ) - prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration+21, validation_dataset - ) - predict( - run.model, - validation_dataset.raw, - prediction_array_identifier, - output_roi=validation_dataset.raw.roi, - ) + ( + input_raw_array_identifier, + input_gt_array_identifier, + ) = array_store.validation_input_arrays(run.name, validation_dataset.name) + + logger.info("Copying validation inputs!") + + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = roi + + input_roi = ( + output_roi.grow(context, context) + .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") + .intersect(validation_dataset.raw.roi) + ) + input_raw_array_identifier.dataset = "tmp_"+input_raw_array_identifier.dataset + input_raw = ZarrArray.create_from_array_identifier( + input_raw_array_identifier, + validation_dataset.raw.axes, + input_roi, + validation_dataset.raw.num_channels, + validation_dataset.raw.voxel_size, + validation_dataset.raw.dtype, + name=f"{run.name}_validation_raw", + write_size=input_size, + ) + input_raw[input_roi] = validation_dataset.raw[input_roi] + + prediction_array_identifier = array_store.validation_prediction_array( + run.name, iteration+224, validation_dataset + ) + predict( + run.model, + input_raw_array_identifier, + prediction_array_identifier, + output_roi=roi, + ) \ No newline at end of file diff --git a/dacapo/validate.py b/dacapo/validate.py index 5e088b31a..e90fa8984 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -12,7 +12,7 @@ from pathlib import Path import logging - +from dacapo.compute_context import create_compute_context logger = logging.getLogger(__name__) @@ -29,6 +29,10 @@ def validate(run_name: str, iteration: int = 0, datasets_config=None): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) + compute_context = create_compute_context() + device = compute_context.device + run.model.to(device) + # read in previous training/validation stats stats_store = create_stats_store() @@ -40,7 +44,8 @@ def validate(run_name: str, iteration: int = 0, datasets_config=None): # create weights store and read weights if iteration > 0: weights_store = create_weights_store() - weights_store.retrieve_weights(run, iteration) + weights = weights_store.retrieve_weights(run, iteration) + run.model.load_state_dict(weights.model) return validate_run(run, iteration, datasets_config) @@ -147,7 +152,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset + run.name, iteration+3, validation_dataset ) predict( run.model, @@ -162,7 +167,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration, parameters, validation_dataset + run.name, iteration+3, parameters, validation_dataset ) post_processed_array = post_processor.process( From 4520c3f462eabe3f108731ea8f62dcd66ff077f5 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 11:40:34 -0400 Subject: [PATCH 05/17] remove plot debug prints --- dacapo/plot.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/dacapo/plot.py b/dacapo/plot.py index d5bfe1d28..4795a4198 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -148,9 +148,7 @@ def bokeh_plot_runs( >>> plot_runs(["run_name"], 100, None, None, [True]) """ - print("PLOTTING RUNS") runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) - print("GOT RUNS INFO") colors = itertools.cycle(palette[20]) loss_tooltips = [ @@ -213,7 +211,6 @@ def bokeh_plot_runs( validation_figure.background_fill_color = "#efefef" validation_figures[dataset.name] = validation_figure - print("VALIDATION SCORES TOOLTIP MADE") summary_tooltips = [ ("run", "@run"), @@ -246,16 +243,12 @@ def bokeh_plot_runs( iterations = [stat.iteration for stat in run.training_stats.iteration_stats] losses = [stat.loss for stat in run.training_stats.iteration_stats] - print(f"Run {run.name} has {len(losses)} iterations") if run.plot_loss: include_loss_figure = True smooth = int(np.maximum(len(iterations) / 2500, 1)) - print(f"smoothing: {smooth}") x, _ = smooth_values(iterations, smooth, stride=smooth) y, s = smooth_values(losses, smooth, stride=smooth) - print(x, y) - print(f"plotting {(len(x), len(y))} points") source = bokeh.plotting.ColumnDataSource( { "iteration": x, @@ -284,8 +277,6 @@ def bokeh_plot_runs( alpha=0.3, ) - print("LOSS PLOTTED") - if run.validation_score_name and run.validation_scores.validated_until() > 0: validation_score_data = run.validation_scores.to_xarray().sel( criteria=run.validation_score_name @@ -331,7 +322,6 @@ def bokeh_plot_runs( color=color, alpha=0.7, ) - print("VALIDATION PLOTTED") # Styling # training @@ -383,7 +373,6 @@ def bokeh_plot_runs( plot = bokeh.layouts.column(*figures) plot.sizing_mode = "scale_width" - print("PLOTTING DONE") if return_json: print("Returning JSON") return json.dumps(json_item(plot, "myplot")) @@ -410,9 +399,7 @@ def plot_runs( Returns: None """ - print("PLOTTING RUNS") runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) - print("GOT RUNS INFO") colors = itertools.cycle(plt.cm.tab20.colors) include_validation_figure = False @@ -429,18 +416,13 @@ def plot_runs( iterations = [stat.iteration for stat in run.training_stats.iteration_stats] losses = [stat.loss for stat in run.training_stats.iteration_stats] - print(f"Run {run.name} has {len(losses)} iterations") if run.plot_loss: include_loss_figure = True smooth = int(np.maximum(len(iterations) / 2500, 1)) - print(f"smoothing: {smooth}") x, _ = smooth_values(iterations, smooth, stride=smooth) y, s = smooth_values(losses, smooth, stride=smooth) - print(x, y) - print(f"plotting {(len(x), len(y))} points") loss_ax.plot(x, y, label=name, color=color) - print("LOSS PLOTTED") if run.validation_score_name and run.validation_scores.validated_until() > 0: validation_score_data = run.validation_scores.to_xarray().sel( @@ -463,7 +445,6 @@ def plot_runs( color=cc, alpha=0.5 + 0.2 * i, ) - print("VALIDATION PLOTTED") if include_loss_figure: loss_ax.set_title("Training") From aca7f684348ea4c495e88b4a14119721a3968d9d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 11:42:35 -0400 Subject: [PATCH 06/17] env DACAPO_OPTIONS_FILE --- dacapo/options.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dacapo/options.py b/dacapo/options.py index e52a0c57a..13afe6e68 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -133,8 +133,8 @@ def config_file(cls) -> Optional[Path]: PosixPath('/home/user/.config/dacapo/dacapo.yaml') """ env_dict = dict(os.environ) - if "OPTIONS_FILE" in env_dict: - options_files = [Path(env_dict["OPTIONS_FILE"])] + if "DACAPO_OPTIONS_FILE" in env_dict: + options_files = [Path(env_dict["DACAPO_OPTIONS_FILE"])] else: options_files = [] @@ -147,7 +147,7 @@ def config_file(cls) -> Optional[Path]: ] for path in options_files: if path.exists(): - os.environ["OPTIONS_FILE"] = str(path) + os.environ["DACAPO_OPTIONS_FILE"] = str(path) return path return None From 9dfbdc1f6c6f41b975de6505502aa615acc573e5 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 11:43:20 -0400 Subject: [PATCH 07/17] remove trash --- dacapo/submit_predict.py | 127 --------------------------------------- 1 file changed, 127 deletions(-) delete mode 100644 dacapo/submit_predict.py diff --git a/dacapo/submit_predict.py b/dacapo/submit_predict.py deleted file mode 100644 index 178a90d7b..000000000 --- a/dacapo/submit_predict.py +++ /dev/null @@ -1,127 +0,0 @@ -from .predict_crop import predict - -from .experiments import Run, ValidationIterationScores -from .experiments.datasplits.datasets.arrays import ZarrArray -from dacapo.store.create_store import ( - create_array_store, - create_config_store, - create_stats_store, - create_weights_store, -) -import torch - -from pathlib import Path -import logging -from dacapo.compute_context import create_compute_context -logger = logging.getLogger(__name__) - - -def full_predict(run_name: str, iteration: int , roi): - """Validate a run at a given iteration. Loads the weights from a previously - stored checkpoint. Returns the best parameters and scores for this - iteration.""" - - logger.info("Validating run %s at iteration %d...", run_name, iteration) - - # create run - - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) - - compute_context = create_compute_context() - device = compute_context.device - run.model.to(device) - - # read in previous training/validation stats - - stats_store = create_stats_store() - run.training_stats = stats_store.retrieve_training_stats(run_name) - run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( - run_name - ) - - # create weights store and read weights - if iteration > 0: - weights_store = create_weights_store() - weights = weights_store.retrieve_weights(run, iteration) - run.model.load_state_dict(weights.model) - - return full_predict_run(run, iteration,roi) - - -def full_predict_run(run: Run, iteration: int,roi): - """Validate an already loaded run at the given iteration. This does not - load the weights of that iteration, it is assumed that the model is already - loaded correctly. Returns the best parameters and scores for this - iteration.""" - # set benchmark flag to True for performance - torch.backends.cudnn.benchmark = True - run.model.eval() - - if (run.datasplit.validate is None - or len(run.datasplit.validate) == 0 - or run.datasplit.validate[0].gt is None - ): - logger.error("Cannot validate run %s. Continuing training!", run.name) - return None, None - - # get array and weight store - array_store = create_array_store() - - input_voxel_size = run.datasplit.train[0].raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - - # Initialize the evaluator with the best scores seen so far - # evaluator.set_best(run.validation_scores) - - datasets = run.datasplit.validate - for validation_dataset in datasets: - assert ( - validation_dataset.gt is not None - ), "We do not yet support validating on datasets without ground truth" - logger.info( - "Validating run %s on dataset %s", run.name, validation_dataset.name - ) - - ( - input_raw_array_identifier, - input_gt_array_identifier, - ) = array_store.validation_input_arrays(run.name, validation_dataset.name) - - logger.info("Copying validation inputs!") - - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = roi - - input_roi = ( - output_roi.grow(context, context) - .snap_to_grid(validation_dataset.raw.voxel_size, mode="grow") - .intersect(validation_dataset.raw.roi) - ) - input_raw_array_identifier.dataset = "tmp_"+input_raw_array_identifier.dataset - input_raw = ZarrArray.create_from_array_identifier( - input_raw_array_identifier, - validation_dataset.raw.axes, - input_roi, - validation_dataset.raw.num_channels, - validation_dataset.raw.voxel_size, - validation_dataset.raw.dtype, - name=f"{run.name}_validation_raw", - write_size=input_size, - ) - input_raw[input_roi] = validation_dataset.raw[input_roi] - - prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration+224, validation_dataset - ) - predict( - run.model, - input_raw_array_identifier, - prediction_array_identifier, - output_roi=roi, - ) \ No newline at end of file From 5de82f2a6942fac77cf6fcf974279fe9e9e614b3 Mon Sep 17 00:00:00 2001 From: mzouink Date: Wed, 2 Oct 2024 15:43:52 +0000 Subject: [PATCH 08/17] :art: Format Python code with psf/black --- dacapo/__init__.py | 2 +- .../architectures/cnnectome_unet.py | 20 ++++++++--- .../datasplits/datasplit_generator.py | 4 ++- .../threshold_post_processor.py | 2 +- .../experiments/trainers/gunpowder_trainer.py | 4 +-- dacapo/plot.py | 3 -- dacapo/predict_crop.py | 15 +++++--- dacapo/validate.py | 34 ++++++++++++------- 8 files changed, 55 insertions(+), 29 deletions(-) diff --git a/dacapo/__init__.py b/dacapo/__init__.py index 9051c1b9e..b8b6960d1 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -8,4 +8,4 @@ from .validate import validate, validate_run # noqa from .predict import predict # noqa from .blockwise import run_blockwise, segment_blockwise # noqa -from .submit_predict import full_predict # noqa \ No newline at end of file +from .submit_predict import full_predict # noqa diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index dbd084c4c..32c4b822b 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -827,7 +827,13 @@ class ConvPass(torch.nn.Module): """ def __init__( - self, in_channels, out_channels, kernel_sizes, activation, padding="valid", batch_normalize=True + self, + in_channels, + out_channels, + kernel_sizes, + activation, + padding="valid", + batch_normalize=True, ): """ Convolutional pass module. This module performs a series of @@ -870,10 +876,14 @@ def __init__( try: layers.append(conv(in_channels, out_channels, kernel_size, padding=pad)) if batch_normalize: - layers.append({ - 2: torch.nn.BatchNorm2d, - 3: torch.nn.BatchNorm3d, - }[self.dims](out_channels)) + layers.append( + { + 2: torch.nn.BatchNorm2d, + 3: torch.nn.BatchNorm3d, + }[ + self.dims + ](out_channels) + ) except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index a68cf3c97..798559d50 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -82,7 +82,9 @@ def resize_if_needed( f"have different dimensions {zarr_array.dims}" ) if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]): - print(f"dataset {array_config} needs resampling to {target_resolution}, upsample: {raw_upsample}, downsample: {raw_downsample}") + print( + f"dataset {array_config} needs resampling to {target_resolution}, upsample: {raw_upsample}, downsample: {raw_downsample}" + ) return ResampledArrayConfig( name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled", source_array_config=array_config, diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index a0f5541c4..2cf719d44 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -126,7 +126,7 @@ def process( ) def process_block(block): - data = to_ndarray(input_array,block.read_roi) > parameters.threshold + data = to_ndarray(input_array, block.read_roi) > parameters.threshold data = data.astype(np.uint8) if int(data.max()) == 0: print("No data in block", block.read_roi) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 9a86303cf..e7705b4e1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -221,7 +221,7 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) dataset_source += gp.Reject(mask_placeholder, 1e-6) - dataset_source += gp.Reject(gt_key,0.1) + dataset_source += gp.Reject(gt_key, 0.1) for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) @@ -580,4 +580,4 @@ def load_batch(event): print(viewer) load_batch(None) - input("Enter to quit!") \ No newline at end of file + input("Enter to quit!") diff --git a/dacapo/plot.py b/dacapo/plot.py index 4795a4198..8bcdc9dc4 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -211,7 +211,6 @@ def bokeh_plot_runs( validation_figure.background_fill_color = "#efefef" validation_figures[dataset.name] = validation_figure - summary_tooltips = [ ("run", "@run"), ("task", "@task"), @@ -243,7 +242,6 @@ def bokeh_plot_runs( iterations = [stat.iteration for stat in run.training_stats.iteration_stats] losses = [stat.loss for stat in run.training_stats.iteration_stats] - if run.plot_loss: include_loss_figure = True smooth = int(np.maximum(len(iterations) / 2500, 1)) @@ -416,7 +414,6 @@ def plot_runs( iterations = [stat.iteration for stat in run.training_stats.iteration_stats] losses = [stat.loss for stat in run.training_stats.iteration_stats] - if run.plot_loss: include_loss_figure = True smooth = int(np.maximum(len(iterations) / 2500, 1)) diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py index a53f38a5a..8a47bd4f6 100644 --- a/dacapo/predict_crop.py +++ b/dacapo/predict_crop.py @@ -24,7 +24,9 @@ def predict( ): # get the model's input and output size if isinstance(raw_array_identifier, LocalArrayIdentifier): - raw_array = open_ds(raw_array_identifier.container.path, raw_array_identifier.dataset) + raw_array = open_ds( + raw_array_identifier.container.path, raw_array_identifier.dataset + ) else: raw_array = raw_array_identifier input_voxel_size = Coordinate(raw_array.voxel_size) @@ -69,14 +71,19 @@ def predict( device = compute_context.device def predict_fn(block): - raw_input = to_ndarray(raw_array,block.read_roi) + raw_input = to_ndarray(raw_array, block.read_roi) # expend batch dim # no need to normalize, done by datasplit raw_input = np.expand_dims(raw_input, (0, 1)) with torch.no_grad(): - predictions = model.forward(torch.from_numpy(raw_input).float().to(device)).detach().cpu().numpy()[0] + predictions = ( + model.forward(torch.from_numpy(raw_input).float().to(device)) + .detach() + .cpu() + .numpy()[0] + ) predictions = (predictions + 1) * 255.0 / 2.0 - predictions[predictions> 250] = 0 + predictions[predictions > 250] = 0 predictions = np.round(predictions).astype(np.uint8) save_ndarray(predictions, block.write_roi, result_dataset) diff --git a/dacapo/validate.py b/dacapo/validate.py index e90fa8984..4e7902a19 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -13,6 +13,7 @@ from pathlib import Path import logging from dacapo.compute_context import create_compute_context + logger = logging.getLogger(__name__) @@ -59,7 +60,8 @@ def validate_run(run: Run, iteration: int, datasets_config=None): torch.backends.cudnn.benchmark = True run.model.eval() - if (run.datasplit.validate is None + if ( + run.datasplit.validate is None or len(run.datasplit.validate) == 0 or run.datasplit.validate[0].gt is None ): @@ -83,14 +85,22 @@ def validate_run(run: Run, iteration: int, datasets_config=None): datasets = run.datasplit.validate else: from dacapo.experiments.datasplits import DataSplitGenerator - datasplit_config = DataSplitGenerator( - "", - datasets_config, - input_voxel_size, - output_voxel_size, - targets = run.task.evaluator.channels - ).compute().validate_configs - datasets = [validate_config.dataset_type(validate_config) for validate_config in datasplit_config] + + datasplit_config = ( + DataSplitGenerator( + "", + datasets_config, + input_voxel_size, + output_voxel_size, + targets=run.task.evaluator.channels, + ) + .compute() + .validate_configs + ) + datasets = [ + validate_config.dataset_type(validate_config) + for validate_config in datasplit_config + ] for validation_dataset in datasets: assert ( @@ -113,7 +123,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): ).exists() ): logger.info("Copying validation inputs!") - + input_shape = run.model.eval_input_shape input_size = input_voxel_size * input_shape output_shape = run.model.compute_output_shape(input_shape)[1] @@ -152,7 +162,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration+3, validation_dataset + run.name, iteration + 3, validation_dataset ) predict( run.model, @@ -167,7 +177,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration+3, parameters, validation_dataset + run.name, iteration + 3, parameters, validation_dataset ) post_processed_array = post_processor.process( From 2775317e20fd711e3482bde2e511d213521f14e9 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 11:59:44 -0400 Subject: [PATCH 09/17] fix restart run --- dacapo/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dacapo/train.py b/dacapo/train.py index eb0b5ffd2..bfd06eeb1 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -82,6 +82,7 @@ def train_run(run: Run, do_validate=True): weights_store = create_weights_store() latest_weights_iteration = weights_store.latest_iteration(run) + weights = None if trained_until > 0: if latest_weights_iteration is None: @@ -104,19 +105,21 @@ def train_run(run: Run, do_validate=True): trained_until = latest_weights_iteration run.training_stats.delete_after(trained_until) run.validation_scores.delete_after(trained_until) - weights_store.retrieve_weights(run, iteration=trained_until) + weights = weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration == trained_until: print(f"Resuming training from iteration {trained_until}") - weights_store.retrieve_weights(run, iteration=trained_until) + weights = weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + weights = weights_store.retrieve_weights(run, iteration=latest_weights_iteration) logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " ) + if weights is not None: + run.model.load_state_dict(weights.model) # start/resume training From 35d864456e6150046e77a534cf22056dd1fa0072 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 12:12:53 -0400 Subject: [PATCH 10/17] batch norm params --- dacapo/__init__.py | 1 - .../architectures/cnnectome_unet.py | 20 +++++++++++++++---- .../architectures/cnnectome_unet_config.py | 4 ++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/dacapo/__init__.py b/dacapo/__init__.py index b8b6960d1..f54a1e06d 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -8,4 +8,3 @@ from .validate import validate, validate_run # noqa from .predict import predict # noqa from .blockwise import run_blockwise, segment_blockwise # noqa -from .submit_predict import full_predict # noqa diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 32c4b822b..d1bcd208c 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -130,6 +130,7 @@ def __init__(self, architecture_config): activation after the upsample operation. - use_attention (optional): Whether or not to use an attention block in the U-Net. + - batch_norm (optional): Whether to use batch normalization. Raises: ValueError: If the input shape is not given. Examples: @@ -170,6 +171,7 @@ def __init__(self, architecture_config): self.upsample_factors if self.upsample_factors is not None else [] ) self.use_attention = architecture_config.use_attention + self.batch_norm = architecture_config.batch_norm self.unet = self.module() @@ -261,6 +263,7 @@ def module(self): upsample_channel_contraction=[False] + [True] * (len(downsample_factors) - 1), use_attention=self.use_attention, + batch_norm=self.batch_norm, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -279,6 +282,7 @@ def module(self): self.fmaps_out, [(3,) * len(upsample_factor)] * 2, activation="ReLU", + batch_norm= self.batch_norm, ) layers.append(conv) unet = torch.nn.Sequential(*layers) @@ -455,6 +459,7 @@ def __init__( upsample_channel_contraction=False, activation_on_upsample=False, use_attention=False, + batch_norm=True, ): """ Create a U-Net:: @@ -573,6 +578,7 @@ def __init__( self.dims = len(downsample_factors[0]) self.use_attention = use_attention + self.batch_norm = batch_norm # default arguments @@ -611,6 +617,7 @@ def __init__( kernel_size_down[level], activation=activation, padding=padding, + batch_norm=self.batch_norm, ) for level in range(self.num_levels) ] @@ -668,6 +675,7 @@ def __init__( ), dims=self.dims, upsample_factor=downsample_factors[level], + batch_norm=self.batch_norm, ) for level in range(self.num_levels - 1) ] @@ -694,6 +702,7 @@ def __init__( kernel_size_up[level], activation=activation, padding=padding, + batch_norm=self.batch_norm, ) for level in range(self.num_levels - 1) ] @@ -833,7 +842,7 @@ def __init__( kernel_sizes, activation, padding="valid", - batch_normalize=True, + batch_norm=True, ): """ Convolutional pass module. This module performs a series of @@ -875,7 +884,7 @@ def __init__( try: layers.append(conv(in_channels, out_channels, kernel_size, padding=pad)) - if batch_normalize: + if batch_norm: layers.append( { 2: torch.nn.BatchNorm2d, @@ -1298,7 +1307,7 @@ class AttentionBlockModule(nn.Module): The AttentionBlockModule is an instance of the ``torch.nn.Module`` class. """ - def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): """ Initialize the Attention Block Module. @@ -1321,13 +1330,14 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): super(AttentionBlockModule, self).__init__() self.dims = dims self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] + self.batch_norm = batch_norm if upsample_factor is not None: self.upsample_factor = upsample_factor else: self.upsample_factor = (2,) * self.dims self.W_g = ConvPass( - F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same" + F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same",batch_norm=self.batch_norm ) self.W_x = nn.Sequential( @@ -1337,6 +1347,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): kernel_sizes=self.kernel_sizes, activation=None, padding="same", + batch_norm=self.batch_norm ), Downsample(upsample_factor), ) @@ -1347,6 +1358,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same", + batch_norm=self.batch_norm, ) up_mode = {2: "bilinear", 3: "trilinear"}[self.dims] diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 77905d79c..7eab80115 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -127,3 +127,7 @@ class CNNectomeUNetConfig(ArchitectureConfig): "help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D." }, ) + batch_norm: bool = attr.ib( + default=True, + metadata={"help_text": "Whether to use batch normalization."}, + ) From d372cd8802fb68f3cfbe2f6ad6d74cb7378f0350 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 12:34:47 -0400 Subject: [PATCH 11/17] revert changes --- .../datasplits/datasets/arrays/zarr_array.py | 76 ++++++++++++++----- 1 file changed, 59 insertions(+), 17 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 4379f153f..120f17788 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -13,9 +13,6 @@ from collections import OrderedDict import logging -from upath import UPath as Path -import os -import json from typing import Dict, Tuple, Any, Optional, List logger = logging.getLogger(__name__) @@ -429,12 +426,32 @@ def create_from_array_identifier( num_channels, voxel_size, dtype, + mode="a", write_size=None, name=None, + overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that - this array_identifier points to a dataset that does not yet exist + this array_identifier points to a dataset that does not yet exist. + Args: + array_identifier (ArrayIdentifier): The array identifier. + axes (List[str]): The axes of the array. + roi (Roi): The region of interest. + num_channels (int): The number of channels. + voxel_size (Coordinate): The voxel size. + dtype (Any): The data type. + write_size (Optional[Coordinate]): The write size. + name (Optional[str]): The name of the array. + overwrite (bool): The boolean value to overwrite the array. + Returns: + ZarrArray: The ZarrArray. + Raises: + NotImplementedError + Examples: + >>> create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False) + Notes: + This method is used to create a new ZarrArray given an array identifier. """ if write_size is None: # total storage per block is approx c*x*y*z*dtype_size @@ -451,6 +468,11 @@ def create_from_array_identifier( write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape))) zarr_container = zarr.open(array_identifier.container, "a") + if num_channels is None: + axes = [axis for axis in axes if "c" not in axis] + num_channels = None + else: + axes = ["c"] + [axis for axis in axes if "c" not in axis] try: funlib.persistence.prepare_ds( f"{array_identifier.container}", @@ -460,21 +482,41 @@ def create_from_array_identifier( dtype, num_channels=num_channels, write_size=write_size, + delete=overwrite, + force_exact_write_size=True, ) zarr_dataset = zarr_container[array_identifier.dataset] - zarr_dataset.attrs["offset"] = ( - roi.offset[::-1] - if array_identifier.container.name.endswith("n5") - else roi.offset - ) - zarr_dataset.attrs["resolution"] = ( - voxel_size[::-1] - if array_identifier.container.name.endswith("n5") - else voxel_size - ) - zarr_dataset.attrs["axes"] = ( - axes[::-1] if array_identifier.container.name.endswith("n5") else axes - ) + if array_identifier.container.name.endswith("n5"): + zarr_dataset.attrs["offset"] = roi.offset[::-1] + zarr_dataset.attrs["resolution"] = voxel_size[::-1] + zarr_dataset.attrs["axes"] = axes[::-1] + # to make display right in neuroglancer: TODO ADD CHANNELS + zarr_dataset.attrs["dimension_units"] = [ + f"{size} nm" for size in voxel_size[::-1] + ] + zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [ + a if a != "c" else "c^" for a in axes[::-1] + ] + else: + zarr_dataset.attrs["offset"] = roi.offset + zarr_dataset.attrs["resolution"] = voxel_size + zarr_dataset.attrs["axes"] = axes + # to make display right in neuroglancer: TODO ADD CHANNELS + zarr_dataset.attrs["dimension_units"] = [ + f"{size} nm" for size in voxel_size + ] + zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [ + a if a != "c" else "c^" for a in axes + ] + if "c" in axes: + if axes.index("c") == 0: + zarr_dataset.attrs["dimension_units"] = [ + str(num_channels) + ] + zarr_dataset.attrs["dimension_units"] + else: + zarr_dataset.attrs["dimension_units"] = zarr_dataset.attrs[ + "dimension_units" + ] + [str(num_channels)] except zarr.errors.ContainsArrayError: zarr_dataset = zarr_container[array_identifier.dataset] assert ( From e9285c8e2be5b0325967c21d98aa27f2845f0830 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 12:40:26 -0400 Subject: [PATCH 12/17] gunpowder trainer reject min option --- .../experiments/datasplits/datasets/arrays/zarr_array.py | 3 ++- dacapo/experiments/trainers/gunpowder_trainer.py | 4 +++- dacapo/experiments/trainers/gunpowder_trainer_config.py | 8 ++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 120f17788..88359c783 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -433,7 +433,8 @@ def create_from_array_identifier( ): """ Create a new ZarrArray given an array identifier. It is assumed that - this array_identifier points to a dataset that does not yet exist. + this array_identifier points to a dataset that does not yet exist. + Args: array_identifier (ArrayIdentifier): The array identifier. axes (List[str]): The axes of the array. diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index e14b07811..f0fa83a32 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -91,6 +91,7 @@ def __init__(self, trainer_config): self.augments = trainer_config.augments self.mask_integral_downsample_factor = 4 self.clip_raw = trainer_config.clip_raw + self.gt_min_reject = trainer_config.gt_min_reject self.scheduler = None @@ -221,7 +222,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) dataset_source += gp.Reject(mask_placeholder, 1e-6) - dataset_source += gp.Reject(gt_key, 0.1) + if self.gt_min_reject is not None: + dataset_source += gp.Reject(gt_key, self.gt_min_reject) for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 9793288d8..1469119a2 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -43,3 +43,11 @@ class GunpowderTrainerConfig(TrainerConfig): ) min_masked: Optional[float] = attr.ib(default=0.15) clip_raw: bool = attr.ib(default=False) + gt_min_reject: Optional[float] = attr.ib( + default=None, + metadata={ + "help_text": "The value to use for the GT mask. If None, the value will be " + "determined by the GT mask augment." + "e.g. 0.15" + }, + ) From ae93dbc330c2064e9f2d492a880c5642afb6e0ba Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 12:43:39 -0400 Subject: [PATCH 13/17] remove + uint8 --- dacapo/predict_crop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py index 8a47bd4f6..253d72382 100644 --- a/dacapo/predict_crop.py +++ b/dacapo/predict_crop.py @@ -83,7 +83,7 @@ def predict_fn(block): .numpy()[0] ) predictions = (predictions + 1) * 255.0 / 2.0 - predictions[predictions > 250] = 0 + predictions[predictions > 254] = 0 predictions = np.round(predictions).astype(np.uint8) save_ndarray(predictions, block.write_roi, result_dataset) From 6d7d5d8d6ea059dcb9ff10c21667983cff4ceda1 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 2 Oct 2024 12:48:24 -0400 Subject: [PATCH 14/17] fix validate --- dacapo/validate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dacapo/validate.py b/dacapo/validate.py index 4e7902a19..a78671cfb 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -10,7 +10,7 @@ ) import torch -from pathlib import Path +from upath import UPath as Path import logging from dacapo.compute_context import create_compute_context @@ -162,7 +162,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration + 3, validation_dataset + run.name, iteration, validation_dataset ) predict( run.model, @@ -177,7 +177,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None): for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration + 3, parameters, validation_dataset + run.name, iteration, parameters, validation_dataset ) post_processed_array = post_processor.process( From 6c922daba817c8053b57f2334fe0558cd2caba39 Mon Sep 17 00:00:00 2001 From: mzouink Date: Wed, 2 Oct 2024 16:48:53 +0000 Subject: [PATCH 15/17] :art: Format Python code with psf/black --- dacapo/experiments/architectures/cnnectome_unet.py | 13 +++++++++---- .../datasplits/datasets/arrays/zarr_array.py | 2 +- dacapo/train.py | 4 +++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index d1bcd208c..d89e902ac 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -282,7 +282,7 @@ def module(self): self.fmaps_out, [(3,) * len(upsample_factor)] * 2, activation="ReLU", - batch_norm= self.batch_norm, + batch_norm=self.batch_norm, ) layers.append(conv) unet = torch.nn.Sequential(*layers) @@ -1307,7 +1307,7 @@ class AttentionBlockModule(nn.Module): The AttentionBlockModule is an instance of the ``torch.nn.Module`` class. """ - def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True): """ Initialize the Attention Block Module. @@ -1337,7 +1337,12 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): self.upsample_factor = (2,) * self.dims self.W_g = ConvPass( - F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same",batch_norm=self.batch_norm + F_g, + F_int, + kernel_sizes=self.kernel_sizes, + activation=None, + padding="same", + batch_norm=self.batch_norm, ) self.W_x = nn.Sequential( @@ -1347,7 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): kernel_sizes=self.kernel_sizes, activation=None, padding="same", - batch_norm=self.batch_norm + batch_norm=self.batch_norm, ), Downsample(upsample_factor), ) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 88359c783..f9a26bd09 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -434,7 +434,7 @@ def create_from_array_identifier( """ Create a new ZarrArray given an array identifier. It is assumed that this array_identifier points to a dataset that does not yet exist. - + Args: array_identifier (ArrayIdentifier): The array identifier. axes (List[str]): The axes of the array. diff --git a/dacapo/train.py b/dacapo/train.py index bfd06eeb1..eb28a3cf7 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -113,7 +113,9 @@ def train_run(run: Run, do_validate=True): weights = weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - weights = weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + weights = weights_store.retrieve_weights( + run, iteration=latest_weights_iteration + ) logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " From 6c539e393ded54d358bf4dd10e06df942ee9c3e6 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 16 Oct 2024 15:24:23 -0400 Subject: [PATCH 16/17] fix tests --- tests/components/test_options.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/components/test_options.py b/tests/components/test_options.py index e12d90483..81742aafe 100644 --- a/tests/components/test_options.py +++ b/tests/components/test_options.py @@ -19,8 +19,8 @@ def test_no_config(): # Remove the environment variable env_dict = dict(os.environ) - if "OPTIONS_FILE" in env_dict: - del env_dict["OPTIONS_FILE"] + if "DACAPO_OPTIONS_FILE" in env_dict: + del env_dict["DACAPO_OPTIONS_FILE"] # Parse the options options = Options.instance() @@ -61,7 +61,7 @@ def test_local_config_file(): """ ) ) - os.environ["OPTIONS_FILE"] = str(config_file) + os.environ["DACAPO_OPTIONS_FILE"] = str(config_file) # Parse the options options = Options.instance() From 27f531cba829eb95e90deab2b8869faf1c52762d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 16 Oct 2024 15:26:32 -0400 Subject: [PATCH 17/17] fix validate tests --- tests/operations/test_validate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 1fc2a6e8b..5d74b3797 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -56,13 +56,13 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate_run(run_config.name, 0, num_workers=4) + validate_run(run_config.name, 0) # weights_store.store_weights(run, 1) - # validate_run(run_config.name, 1, num_workers=4) + # validate_run(run_config.name, 1) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate_run(run_config.name, 2, num_workers=4) + validate_run(run_config.name, 2) if debug: os.chdir(old_path)