diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 6c1a214cd..bb1c19472 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -6,6 +6,7 @@ import zarr from zarr.n5 import N5FSStore +import numpy as np from dacapo.experiments.datasplits.datasets.arrays import ( ZarrArrayConfig, ZarrArray, @@ -15,11 +16,13 @@ ConcatArrayConfig, LogicalOrArrayConfig, ConstantArrayConfig, + CropArrayConfig, ) from dacapo.experiments.datasplits import TrainValidateDataSplitConfig from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig import logging + logger = logging.getLogger(__name__) @@ -93,6 +96,39 @@ def resize_if_needed( return array_config +def limit_validation_crop_size(gt_config, mask_config, max_size): + gt_array = gt_config.array_type(gt_config) + voxel_shape = gt_array.roi.shape / gt_array.voxel_size + crop = False + while np.prod(voxel_shape) > max_size: + crop = True + max_idx = np.argmax(voxel_shape) + voxel_shape = Coordinate( + s if i != max_idx else s // 2 for i, s in enumerate(voxel_shape) + ) + if crop: + crop_roi_shape = voxel_shape * gt_array.voxel_size + context = (gt_array.roi.shape - crop_roi_shape) / 2 + crop_roi = gt_array.roi.grow(-context, -context) + crop_roi = crop_roi.snap_to_grid(gt_array.voxel_size, mode="shrink") + + logger.debug( + f"Cropped {gt_config.name}: original roi: {gt_array.roi}, new_roi: {crop_roi}" + ) + + gt_config = CropArrayConfig( + name=gt_config.name + "_cropped", + source_array_config=gt_config, + roi=crop_roi, + ) + mask_config = CropArrayConfig( + name=mask_config.name + "_cropped", + source_array_config=gt_config, + roi=crop_roi, + ) + return gt_config, mask_config + + def get_right_resolution_array_config( container: Path, dataset, target_resolution, extra_str="" ): @@ -441,6 +477,10 @@ class DataSplitGenerator: The maximum raw value. classes_separator_character : str The classes separator character. + max_validation_volume_size : int + The maximum validation volume size. Default is None. If None, the validation volume size is not limited. + else, the validation volume size is limited to the specified value. + e.g. 600**3 for 600^3 voxels = 216_000_000 voxels. Methods: __init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_character) Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character. @@ -484,6 +524,7 @@ def __init__( raw_max=255, classes_separator_character="&", use_negative_class=False, + max_validation_volume_size=None, binarize_gt=False, ): """ @@ -580,6 +621,7 @@ def __init__( self.raw_max = raw_max self.classes_separator_character = classes_separator_character self.use_negative_class = use_negative_class + self.max_validation_volume_size = max_validation_volume_size self.binarize_gt = binarize_gt if use_negative_class: if targets is None: @@ -757,6 +799,10 @@ def __generate_semantic_seg_datasplit(self): ) ) else: + if self.max_validation_volume_size is not None: + gt_config, mask_config = limit_validation_crop_size( + gt_config, mask_config, self.max_validation_volume_size + ) validation_dataset_configs.append( RawGTDatasetConfig( name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",