From f9959f0a1db24eb34e6bddae7396a3bd1cc9596b Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 10 Sep 2024 11:50:23 -0400 Subject: [PATCH 1/3] support max validation size --- .../datasplits/datasplit_generator.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index d3a6cb7d6..a212d213e 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -92,6 +92,37 @@ def resize_if_needed( else: return array_config +def limit_validation_crop_size(gt_config, mask_config, max_size): + gt_array = gt_config.array_type(gt_config) + voxel_shape = gt_array.roi.shape / gt_array.voxel_size + crop = False + while np.prod(voxel_shape) > max_size: + crop = True + max_idx = np.argmax(voxel_shape) + voxel_shape = Coordinate( + s if i != max_idx else s // 2 for i, s in enumerate(voxel_shape) + ) + if crop: + crop_roi_shape = voxel_shape * gt_array.voxel_size + context = (gt_array.roi.shape - crop_roi_shape) / 2 + crop_roi = gt_array.roi.grow(-context, -context) + crop_roi = crop_roi.snap_to_grid(gt_array.voxel_size, mode="shrink") + + logger.debug( + f"Cropped {gt_config.name}: original roi: {gt_array.roi}, new_roi: {crop_roi}" + ) + + gt_config = CropArrayConfig( + name=gt_config.name + "_cropped", + source_array_config=gt_config, + roi=crop_roi, + ) + mask_config = CropArrayConfig( + name=mask_config.name + "_cropped", + source_array_config=gt_config, + roi=crop_roi, + ) + return gt_config, mask_config def get_right_resolution_array_config( container: Path, dataset, target_resolution, extra_str="" @@ -441,6 +472,10 @@ class DataSplitGenerator: The maximum raw value. classes_separator_caracter : str The classes separator character. + max_validation_volume_size : int + The maximum validation volume size. Default is None. If None, the validation volume size is not limited. + else, the validation volume size is limited to the specified value. + e.g. 600**3 for 600^3 voxels = 216_000_000 voxels. Methods: __init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter) Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character. @@ -484,6 +519,7 @@ def __init__( raw_max=255, classes_separator_caracter="&", use_negative_class=False, + max_validation_volume_size=None, ): """ Initializes the DataSplitGenerator class with the specified: @@ -573,6 +609,7 @@ def __init__( self.raw_max = raw_max self.classes_separator_caracter = classes_separator_caracter self.use_negative_class = use_negative_class + self.max_validation_volume_size = max_validation_volume_size if use_negative_class: if targets is None: raise ValueError( @@ -749,6 +786,10 @@ def __generate_semantic_seg_datasplit(self): ) ) else: + if self.max_validation_volume_size is not None: + gt_config, mask_config = limit_validation_crop_size( + gt_config, mask_config, self.max_validation_volume_size + ) validation_dataset_configs.append( RawGTDatasetConfig( name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm", From 1cb474810e03a4acc1e64171ecbb08371f012524 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 10 Sep 2024 11:54:30 -0400 Subject: [PATCH 2/3] fix import --- dacapo/experiments/datasplits/datasplit_generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index a212d213e..c530b6dc5 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -6,6 +6,7 @@ import zarr from zarr.n5 import N5FSStore +import numpy as np from dacapo.experiments.datasplits.datasets.arrays import ( ZarrArrayConfig, ZarrArray, @@ -15,11 +16,13 @@ ConcatArrayConfig, LogicalOrArrayConfig, ConstantArrayConfig, + CropArrayConfig, ) from dacapo.experiments.datasplits import TrainValidateDataSplitConfig from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig import logging + logger = logging.getLogger(__name__) From e4a84ac51d5c01abb8fb20e6fb2cd6e909cc391f Mon Sep 17 00:00:00 2001 From: mzouink Date: Tue, 10 Sep 2024 15:55:03 +0000 Subject: [PATCH 3/3] :art: Format Python code with psf/black --- dacapo/experiments/datasplits/datasplit_generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index c530b6dc5..ae1dad954 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -95,6 +95,7 @@ def resize_if_needed( else: return array_config + def limit_validation_crop_size(gt_config, mask_config, max_size): gt_array = gt_config.array_type(gt_config) voxel_shape = gt_array.roi.shape / gt_array.voxel_size @@ -127,6 +128,7 @@ def limit_validation_crop_size(gt_config, mask_config, max_size): ) return gt_config, mask_config + def get_right_resolution_array_config( container: Path, dataset, target_resolution, extra_str="" ):