diff --git a/dacapo/experiments/datasplits/__init__.py b/dacapo/experiments/datasplits/__init__.py index ad1ad4880..f70ec1a71 100644 --- a/dacapo/experiments/datasplits/__init__.py +++ b/dacapo/experiments/datasplits/__init__.py @@ -5,3 +5,4 @@ from .train_validate_datasplit import TrainValidateDataSplit from .train_validate_datasplit_config import TrainValidateDataSplitConfig from .datasplit_generator import DataSplitGenerator, DatasetSpec +from .simple_config import SimpleDataSplitConfig \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasets/__init__.py b/dacapo/experiments/datasplits/datasets/__init__.py index edcffd8ef..c886eea19 100644 --- a/dacapo/experiments/datasplits/datasets/__init__.py +++ b/dacapo/experiments/datasplits/datasets/__init__.py @@ -4,3 +4,4 @@ from .dummy_dataset_config import DummyDatasetConfig from .raw_gt_dataset import RawGTDataset from .raw_gt_dataset_config import RawGTDatasetConfig +from .simple import SimpleDataset \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasets/dummy_dataset.py b/dacapo/experiments/datasplits/datasets/dummy_dataset.py index b8e6a2ae0..b73f1a051 100644 --- a/dacapo/experiments/datasplits/datasets/dummy_dataset.py +++ b/dacapo/experiments/datasplits/datasets/dummy_dataset.py @@ -1,6 +1,8 @@ from .dataset import Dataset from funlib.persistence import Array +import warnings + class DummyDataset(Dataset): """ @@ -15,6 +17,7 @@ class DummyDataset(Dataset): Notes: This class is used to create a dataset with raw data. """ + raw: Array @@ -34,5 +37,11 @@ def __init__(self, dataset_config): This method is used to initialize the dataset. """ super().__init__() + + warnings.warn( + "DummyDataset is deprecated. Use SimpleDataset instead.", + DeprecationWarning, + ) + self.name = dataset_config.name self.raw = dataset_config.raw_config.array() diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py index 8af1068f9..6da920ec0 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py @@ -4,6 +4,7 @@ from funlib.geometry import Coordinate from typing import Optional, List +import warnings class RawGTDataset(Dataset): @@ -48,6 +49,12 @@ def __init__(self, dataset_config): Notes: This method is used to initialize the dataset. """ + + warnings.warn( + "RawGTDataset is deprecated. Use SimpleDataset instead.", + DeprecationWarning, + ) + self.name = dataset_config.name self.raw = dataset_config.raw_config.array() self.gt = dataset_config.gt_config.array() diff --git a/dacapo/experiments/datasplits/datasets/simple.py b/dacapo/experiments/datasplits/datasets/simple.py new file mode 100644 index 000000000..5c73c2537 --- /dev/null +++ b/dacapo/experiments/datasplits/datasets/simple.py @@ -0,0 +1,69 @@ +from .dataset_config import DatasetConfig + +from funlib.persistence import Array, open_ds + + +import attr + +from pathlib import Path +import numpy as np + +@attr.s +class SimpleDataset(DatasetConfig): + + path: Path = attr.ib() + weight: int = attr.ib(default=1) + raw_name: str = attr.ib(default="raw") + gt_name: str = attr.ib(default="labels") + mask_name: str = attr.ib(default="mask") + + @staticmethod + def dataset_type(dataset_config): + return dataset_config + + @property + def raw(self) -> Array: + raw_array = open_ds(self.path / self.raw_name) + dtype = raw_array.dtype + if dtype == np.uint8: + raw_array.lazy_op(lambda data: data.astype(np.float32) / 255) + elif dtype == np.uint16: + raw_array.lazy_op(lambda data: data.astype(np.float32) / 65535) + elif np.issubdtype(dtype, np.floating): + pass + elif np.issubdtype(dtype, np.integer): + raise Exception( + f"Not sure how to normalize intensity data with dtype {dtype}" + ) + return raw_array + + @property + def gt(self) -> Array: + return open_ds(self.path / self.gt_name) + + @property + def mask(self) -> Array | None: + mask_path = self.path / self.mask_name + if mask_path.exists(): + mask = open_ds(mask_path) + assert np.issubdtype(mask.dtype, np.integer), "Mask must be integer type" + mask.lazy_op(lambda data: data > 0) + return mask + return None + + @property + def sample_points(self) -> None: + return None + + + def __eq__(self, other) -> bool: + return isinstance(other, type(self)) and self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + + def __repr__(self) -> str: + return self.name + + def __str__(self) -> str: + return self.name \ No newline at end of file diff --git a/dacapo/experiments/datasplits/dummy_datasplit.py b/dacapo/experiments/datasplits/dummy_datasplit.py index b8bde7327..20342040d 100644 --- a/dacapo/experiments/datasplits/dummy_datasplit.py +++ b/dacapo/experiments/datasplits/dummy_datasplit.py @@ -2,6 +2,7 @@ from .datasets import Dataset from typing import List +import warnings class DummyDataSplit(DataSplit): @@ -41,6 +42,10 @@ def __init__(self, datasplit_config): This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets. """ super().__init__() + warnings.warn( + "TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.", + DeprecationWarning, + ) self.train = [ datasplit_config.train_config.dataset_type(datasplit_config.train_config) diff --git a/dacapo/experiments/datasplits/simple_config.py b/dacapo/experiments/datasplits/simple_config.py new file mode 100644 index 000000000..8e65f56b8 --- /dev/null +++ b/dacapo/experiments/datasplits/simple_config.py @@ -0,0 +1,69 @@ +from .datasets.simple import SimpleDataset +from .datasplit_config import DataSplitConfig + +import attr + +from pathlib import Path + +import glob + +@attr.s +class SimpleDataSplitConfig(DataSplitConfig): + """ + A convention over configuration datasplit that can handle many of the most + basic cases. + """ + + path: Path = attr.ib() + name: str = attr.ib() + train_group_name: str = attr.ib(default="train") + validate_group_name: str = attr.ib(default="test") + raw_name: str = attr.ib(default="raw") + gt_name: str = attr.ib(default="labels") + mask_name: str = attr.ib(default="mask") + + @staticmethod + def datasplit_type(datasplit_config): + return datasplit_config + + def get_paths(self, group_name: str) -> list[Path]: + level_0 = f"{self.path}/{self.raw_name}" + level_1 = f"{self.path}/{group_name}/{self.raw_name}" + level_2 = f"{self.path}/{group_name}/**/{self.raw_name}" + level_0_matches = glob.glob(level_0) + level_1_matches = glob.glob(level_1) + level_2_matches = glob.glob(level_2) + if len(level_0_matches) > 0: + assert ( + len(level_1_matches) == len(level_2_matches) == 0 + ), f"Found raw data at {level_0} and {level_1} and {level_2}" + return [Path(x).parent for x in level_0_matches] + elif len(level_1_matches) > 0: + assert ( + len(level_2_matches) == 0 + ), f"Found raw data at {level_1} and {level_2}" + return [Path(x).parent for x in level_1_matches] + elif len(level_2_matches).parent > 0: + return [Path(x) for x in level_2_matches] + + raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}") + + @property + def train(self) -> list[SimpleDataset]: + return [ + SimpleDataset( + name=x.stem, + path=x, + ) + for x in self.get_paths(self.train_group_name) + ] + + @property + def validate(self) -> list[SimpleDataset]: + return [ + SimpleDataset( + name=x.stem, + path=x, + ) + for x in self.get_paths(self.validate_group_name) + ] diff --git a/dacapo/experiments/datasplits/train_validate_datasplit.py b/dacapo/experiments/datasplits/train_validate_datasplit.py index 0b93663a3..abf57e9b4 100644 --- a/dacapo/experiments/datasplits/train_validate_datasplit.py +++ b/dacapo/experiments/datasplits/train_validate_datasplit.py @@ -2,6 +2,7 @@ from .datasets import Dataset from typing import List +import warnings class TrainValidateDataSplit(DataSplit): @@ -47,6 +48,10 @@ def __init__(self, datasplit_config): into training and validation datasets. """ super().__init__() + warnings.warn( + "TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.", + DeprecationWarning, + ) self.train = [ train_config.dataset_type(train_config)