diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index b2a015a75..7db2d9c27 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -74,7 +74,12 @@ def run_blockwise( ) print("Running blockwise with worker_file: ", worker_file) print(f"Using compute context: {create_compute_context()}") - success = daisy.run_blockwise([task]) + compute_context = create_compute_context() + print(f"Using compute context: {compute_context}") + + multiprocessing = compute_context.distribute_workers + + success = daisy.run_blockwise([task], multiprocessing=multiprocessing) return success diff --git a/dacapo/cli.py b/dacapo/cli.py index 2af9aea77..d1f7ab2ae 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -76,8 +76,19 @@ def cli(log_level): @click.option( "-r", "--run-name", required=True, type=str, help="The NAME of the run to train." ) -def train(run_name): - dacapo.train(run_name) # TODO: run with compute_context +@click.option( + "--no-validation", is_flag=True, help="Disable validation after training." +) +def train(run_name, no_validation): + """ + Train a model with the specified run name. + + Args: + run_name (str): The name of the run to train. + no_validation (bool): Flag to disable validation after training. + """ + do_validate = not no_validation + dacapo.train(run_name, do_validate=do_validate) @cli.command() diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index acb345c88..d89e902ac 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -130,6 +130,7 @@ def __init__(self, architecture_config): activation after the upsample operation. - use_attention (optional): Whether or not to use an attention block in the U-Net. + - batch_norm (optional): Whether to use batch normalization. Raises: ValueError: If the input shape is not given. Examples: @@ -170,6 +171,7 @@ def __init__(self, architecture_config): self.upsample_factors if self.upsample_factors is not None else [] ) self.use_attention = architecture_config.use_attention + self.batch_norm = architecture_config.batch_norm self.unet = self.module() @@ -261,6 +263,7 @@ def module(self): upsample_channel_contraction=[False] + [True] * (len(downsample_factors) - 1), use_attention=self.use_attention, + batch_norm=self.batch_norm, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -279,6 +282,7 @@ def module(self): self.fmaps_out, [(3,) * len(upsample_factor)] * 2, activation="ReLU", + batch_norm=self.batch_norm, ) layers.append(conv) unet = torch.nn.Sequential(*layers) @@ -455,6 +459,7 @@ def __init__( upsample_channel_contraction=False, activation_on_upsample=False, use_attention=False, + batch_norm=True, ): """ Create a U-Net:: @@ -573,6 +578,7 @@ def __init__( self.dims = len(downsample_factors[0]) self.use_attention = use_attention + self.batch_norm = batch_norm # default arguments @@ -611,6 +617,7 @@ def __init__( kernel_size_down[level], activation=activation, padding=padding, + batch_norm=self.batch_norm, ) for level in range(self.num_levels) ] @@ -668,6 +675,7 @@ def __init__( ), dims=self.dims, upsample_factor=downsample_factors[level], + batch_norm=self.batch_norm, ) for level in range(self.num_levels - 1) ] @@ -694,6 +702,7 @@ def __init__( kernel_size_up[level], activation=activation, padding=padding, + batch_norm=self.batch_norm, ) for level in range(self.num_levels - 1) ] @@ -827,7 +836,13 @@ class ConvPass(torch.nn.Module): """ def __init__( - self, in_channels, out_channels, kernel_sizes, activation, padding="valid" + self, + in_channels, + out_channels, + kernel_sizes, + activation, + padding="valid", + batch_norm=True, ): """ Convolutional pass module. This module performs a series of @@ -869,6 +884,15 @@ def __init__( try: layers.append(conv(in_channels, out_channels, kernel_size, padding=pad)) + if batch_norm: + layers.append( + { + 2: torch.nn.BatchNorm2d, + 3: torch.nn.BatchNorm3d, + }[ + self.dims + ](out_channels) + ) except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) @@ -1283,7 +1307,7 @@ class AttentionBlockModule(nn.Module): The AttentionBlockModule is an instance of the ``torch.nn.Module`` class. """ - def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True): """ Initialize the Attention Block Module. @@ -1306,13 +1330,19 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): super(AttentionBlockModule, self).__init__() self.dims = dims self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] + self.batch_norm = batch_norm if upsample_factor is not None: self.upsample_factor = upsample_factor else: self.upsample_factor = (2,) * self.dims self.W_g = ConvPass( - F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same" + F_g, + F_int, + kernel_sizes=self.kernel_sizes, + activation=None, + padding="same", + batch_norm=self.batch_norm, ) self.W_x = nn.Sequential( @@ -1322,6 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): kernel_sizes=self.kernel_sizes, activation=None, padding="same", + batch_norm=self.batch_norm, ), Downsample(upsample_factor), ) @@ -1332,6 +1363,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same", + batch_norm=self.batch_norm, ) up_mode = {2: "bilinear", 3: "trilinear"}[self.dims] diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 77905d79c..7eab80115 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -127,3 +127,7 @@ class CNNectomeUNetConfig(ArchitectureConfig): "help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D." }, ) + batch_norm: bool = attr.ib( + default=True, + metadata={"help_text": "Whether to use batch normalization."}, + ) diff --git a/dacapo/experiments/datasplits/__init__.py b/dacapo/experiments/datasplits/__init__.py index 0bc9fe628..ad1ad4880 100644 --- a/dacapo/experiments/datasplits/__init__.py +++ b/dacapo/experiments/datasplits/__init__.py @@ -4,4 +4,4 @@ from .dummy_datasplit_config import DummyDataSplitConfig from .train_validate_datasplit import TrainValidateDataSplit from .train_validate_datasplit_config import TrainValidateDataSplitConfig -from .datasplit_generator import DataSplitGenerator +from .datasplit_generator import DataSplitGenerator, DatasetSpec diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py index ba6fd99f0..86367e50b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py @@ -171,7 +171,9 @@ def roi(self) -> Roi: This method returns the region of interest of the resampled array. """ - return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink") + return self._source_array.roi.snap_to_grid( + np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink" + ) @property def writable(self) -> bool: @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray: Note: This method returns the data of the resampled array within the given region of interest. """ - snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow") + snapped_roi = roi.snap_to_grid( + np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow" + ) resampled_array = funlib.persistence.Array( rescale( self._source_array[snapped_roi].astype(np.float32), diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 30c6ac693..f9a26bd09 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -13,8 +13,6 @@ from collections import OrderedDict import logging -from upath import UPath as Path -import json from typing import Dict, Tuple, Any, Optional, List logger = logging.getLogger(__name__) @@ -273,7 +271,9 @@ def roi(self) -> Roi: This method is used to return the region of interest of the array. """ if self.snap_to_grid is not None: - return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink") + return self._daisy_array.roi.snap_to_grid( + np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink" + ) else: return self._daisy_array.roi @@ -469,7 +469,7 @@ def create_from_array_identifier( write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape))) zarr_container = zarr.open(array_identifier.container, "a") - if num_channels is None or num_channels == 1: + if num_channels is None: axes = [axis for axis in axes if "c" not in axis] num_channels = None else: diff --git a/dacapo/experiments/datasplits/datasets/dataset.py b/dacapo/experiments/datasplits/datasets/dataset.py index d3591b447..ef8ad2a1d 100644 --- a/dacapo/experiments/datasplits/datasets/dataset.py +++ b/dacapo/experiments/datasplits/datasets/dataset.py @@ -90,7 +90,7 @@ def __repr__(self) -> str: Notes: This method is used to return the official string representation of the dataset object. """ - return f"Dataset({self.name})" + return f"ds_{self.name.replace('/', '_')}" def __str__(self) -> str: """ @@ -109,7 +109,7 @@ def __str__(self) -> str: Notes: This method is used to return the string representation of the dataset object. """ - return f"Dataset({self.name})" + return f"ds_{self.name.replace('/', '_')}" def _neuroglancer_layers(self, prefix="", exclude_layers=None): """ diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index bb1c19472..5858b2a6e 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -85,6 +85,9 @@ def resize_if_needed( f"have different dimensions {zarr_array.dims}" ) if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]): + print( + f"dataset {array_config} needs resampling to {target_resolution}, upsample: {raw_upsample}, downsample: {raw_downsample}" + ) return ResampledArrayConfig( name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled", source_array_config=array_config, @@ -93,6 +96,7 @@ def resize_if_needed( interp_order=False, ) else: + # print(f"dataset {array_config.dataset} does not need resampling found {raw_voxel_size}=={target_resolution}") return array_config @@ -959,23 +963,23 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): constant=1, ) - if len(target_images) > 1: - gt_config = ConcatArrayConfig( - name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt", - channels=[organelle for organelle in current_targets], - # source_array_configs={k: gt for k, gt in target_images.items()}, - source_array_configs={k: target_images[k] for k in current_targets}, - ) - mask_config = ConcatArrayConfig( - name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask", - channels=[organelle for organelle in current_targets], - # source_array_configs={k: mask for k, mask in target_masks.items()}, - # to be sure to have the same order - source_array_configs={k: target_masks[k] for k in current_targets}, - ) - else: - gt_config = list(target_images.values())[0] - mask_config = list(target_masks.values())[0] + # if len(target_images) > 1: + gt_config = ConcatArrayConfig( + name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt", + channels=[organelle for organelle in current_targets], + # source_array_configs={k: gt for k, gt in target_images.items()}, + source_array_configs={k: target_images[k] for k in current_targets}, + ) + mask_config = ConcatArrayConfig( + name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask", + channels=[organelle for organelle in current_targets], + # source_array_configs={k: mask for k, mask in target_masks.items()}, + # to be sure to have the same order + source_array_configs={k: target_masks[k] for k in current_targets}, + ) + # else: + # gt_config = list(target_images.values())[0] + # mask_config = list(target_masks.values())[0] return raw_config, gt_config, mask_config diff --git a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py index d6ade542e..178dd8f4b 100644 --- a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py @@ -127,8 +127,9 @@ def evaluate(self, output_array_identifier, evaluation_array): This function is used to evaluate the output array against the evaluation array. """ output_array = ZarrArray.open_from_array_identifier(output_array_identifier) - evaluation_data = evaluation_array[evaluation_array.roi].squeeze() - output_data = output_array[output_array.roi].squeeze() + # removed the .squeeze() because it was used for batch size and now we are feeding 4d c, z, y, x + evaluation_data = evaluation_array[evaluation_array.roi] + output_data = output_array[output_array.roi] print( f"Evaluating binary segmentations on evaluation_data of shape: {evaluation_data.shape}" ) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index f99c64d3a..2cf719d44 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -1,12 +1,12 @@ -from upath import UPath as Path -from dacapo.blockwise.scheduler import run_blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters from dacapo.store.array_store import LocalArrayIdentifier from .post_processor import PostProcessor -import dacapo.blockwise import numpy as np +import daisy from daisy import Roi, Coordinate +from dacapo.utils.array_utils import to_ndarray, save_ndarray +from funlib.persistence import open_ds from typing import Iterable @@ -43,7 +43,7 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: Note: This method should return a generator of instances of ``ThresholdPostProcessorParameters``. """ - for i, threshold in enumerate([100, 127, 150]): + for i, threshold in enumerate([127]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) def set_prediction(self, prediction_array_identifier): @@ -117,28 +117,33 @@ def process( self.prediction_array.num_channels, self.prediction_array.voxel_size, np.uint8, - write_size, ) read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) - # run blockwise post-processing - sucess = run_blockwise( - worker_file=str( - Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") - ), + input_array = open_ds( + self.prediction_array_identifier.container.path, + self.prediction_array_identifier.dataset, + ) + + def process_block(block): + data = to_ndarray(input_array, block.read_roi) > parameters.threshold + data = data.astype(np.uint8) + if int(data.max()) == 0: + print("No data in block", block.read_roi) + return + save_ndarray(data, block.write_roi, output_array) + + task = daisy.Task( + f"threshold_{output_array.dataset}", total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, - num_workers=num_workers, - max_retries=2, # TODO: make this an option - timeout=None, # TODO: make this an option - ###### - input_array_identifier=self.prediction_array_identifier, - output_array_identifier=output_array_identifier, - threshold=parameters.threshold, + process_function=process_block, + check_function=None, + read_write_conflict=False, + fit="overhang", + max_retries=0, + timeout=None, ) - if not sucess: - raise RuntimeError("Blockwise post-processing failed.") - - return output_array + return daisy.run_blockwise([task], multiprocessing=False) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py index 014fc1ec2..18114a670 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py @@ -19,3 +19,6 @@ class ThresholdPostProcessorParameters(PostProcessorParameters): """ threshold: float = attr.ib(default=0.0) + + def __str__(self): + return f"threshold_{self.threshold}" diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 78324e3a3..f0fa83a32 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -91,6 +91,7 @@ def __init__(self, trainer_config): self.augments = trainer_config.augments self.mask_integral_downsample_factor = 4 self.clip_raw = trainer_config.clip_raw + self.gt_min_reject = trainer_config.gt_min_reject self.scheduler = None @@ -221,6 +222,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): ) dataset_source += gp.Reject(mask_placeholder, 1e-6) + if self.gt_min_reject is not None: + dataset_source += gp.Reject(gt_key, self.gt_min_reject) for augment in self.augments: dataset_source += augment.node(raw_key, gt_key, mask_key) @@ -351,11 +354,15 @@ def iterate(self, num_iterations, model, optimizer, device): snapshot_array_identifier = ( self.snapshot_container.array_identifier(k) ) + if v.num_channels == 1: + channels = None + else: + channels = v.num_channels ZarrArray.create_from_array_identifier( snapshot_array_identifier, v.axes, v.roi, - v.num_channels, + channels, v.voxel_size, v.dtype if not v.dtype == bool else np.float32, model.output_shape * v.voxel_size, diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 9793288d8..1469119a2 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -43,3 +43,11 @@ class GunpowderTrainerConfig(TrainerConfig): ) min_masked: Optional[float] = attr.ib(default=0.15) clip_raw: bool = attr.ib(default=False) + gt_min_reject: Optional[float] = attr.ib( + default=None, + metadata={ + "help_text": "The value to use for the GT mask. If None, the value will be " + "determined by the GT mask augment." + "e.g. 0.15" + }, + ) diff --git a/dacapo/options.py b/dacapo/options.py index e52a0c57a..13afe6e68 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -133,8 +133,8 @@ def config_file(cls) -> Optional[Path]: PosixPath('/home/user/.config/dacapo/dacapo.yaml') """ env_dict = dict(os.environ) - if "OPTIONS_FILE" in env_dict: - options_files = [Path(env_dict["OPTIONS_FILE"])] + if "DACAPO_OPTIONS_FILE" in env_dict: + options_files = [Path(env_dict["DACAPO_OPTIONS_FILE"])] else: options_files = [] @@ -147,7 +147,7 @@ def config_file(cls) -> Optional[Path]: ] for path in options_files: if path.exists(): - os.environ["OPTIONS_FILE"] = str(path) + os.environ["DACAPO_OPTIONS_FILE"] = str(path) return path return None diff --git a/dacapo/plot.py b/dacapo/plot.py index d5bfe1d28..8bcdc9dc4 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -148,9 +148,7 @@ def bokeh_plot_runs( >>> plot_runs(["run_name"], 100, None, None, [True]) """ - print("PLOTTING RUNS") runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) - print("GOT RUNS INFO") colors = itertools.cycle(palette[20]) loss_tooltips = [ @@ -213,8 +211,6 @@ def bokeh_plot_runs( validation_figure.background_fill_color = "#efefef" validation_figures[dataset.name] = validation_figure - print("VALIDATION SCORES TOOLTIP MADE") - summary_tooltips = [ ("run", "@run"), ("task", "@task"), @@ -246,16 +242,11 @@ def bokeh_plot_runs( iterations = [stat.iteration for stat in run.training_stats.iteration_stats] losses = [stat.loss for stat in run.training_stats.iteration_stats] - print(f"Run {run.name} has {len(losses)} iterations") - if run.plot_loss: include_loss_figure = True smooth = int(np.maximum(len(iterations) / 2500, 1)) - print(f"smoothing: {smooth}") x, _ = smooth_values(iterations, smooth, stride=smooth) y, s = smooth_values(losses, smooth, stride=smooth) - print(x, y) - print(f"plotting {(len(x), len(y))} points") source = bokeh.plotting.ColumnDataSource( { "iteration": x, @@ -284,8 +275,6 @@ def bokeh_plot_runs( alpha=0.3, ) - print("LOSS PLOTTED") - if run.validation_score_name and run.validation_scores.validated_until() > 0: validation_score_data = run.validation_scores.to_xarray().sel( criteria=run.validation_score_name @@ -331,7 +320,6 @@ def bokeh_plot_runs( color=color, alpha=0.7, ) - print("VALIDATION PLOTTED") # Styling # training @@ -383,7 +371,6 @@ def bokeh_plot_runs( plot = bokeh.layouts.column(*figures) plot.sizing_mode = "scale_width" - print("PLOTTING DONE") if return_json: print("Returning JSON") return json.dumps(json_item(plot, "myplot")) @@ -410,9 +397,7 @@ def plot_runs( Returns: None """ - print("PLOTTING RUNS") runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) - print("GOT RUNS INFO") colors = itertools.cycle(plt.cm.tab20.colors) include_validation_figure = False @@ -429,18 +414,12 @@ def plot_runs( iterations = [stat.iteration for stat in run.training_stats.iteration_stats] losses = [stat.loss for stat in run.training_stats.iteration_stats] - print(f"Run {run.name} has {len(losses)} iterations") - if run.plot_loss: include_loss_figure = True smooth = int(np.maximum(len(iterations) / 2500, 1)) - print(f"smoothing: {smooth}") x, _ = smooth_values(iterations, smooth, stride=smooth) y, s = smooth_values(losses, smooth, stride=smooth) - print(x, y) - print(f"plotting {(len(x), len(y))} points") loss_ax.plot(x, y, label=name, color=color) - print("LOSS PLOTTED") if run.validation_score_name and run.validation_scores.validated_until() > 0: validation_score_data = run.validation_scores.to_xarray().sel( @@ -463,7 +442,6 @@ def plot_runs( color=cc, alpha=0.5 + 0.2 * i, ) - print("VALIDATION PLOTTED") if include_loss_figure: loss_ax.set_title("Training") diff --git a/dacapo/predict_crop.py b/dacapo/predict_crop.py new file mode 100644 index 000000000..253d72382 --- /dev/null +++ b/dacapo/predict_crop.py @@ -0,0 +1,110 @@ +from dacapo.experiments.model import Model +from dacapo.store.local_array_store import LocalArrayIdentifier +from funlib.persistence import open_ds, prepare_ds, Array +from dacapo.utils.array_utils import to_ndarray +from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray +from funlib.geometry import Coordinate, Roi +import numpy as np +from dacapo.compute_context import create_compute_context +from typing import Optional +import logging +import daisy +import torch +import os +from dacapo.utils.array_utils import to_ndarray, save_ndarray + +logger = logging.getLogger(__name__) + + +def predict( + model: Model, + raw_array_identifier: LocalArrayIdentifier, + prediction_array_identifier: LocalArrayIdentifier, + output_roi: Optional[Roi] = None, +): + # get the model's input and output size + if isinstance(raw_array_identifier, LocalArrayIdentifier): + raw_array = open_ds( + raw_array_identifier.container.path, raw_array_identifier.dataset + ) + else: + raw_array = raw_array_identifier + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + context = (input_size - output_size) / 2 + + if output_roi is None: + input_roi = raw_array.roi + output_roi = input_roi.grow(-context, -context) + else: + input_roi = output_roi.grow(context, context) + + read_roi = Roi((0, 0, 0), input_size) + write_roi = read_roi.grow(-context, -context) + + axes = ["c", "z", "y", "x"] + + num_channels = model.num_out_channels + + result_dataset = ZarrArray.create_from_array_identifier( + prediction_array_identifier, + axes, + output_roi, + num_channels, + output_voxel_size, + np.uint8, + ) + + logger.info("Total input ROI: %s, output ROI: %s", input_size, output_roi) + logger.info("Block read ROI: %s, write ROI: %s", read_roi, write_roi) + + out_container, out_dataset = ( + prediction_array_identifier.container.path, + prediction_array_identifier.dataset, + ) + compute_context = create_compute_context() + device = compute_context.device + + def predict_fn(block): + raw_input = to_ndarray(raw_array, block.read_roi) + # expend batch dim + # no need to normalize, done by datasplit + raw_input = np.expand_dims(raw_input, (0, 1)) + with torch.no_grad(): + predictions = ( + model.forward(torch.from_numpy(raw_input).float().to(device)) + .detach() + .cpu() + .numpy()[0] + ) + predictions = (predictions + 1) * 255.0 / 2.0 + predictions[predictions > 254] = 0 + predictions = np.round(predictions).astype(np.uint8) + + save_ndarray(predictions, block.write_roi, result_dataset) + # result_dataset[block.write_roi] = predictions + + # fixing the input roi to be a multiple of the output voxel size + input_roi = input_roi.snap_to_grid( + np.lcm(input_voxel_size, output_voxel_size), mode="shrink" + ) + + task = daisy.Task( + f"predict_{out_container}_{out_dataset}", + total_roi=input_roi, + read_roi=Roi((0, 0, 0), input_size), + write_roi=Roi(context, output_size), + process_function=predict_fn, + check_function=None, + read_write_conflict=False, + fit="overhang", + max_retries=0, + timeout=None, + ) + + return daisy.run_blockwise([task], multiprocessing=False) diff --git a/dacapo/train.py b/dacapo/train.py index a1d884b3d..eb28a3cf7 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -6,7 +6,7 @@ create_weights_store, ) from dacapo.experiments import Run -from dacapo.validate import validate +from dacapo.validate import validate_run import torch from tqdm import tqdm @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -def train(run_name: str): +def train(run_name: str, do_validate=True): """ Train a run @@ -44,7 +44,7 @@ def train(run_name: str): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - return train_run(run) + return train_run(run, do_validate) def train_run(run: Run, do_validate=True): @@ -82,6 +82,7 @@ def train_run(run: Run, do_validate=True): weights_store = create_weights_store() latest_weights_iteration = weights_store.latest_iteration(run) + weights = None if trained_until > 0: if latest_weights_iteration is None: @@ -104,19 +105,23 @@ def train_run(run: Run, do_validate=True): trained_until = latest_weights_iteration run.training_stats.delete_after(trained_until) run.validation_scores.delete_after(trained_until) - weights_store.retrieve_weights(run, iteration=trained_until) + weights = weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration == trained_until: print(f"Resuming training from iteration {trained_until}") - weights_store.retrieve_weights(run, iteration=trained_until) + weights = weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + weights = weights_store.retrieve_weights( + run, iteration=latest_weights_iteration + ) logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " ) + if weights is not None: + run.model.load_state_dict(weights.model) # start/resume training @@ -196,7 +201,7 @@ def train_run(run: Run, do_validate=True): ) validate_thread.start() else: - validate( + validate_run( run, iteration_stats.iteration + 1, ) diff --git a/dacapo/utils/array_utils.py b/dacapo/utils/array_utils.py new file mode 100644 index 000000000..6d0293a4c --- /dev/null +++ b/dacapo/utils/array_utils.py @@ -0,0 +1,65 @@ +import numpy as np +from funlib.persistence import Array + + +def to_ndarray(result_data, roi, fill_value=0): + """An alternative implementation of `__getitem__` that supports + using fill values to request data that may extend outside the + roi covered by result_data. + + Args: + + roi (`class:Roi`, optional): + + If given, copy only the data represented by this ROI. This is + equivalent to:: + + array[roi].to_ndarray() + + fill_value (scalar, optional): + + The value to use to fill in values that are outside the ROI + provided by this data. Defaults to 0. + """ + + shape = roi.shape / result_data.voxel_size + data = np.zeros( + result_data[result_data.roi].shape[: result_data.n_channel_dims] + shape, + dtype=result_data.data.dtype, + ) + if fill_value != 0: + data[:] = fill_value + + array = Array(data, roi, result_data.voxel_size) + + shared_roi = result_data.roi.intersect(roi) + + if not shared_roi.empty: + array[shared_roi] = result_data[shared_roi] + + return data + + +def save_ndarray(data, roi, array): + """An alternative implementation of `__setitem__` that supports + using fill values to request data that may extend outside the + roi covered by result_data. + + Args: + + roi (`class:Roi`, optional): + + If given, copy only the data represented by this ROI. This is + equivalent to:: + + array[roi] = data + + fill_value (scalar, optional): + + The value to use to fill in values that are outside the ROI + provided by this data. Defaults to 0. + """ + intersection_roi = roi.intersect(array.roi) + if not intersection_roi.empty: + result_array = Array(data, roi, array.voxel_size) + array[intersection_roi] = result_array[intersection_roi] diff --git a/dacapo/utils/view.py b/dacapo/utils/view.py index 40241e2f3..5cadc29d5 100644 --- a/dacapo/utils/view.py +++ b/dacapo/utils/view.py @@ -452,7 +452,9 @@ def open_from_array_identitifier(self, array_identifier): >>> ds = viewer.open_from_array_identitifier(array_identifier) """ if os.path.exists(array_identifier.container / array_identifier.dataset): - return open_ds(str(array_identifier.container), array_identifier.dataset) + return open_ds( + str(array_identifier.container.path), array_identifier.dataset + ) else: return None diff --git a/dacapo/validate.py b/dacapo/validate.py index 44f091e7a..a78671cfb 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -1,107 +1,72 @@ -from dacapo.compute_context import create_compute_context -from .predict import predict +from .predict_crop import predict + from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray -from .store.create_store import ( +from dacapo.store.create_store import ( create_array_store, create_config_store, create_stats_store, create_weights_store, ) +import torch from upath import UPath as Path import logging -from warnings import warn +from dacapo.compute_context import create_compute_context logger = logging.getLogger(__name__) -def validate_run( - run_name: str, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate a run at a given iteration. Loads the weights from a previously +def validate(run_name: str, iteration: int = 0, datasets_config=None): + """Validate a run at a given iteration. Loads the weights from a previously stored checkpoint. Returns the best parameters and scores for this - iteration. + iteration.""" - Args: - run: The name of the run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays + logger.info("Validating run %s at iteration %d...", run_name, iteration) + + # create run - """ - # Load the model and weights config_store = create_config_store() run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - compute_context = create_compute_context() - if iteration is not None and not compute_context.distribute_workers: - # create weights store - weights_store = create_weights_store() - - # load weights - run.model.load_state_dict( - weights_store.retrieve_weights(run_name, iteration).model - ) - - return validate( - run=run, - iteration=iteration, - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, - ) - -def validate( - run: Run, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate a run at a given iteration. Loads the weights from a previously - stored checkpoint. Returns the best parameters and scores for this - iteration. - - Args: - run: The run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays - Returns: - The best parameters and scores for this iteration - Raises: - ValueError: If the run does not have a validation dataset or the dataset does not have ground truth. - Example: - validate(my_run, 1000) - """ - - print(f"Validating run {run.name} at iteration {iteration}...") - - run_name = run.name + compute_context = create_compute_context() + device = compute_context.device + run.model.to(device) # read in previous training/validation stats + stats_store = create_stats_store() run.training_stats = stats_store.retrieve_training_stats(run_name) run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( run_name ) + # create weights store and read weights + if iteration > 0: + weights_store = create_weights_store() + weights = weights_store.retrieve_weights(run, iteration) + run.model.load_state_dict(weights.model) + + return validate_run(run, iteration, datasets_config) + + +def validate_run(run: Run, iteration: int, datasets_config=None): + """Validate an already loaded run at the given iteration. This does not + load the weights of that iteration, it is assumed that the model is already + loaded correctly. Returns the best parameters and scores for this + iteration.""" + # set benchmark flag to True for performance + torch.backends.cudnn.benchmark = True + run.model.eval() + if ( run.datasplit.validate is None or len(run.datasplit.validate) == 0 or run.datasplit.validate[0].gt is None ): - raise ValueError(f"Cannot validate run {run.name} at iteration {iteration}.") + logger.error("Cannot validate run %s. Continuing training!", run.name) + return None, None # get array and weight store array_store = create_array_store() @@ -111,37 +76,44 @@ def validate( post_processor = run.task.post_processor evaluator = run.task.evaluator - # Initialize the evaluator with the best scores seen so far - try: - evaluator.set_best(run.validation_scores) - except ValueError: - logger.warn( - f"Could not set best scores for run {run.name} at iteration {iteration}." - ) + input_voxel_size = run.datasplit.train[0].raw.voxel_size + output_voxel_size = run.model.scale(input_voxel_size) - for validation_dataset in run.datasplit.validate: - if validation_dataset.gt is None: - logger.error( - "We do not yet support validating on datasets without ground truth" + # Initialize the evaluator with the best scores seen so far + # evaluator.set_best(run.validation_scores) + if datasets_config is None: + datasets = run.datasplit.validate + else: + from dacapo.experiments.datasplits import DataSplitGenerator + + datasplit_config = ( + DataSplitGenerator( + "", + datasets_config, + input_voxel_size, + output_voxel_size, + targets=run.task.evaluator.channels, ) - raise NotImplementedError - - print(f"Validating run {run.name} on dataset {validation_dataset.name}") + .compute() + .validate_configs + ) + datasets = [ + validate_config.dataset_type(validate_config) + for validate_config in datasplit_config + ] + + for validation_dataset in datasets: + assert ( + validation_dataset.gt is not None + ), "We do not yet support validating on datasets without ground truth" + logger.info( + "Validating run %s on dataset %s", run.name, validation_dataset.name + ) ( input_raw_array_identifier, input_gt_array_identifier, ) = array_store.validation_input_arrays(run.name, validation_dataset.name) - - input_voxel_size = validation_dataset.raw.voxel_size - output_voxel_size = run.model.scale(input_voxel_size) - input_shape = run.model.eval_input_shape - input_size = input_voxel_size * input_shape - output_shape = run.model.compute_output_shape(input_shape)[1] - output_size = output_voxel_size * output_shape - context = (input_size - output_size) / 2 - output_roi = validation_dataset.gt.roi - if ( not Path( f"{input_raw_array_identifier.container}/{input_raw_array_identifier.dataset}" @@ -150,7 +122,14 @@ def validate( f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" ).exists() ): - print("Copying validation inputs!") + logger.info("Copying validation inputs!") + + input_shape = run.model.eval_input_shape + input_size = input_voxel_size * input_shape + output_shape = run.model.compute_output_shape(input_shape)[1] + output_size = output_voxel_size * output_shape + context = (input_size - output_size) / 2 + output_roi = validation_dataset.gt.roi input_roi = ( output_roi.grow(context, context) @@ -167,7 +146,7 @@ def validate( name=f"{run.name}_validation_raw", write_size=input_size, ) - input_raw[input_roi] = validation_dataset.raw[input_roi].squeeze() + input_raw[input_roi] = validation_dataset.raw[input_roi] input_gt = ZarrArray.create_from_array_identifier( input_gt_array_identifier, validation_dataset.gt.axes, @@ -178,140 +157,81 @@ def validate( name=f"{run.name}_validation_gt", write_size=output_size, ) - input_gt[output_roi] = validation_dataset.gt[output_roi].squeeze() + input_gt[output_roi] = validation_dataset.gt[output_roi] else: - print("validation inputs already copied!") + logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset.name + run.name, iteration, validation_dataset ) - compute_context = create_compute_context() - success = predict( - run, - iteration if compute_context.distribute_workers else None, - input_container=input_raw_array_identifier.container, - input_dataset=input_raw_array_identifier.dataset, - output_path=prediction_array_identifier, - output_roi=validation_dataset.gt.roi, # type: ignore - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, + predict( + run.model, + input_raw_array_identifier, + prediction_array_identifier, + output_roi=validation_dataset.gt.roi, ) - if not success: - logger.error( - f"Could not predict run {run.name} on dataset {validation_dataset.name}." - ) - continue - - print(f"Predicted on dataset {validation_dataset.name}") - post_processor.set_prediction(prediction_array_identifier) - # # set up dict for overall best scores per dataset - # overall_best_scores = {} - # for criterion in run.validation_scores.criteria: - # overall_best_scores[criterion] = evaluator.get_overall_best( - # validation_dataset, criterion - # ) - - # any_overall_best = False - output_array_identifiers = [] dataset_iteration_scores = [] + for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration, str(parameters), validation_dataset.name + run.name, iteration, parameters, validation_dataset ) - output_array_identifiers.append(output_array_identifier) + post_processed_array = post_processor.process( parameters, output_array_identifier ) - try: - scores = evaluator.evaluate( - output_array_identifier, validation_dataset.gt # type: ignore - ) - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) - # for criterion in run.validation_scores.criteria: - # # replace predictions in array with the new better predictions - # if evaluator.is_best( - # validation_dataset, - # parameters, - # criterion, - # scores, - # ): - # # then this is the current best score for this parameter, but not necessarily the overall best - # # initial_best_score = overall_best_scores[criterion] - # current_score = getattr(scores, criterion) - # if not overall_best_scores[criterion] or evaluator.compare( - # current_score, overall_best_scores[criterion], criterion - # ): - # any_overall_best = True - # overall_best_scores[criterion] = current_score - - # # For example, if parameter 2 did better this round than it did in other rounds, but it was still worse than parameter 1 - # # the code would have overwritten it below since all parameters write to the same file. Now each parameter will be its own file - # # Either we do that, or we only write out the overall best, regardless of parameters - # best_array_identifier = array_store.best_validation_array( - # run.name, - # criterion, - # index=validation_dataset.name, - # ) - # best_array = ZarrArray.create_from_array_identifier( - # best_array_identifier, - # post_processed_array.axes, - # post_processed_array.roi, - # post_processed_array.num_channels, - # post_processed_array.voxel_size, - # post_processed_array.dtype, - # output_size, - # ) - # best_array[best_array.roi] = post_processed_array[ - # post_processed_array.roi - # ] - # best_array.add_metadata( - # { - # "iteration": iteration, - # criterion: getattr(scores, criterion), - # "parameters_id": parameters.id, - # } - # ) - # weights_store.store_best( - # run.name, - # iteration, - # validation_dataset.name, - # criterion, - # ) - except: - logger.error( - f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.", - exc_info=True, - stack_info=True, - ) - - # if not any_overall_best: - # # We only keep the best outputs as determined by the evaluator - # for output_array_identifier in output_array_identifiers: - # array_store.remove(prediction_array_identifier) - # array_store.remove(output_array_identifier) + scores = evaluator.evaluate(output_array_identifier, validation_dataset.gt) + + # for criterion in run.validation_scores.criteria: + # # replace predictions in array with the new better predictions + # if evaluator.is_best( + # validation_dataset, + # parameters, + # criterion, + # scores, + # ): + # best_array_identifier = array_store.best_validation_array( + # run.name, criterion, index=validation_dataset.name + # ) + # best_array = ZarrArray.create_from_array_identifier( + # best_array_identifier, + # post_processed_array.axes, + # post_processed_array.roi, + # post_processed_array.num_channels, + # post_processed_array.voxel_size, + # post_processed_array.dtype, + # ) + # best_array[best_array.roi] = post_processed_array[ + # post_processed_array.roi + # ] + # best_array.add_metadata( + # { + # "iteration": iteration, + # criterion: getattr(scores, criterion), + # "parameters_id": parameters.id, + # } + # ) + # weights_store.store_best( + # run, iteration, validation_dataset.name, criterion + # ) + + # delete current output. We only keep the best outputs as determined by + # the evaluator + # array_store.remove(output_array_identifier) + + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) iteration_scores.append(dataset_iteration_scores) + # array_store.remove(prediction_array_identifier) run.validation_scores.add_iteration_scores( ValidationIterationScores(iteration, iteration_scores) ) stats_store = create_stats_store() stats_store.store_validation_iteration_scores(run.name, run.validation_scores) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("run_name", type=str) - parser.add_argument("iteration", type=int) - args = parser.parse_args() - - validate(args.run_name, args.iteration) diff --git a/tests/components/test_options.py b/tests/components/test_options.py index e12d90483..81742aafe 100644 --- a/tests/components/test_options.py +++ b/tests/components/test_options.py @@ -19,8 +19,8 @@ def test_no_config(): # Remove the environment variable env_dict = dict(os.environ) - if "OPTIONS_FILE" in env_dict: - del env_dict["OPTIONS_FILE"] + if "DACAPO_OPTIONS_FILE" in env_dict: + del env_dict["DACAPO_OPTIONS_FILE"] # Parse the options options = Options.instance() @@ -61,7 +61,7 @@ def test_local_config_file(): """ ) ) - os.environ["OPTIONS_FILE"] = str(config_file) + os.environ["DACAPO_OPTIONS_FILE"] = str(config_file) # Parse the options options = Options.instance() diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 1fc2a6e8b..5d74b3797 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -56,13 +56,13 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate_run(run_config.name, 0, num_workers=4) + validate_run(run_config.name, 0) # weights_store.store_weights(run, 1) - # validate_run(run_config.name, 1, num_workers=4) + # validate_run(run_config.name, 1) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate_run(run_config.name, 2, num_workers=4) + validate_run(run_config.name, 2) if debug: os.chdir(old_path)