Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: datasplit limit validation size #289

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import zarr
from zarr.n5 import N5FSStore
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays import (
ZarrArrayConfig,
ZarrArray,
Expand All @@ -15,11 +16,13 @@
ConcatArrayConfig,
LogicalOrArrayConfig,
ConstantArrayConfig,
CropArrayConfig,
)
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig
import logging


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -93,6 +96,39 @@ def resize_if_needed(
return array_config


def limit_validation_crop_size(gt_config, mask_config, max_size):
gt_array = gt_config.array_type(gt_config)
voxel_shape = gt_array.roi.shape / gt_array.voxel_size
crop = False
while np.prod(voxel_shape) > max_size:
crop = True
max_idx = np.argmax(voxel_shape)
voxel_shape = Coordinate(
s if i != max_idx else s // 2 for i, s in enumerate(voxel_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, I don't like using s // 2...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider a volume that is 100x100x110, but a max size of 100**3
This will reduce the volume to 100x100x55 instead of just getting to the max size :/

)
if crop:
crop_roi_shape = voxel_shape * gt_array.voxel_size
context = (gt_array.roi.shape - crop_roi_shape) / 2
crop_roi = gt_array.roi.grow(-context, -context)
crop_roi = crop_roi.snap_to_grid(gt_array.voxel_size, mode="shrink")

logger.debug(
f"Cropped {gt_config.name}: original roi: {gt_array.roi}, new_roi: {crop_roi}"
)

gt_config = CropArrayConfig(
name=gt_config.name + "_cropped",
source_array_config=gt_config,
roi=crop_roi,
)
mask_config = CropArrayConfig(
name=mask_config.name + "_cropped",
source_array_config=gt_config,
roi=crop_roi,
)
return gt_config, mask_config


def get_right_resolution_array_config(
container: Path, dataset, target_resolution, extra_str=""
):
Expand Down Expand Up @@ -441,6 +477,10 @@ class DataSplitGenerator:
The maximum raw value.
classes_separator_caracter : str
The classes separator character.
max_validation_volume_size : int
The maximum validation volume size. Default is None. If None, the validation volume size is not limited.
else, the validation volume size is limited to the specified value.
e.g. 600**3 for 600^3 voxels = 216_000_000 voxels.
Methods:
__init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter)
Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character.
Expand Down Expand Up @@ -484,6 +524,7 @@ def __init__(
raw_max=255,
classes_separator_caracter="&",
use_negative_class=False,
max_validation_volume_size=None,
):
"""
Initializes the DataSplitGenerator class with the specified:
Expand Down Expand Up @@ -573,6 +614,7 @@ def __init__(
self.raw_max = raw_max
self.classes_separator_caracter = classes_separator_caracter
self.use_negative_class = use_negative_class
self.max_validation_volume_size = max_validation_volume_size
if use_negative_class:
if targets is None:
raise ValueError(
Expand Down Expand Up @@ -749,6 +791,10 @@ def __generate_semantic_seg_datasplit(self):
)
)
else:
if self.max_validation_volume_size is not None:
gt_config, mask_config = limit_validation_crop_size(
gt_config, mask_config, self.max_validation_volume_size
)
validation_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
Expand Down
Loading