Skip to content

Commit

Permalink
Depricate old data configs, add simplified data config
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Nov 12, 2024
1 parent d59bd8f commit 565b964
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 0 deletions.
1 change: 1 addition & 0 deletions dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator, DatasetSpec
from .simple_config import SimpleDataSplitConfig
1 change: 1 addition & 0 deletions dacapo/experiments/datasplits/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .dummy_dataset_config import DummyDatasetConfig
from .raw_gt_dataset import RawGTDataset
from .raw_gt_dataset_config import RawGTDatasetConfig
from .simple import SimpleDataset
9 changes: 9 additions & 0 deletions dacapo/experiments/datasplits/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .dataset import Dataset
from funlib.persistence import Array

import warnings


class DummyDataset(Dataset):
"""
Expand All @@ -15,6 +17,7 @@ class DummyDataset(Dataset):
Notes:
This class is used to create a dataset with raw data.
"""


raw: Array

Expand All @@ -34,5 +37,11 @@ def __init__(self, dataset_config):
This method is used to initialize the dataset.
"""
super().__init__()

warnings.warn(
"DummyDataset is deprecated. Use SimpleDataset instead.",
DeprecationWarning,
)

self.name = dataset_config.name
self.raw = dataset_config.raw_config.array()
7 changes: 7 additions & 0 deletions dacapo/experiments/datasplits/datasets/raw_gt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from funlib.geometry import Coordinate

from typing import Optional, List
import warnings


class RawGTDataset(Dataset):
Expand Down Expand Up @@ -48,6 +49,12 @@ def __init__(self, dataset_config):
Notes:
This method is used to initialize the dataset.
"""

warnings.warn(
"RawGTDataset is deprecated. Use SimpleDataset instead.",
DeprecationWarning,
)

self.name = dataset_config.name
self.raw = dataset_config.raw_config.array()
self.gt = dataset_config.gt_config.array()
Expand Down
69 changes: 69 additions & 0 deletions dacapo/experiments/datasplits/datasets/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .dataset_config import DatasetConfig

from funlib.persistence import Array, open_ds


import attr

from pathlib import Path
import numpy as np

@attr.s
class SimpleDataset(DatasetConfig):

path: Path = attr.ib()
weight: int = attr.ib(default=1)
raw_name: str = attr.ib(default="raw")
gt_name: str = attr.ib(default="labels")
mask_name: str = attr.ib(default="mask")

@staticmethod
def dataset_type(dataset_config):
return dataset_config

@property
def raw(self) -> Array:
raw_array = open_ds(self.path / self.raw_name)
dtype = raw_array.dtype
if dtype == np.uint8:
raw_array.lazy_op(lambda data: data.astype(np.float32) / 255)
elif dtype == np.uint16:
raw_array.lazy_op(lambda data: data.astype(np.float32) / 65535)
elif np.issubdtype(dtype, np.floating):
pass
elif np.issubdtype(dtype, np.integer):
raise Exception(
f"Not sure how to normalize intensity data with dtype {dtype}"
)
return raw_array

@property
def gt(self) -> Array:
return open_ds(self.path / self.gt_name)

@property
def mask(self) -> Array | None:
mask_path = self.path / self.mask_name
if mask_path.exists():
mask = open_ds(mask_path)
assert np.issubdtype(mask.dtype, np.integer), "Mask must be integer type"
mask.lazy_op(lambda data: data > 0)
return mask
return None

@property
def sample_points(self) -> None:
return None


def __eq__(self, other) -> bool:
return isinstance(other, type(self)) and self.name == other.name

def __hash__(self) -> int:
return hash(self.name)

def __repr__(self) -> str:
return self.name

def __str__(self) -> str:
return self.name
5 changes: 5 additions & 0 deletions dacapo/experiments/datasplits/dummy_datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .datasets import Dataset

from typing import List
import warnings


class DummyDataSplit(DataSplit):
Expand Down Expand Up @@ -41,6 +42,10 @@ def __init__(self, datasplit_config):
This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets.
"""
super().__init__()
warnings.warn(
"TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.",
DeprecationWarning,
)

self.train = [
datasplit_config.train_config.dataset_type(datasplit_config.train_config)
Expand Down
69 changes: 69 additions & 0 deletions dacapo/experiments/datasplits/simple_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .datasets.simple import SimpleDataset
from .datasplit_config import DataSplitConfig

import attr

from pathlib import Path

import glob

@attr.s
class SimpleDataSplitConfig(DataSplitConfig):
"""
A convention over configuration datasplit that can handle many of the most
basic cases.
"""

path: Path = attr.ib()
name: str = attr.ib()
train_group_name: str = attr.ib(default="train")
validate_group_name: str = attr.ib(default="test")
raw_name: str = attr.ib(default="raw")
gt_name: str = attr.ib(default="labels")
mask_name: str = attr.ib(default="mask")

@staticmethod
def datasplit_type(datasplit_config):
return datasplit_config

def get_paths(self, group_name: str) -> list[Path]:
level_0 = f"{self.path}/{self.raw_name}"
level_1 = f"{self.path}/{group_name}/{self.raw_name}"
level_2 = f"{self.path}/{group_name}/**/{self.raw_name}"
level_0_matches = glob.glob(level_0)
level_1_matches = glob.glob(level_1)
level_2_matches = glob.glob(level_2)
if len(level_0_matches) > 0:
assert (
len(level_1_matches) == len(level_2_matches) == 0
), f"Found raw data at {level_0} and {level_1} and {level_2}"
return [Path(x).parent for x in level_0_matches]
elif len(level_1_matches) > 0:
assert (
len(level_2_matches) == 0
), f"Found raw data at {level_1} and {level_2}"
return [Path(x).parent for x in level_1_matches]
elif len(level_2_matches).parent > 0:
return [Path(x) for x in level_2_matches]

raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}")

@property
def train(self) -> list[SimpleDataset]:
return [
SimpleDataset(
name=x.stem,
path=x,
)
for x in self.get_paths(self.train_group_name)
]

@property
def validate(self) -> list[SimpleDataset]:
return [
SimpleDataset(
name=x.stem,
path=x,
)
for x in self.get_paths(self.validate_group_name)
]
5 changes: 5 additions & 0 deletions dacapo/experiments/datasplits/train_validate_datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .datasets import Dataset

from typing import List
import warnings


class TrainValidateDataSplit(DataSplit):
Expand Down Expand Up @@ -47,6 +48,10 @@ def __init__(self, datasplit_config):
into training and validation datasets.
"""
super().__init__()
warnings.warn(
"TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.",
DeprecationWarning,
)

self.train = [
train_config.dataset_type(train_config)
Expand Down

0 comments on commit 565b964

Please sign in to comment.