Skip to content

Commit

Permalink
Feat: datasplit limit validation size (#289)
Browse files Browse the repository at this point in the history
Sometimes i want to speedup validation so as option we can limit
validation size:
`max_validation_volume_size = 600**3`
  • Loading branch information
mzouink authored Sep 12, 2024
2 parents 8d0dbf6 + c0b09bd commit 3430d6a
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import zarr
from zarr.n5 import N5FSStore
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays import (
ZarrArrayConfig,
ZarrArray,
Expand All @@ -15,11 +16,13 @@
ConcatArrayConfig,
LogicalOrArrayConfig,
ConstantArrayConfig,
CropArrayConfig,
)
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig
import logging


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -93,6 +96,39 @@ def resize_if_needed(
return array_config


def limit_validation_crop_size(gt_config, mask_config, max_size):
gt_array = gt_config.array_type(gt_config)
voxel_shape = gt_array.roi.shape / gt_array.voxel_size
crop = False
while np.prod(voxel_shape) > max_size:
crop = True
max_idx = np.argmax(voxel_shape)
voxel_shape = Coordinate(
s if i != max_idx else s // 2 for i, s in enumerate(voxel_shape)
)
if crop:
crop_roi_shape = voxel_shape * gt_array.voxel_size
context = (gt_array.roi.shape - crop_roi_shape) / 2
crop_roi = gt_array.roi.grow(-context, -context)
crop_roi = crop_roi.snap_to_grid(gt_array.voxel_size, mode="shrink")

logger.debug(
f"Cropped {gt_config.name}: original roi: {gt_array.roi}, new_roi: {crop_roi}"
)

gt_config = CropArrayConfig(
name=gt_config.name + "_cropped",
source_array_config=gt_config,
roi=crop_roi,
)
mask_config = CropArrayConfig(
name=mask_config.name + "_cropped",
source_array_config=gt_config,
roi=crop_roi,
)
return gt_config, mask_config


def get_right_resolution_array_config(
container: Path, dataset, target_resolution, extra_str=""
):
Expand Down Expand Up @@ -441,6 +477,10 @@ class DataSplitGenerator:
The maximum raw value.
classes_separator_character : str
The classes separator character.
max_validation_volume_size : int
The maximum validation volume size. Default is None. If None, the validation volume size is not limited.
else, the validation volume size is limited to the specified value.
e.g. 600**3 for 600^3 voxels = 216_000_000 voxels.
Methods:
__init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_character)
Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character.
Expand Down Expand Up @@ -484,6 +524,7 @@ def __init__(
raw_max=255,
classes_separator_character="&",
use_negative_class=False,
max_validation_volume_size=None,
binarize_gt=False,
):
"""
Expand Down Expand Up @@ -580,6 +621,7 @@ def __init__(
self.raw_max = raw_max
self.classes_separator_character = classes_separator_character
self.use_negative_class = use_negative_class
self.max_validation_volume_size = max_validation_volume_size
self.binarize_gt = binarize_gt
if use_negative_class:
if targets is None:
Expand Down Expand Up @@ -757,6 +799,10 @@ def __generate_semantic_seg_datasplit(self):
)
)
else:
if self.max_validation_volume_size is not None:
gt_config, mask_config = limit_validation_crop_size(
gt_config, mask_config, self.max_validation_volume_size
)
validation_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
Expand Down

0 comments on commit 3430d6a

Please sign in to comment.