Skip to content

Commit

Permalink
Specify resolution with sequences (#251)
Browse files Browse the repository at this point in the history
Let's specify resolution also directly like `(8, 8, 8)`, in addition to
`Coordinate(8, 8, 8)`?
```python
datasplit_config = DataSplitGenerator.generate_from_csv(
    'test.csv',
    input_resolution=(8, 8, 8),  # This works.
    output_resolution=Coordinate(4, 4, 4),  # And this works.
)
```
  • Loading branch information
mzouink authored May 9, 2024
2 parents 9d2384a + c306a7c commit 3c7e309
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from dacapo.experiments.tasks import TaskConfig
from upath import UPath as Path
from typing import List
from typing import List, Union, Optional, Sequence
from enum import Enum, EnumMeta
from funlib.geometry import Coordinate
from typing import Union, Optional

import zarr
from zarr.n5 import N5FSStore
Expand Down Expand Up @@ -389,12 +388,14 @@ def generate_dataspec_from_csv(csv_path: Path):


class DataSplitGenerator:
"""
Generates DataSplitConfig for a given task config and datasets. A csv file can be generated
from the DataSplitConfig and used to generate the DataSplitConfig again.
"""Generates DataSplitConfig for a given task config and datasets.
Class names in gt_dataset should be within [] e.g. [mito&peroxisome&er] for
multiple classes or [mito] for one class.
Currently only supports semantic segmentation.
Supports:
Currently only supports:
- semantic segmentation.
Supports:
- 2D and 3D datasets.
- Zarr, N5 and OME-Zarr datasets.
- Multi class targets.
Expand Down Expand Up @@ -462,8 +463,8 @@ def __init__(
self,
name: str,
datasets: List[DatasetSpec],
input_resolution: Coordinate,
output_resolution: Coordinate,
input_resolution: Union[Sequence[int], Coordinate],
output_resolution: Union[Sequence[int], Coordinate],
targets: Optional[List[str]] = None,
segmentation_type: Union[str, SegmentationType] = "semantic",
max_gt_downsample=32,
Expand Down Expand Up @@ -540,16 +541,19 @@ def __init__(
This function is used to initialize the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character.
"""
if not isinstance(input_resolution, Coordinate):
input_resolution = Coordinate(input_resolution)
if not isinstance(output_resolution, Coordinate):
output_resolution = Coordinate(output_resolution)
if isinstance(segmentation_type, str):
segmentation_type = SegmentationType[segmentation_type.lower()]

self.name = name
self.datasets = datasets
self.input_resolution = input_resolution
self.output_resolution = output_resolution
self.targets = targets
self._class_name = None

if isinstance(segmentation_type, str):
segmentation_type = SegmentationType[segmentation_type.lower()]

self.segmentation_type = segmentation_type
self.max_gt_downsample = max_gt_downsample
self.max_gt_upsample = max_gt_upsample
Expand Down Expand Up @@ -844,8 +848,8 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
@staticmethod
def generate_from_csv(
csv_path: Path,
input_resolution: Coordinate,
output_resolution: Coordinate,
input_resolution: Union[Sequence[int], Coordinate],
output_resolution: Union[Sequence[int], Coordinate],
name: Optional[str] = None,
**kwargs,
):
Expand Down

0 comments on commit 3c7e309

Please sign in to comment.