Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data conventions #335

Merged
merged 10 commits into from
Nov 13, 2024
Prev Previous commit
Next Next commit
Depricate old data configs, add simplified data config
pattonw committed Nov 12, 2024
commit 565b964c9e57cc268fba479f8631f15d8b8ccb4f
1 change: 1 addition & 0 deletions dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
@@ -5,3 +5,4 @@
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator, DatasetSpec
from .simple_config import SimpleDataSplitConfig
1 change: 1 addition & 0 deletions dacapo/experiments/datasplits/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -4,3 +4,4 @@
from .dummy_dataset_config import DummyDatasetConfig
from .raw_gt_dataset import RawGTDataset
from .raw_gt_dataset_config import RawGTDatasetConfig
from .simple import SimpleDataset
9 changes: 9 additions & 0 deletions dacapo/experiments/datasplits/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .dataset import Dataset
from funlib.persistence import Array

import warnings


class DummyDataset(Dataset):
"""
@@ -15,6 +17,7 @@ class DummyDataset(Dataset):
Notes:
This class is used to create a dataset with raw data.
"""


raw: Array

@@ -34,5 +37,11 @@ def __init__(self, dataset_config):
This method is used to initialize the dataset.
"""
super().__init__()

warnings.warn(
"DummyDataset is deprecated. Use SimpleDataset instead.",
DeprecationWarning,
)

self.name = dataset_config.name
self.raw = dataset_config.raw_config.array()
7 changes: 7 additions & 0 deletions dacapo/experiments/datasplits/datasets/raw_gt_dataset.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from funlib.geometry import Coordinate

from typing import Optional, List
import warnings


class RawGTDataset(Dataset):
@@ -48,6 +49,12 @@ def __init__(self, dataset_config):
Notes:
This method is used to initialize the dataset.
"""

warnings.warn(
"RawGTDataset is deprecated. Use SimpleDataset instead.",
DeprecationWarning,
)

self.name = dataset_config.name
self.raw = dataset_config.raw_config.array()
self.gt = dataset_config.gt_config.array()
69 changes: 69 additions & 0 deletions dacapo/experiments/datasplits/datasets/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .dataset_config import DatasetConfig

from funlib.persistence import Array, open_ds


import attr

from pathlib import Path
import numpy as np

@attr.s
class SimpleDataset(DatasetConfig):

path: Path = attr.ib()
weight: int = attr.ib(default=1)
raw_name: str = attr.ib(default="raw")
gt_name: str = attr.ib(default="labels")
mask_name: str = attr.ib(default="mask")

@staticmethod
def dataset_type(dataset_config):
return dataset_config

@property
def raw(self) -> Array:
raw_array = open_ds(self.path / self.raw_name)
dtype = raw_array.dtype
if dtype == np.uint8:
raw_array.lazy_op(lambda data: data.astype(np.float32) / 255)
elif dtype == np.uint16:
raw_array.lazy_op(lambda data: data.astype(np.float32) / 65535)
elif np.issubdtype(dtype, np.floating):
pass
elif np.issubdtype(dtype, np.integer):
raise Exception(
f"Not sure how to normalize intensity data with dtype {dtype}"
)
return raw_array

@property
def gt(self) -> Array:
return open_ds(self.path / self.gt_name)

@property
def mask(self) -> Array | None:
mask_path = self.path / self.mask_name
if mask_path.exists():
mask = open_ds(mask_path)
assert np.issubdtype(mask.dtype, np.integer), "Mask must be integer type"
mask.lazy_op(lambda data: data > 0)
return mask
return None

@property
def sample_points(self) -> None:
return None


def __eq__(self, other) -> bool:
return isinstance(other, type(self)) and self.name == other.name

def __hash__(self) -> int:
return hash(self.name)

def __repr__(self) -> str:
return self.name

def __str__(self) -> str:
return self.name
5 changes: 5 additions & 0 deletions dacapo/experiments/datasplits/dummy_datasplit.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from .datasets import Dataset

from typing import List
import warnings


class DummyDataSplit(DataSplit):
@@ -41,6 +42,10 @@ def __init__(self, datasplit_config):
This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets.
"""
super().__init__()
warnings.warn(
"TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.",
DeprecationWarning,
)

self.train = [
datasplit_config.train_config.dataset_type(datasplit_config.train_config)
69 changes: 69 additions & 0 deletions dacapo/experiments/datasplits/simple_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .datasets.simple import SimpleDataset
from .datasplit_config import DataSplitConfig

import attr

from pathlib import Path

import glob

@attr.s
class SimpleDataSplitConfig(DataSplitConfig):
"""
A convention over configuration datasplit that can handle many of the most
basic cases.
"""

path: Path = attr.ib()
name: str = attr.ib()
train_group_name: str = attr.ib(default="train")
validate_group_name: str = attr.ib(default="test")
raw_name: str = attr.ib(default="raw")
gt_name: str = attr.ib(default="labels")
mask_name: str = attr.ib(default="mask")

@staticmethod
def datasplit_type(datasplit_config):
return datasplit_config

def get_paths(self, group_name: str) -> list[Path]:
level_0 = f"{self.path}/{self.raw_name}"
level_1 = f"{self.path}/{group_name}/{self.raw_name}"
level_2 = f"{self.path}/{group_name}/**/{self.raw_name}"
level_0_matches = glob.glob(level_0)
level_1_matches = glob.glob(level_1)
level_2_matches = glob.glob(level_2)
if len(level_0_matches) > 0:
assert (
len(level_1_matches) == len(level_2_matches) == 0
), f"Found raw data at {level_0} and {level_1} and {level_2}"
return [Path(x).parent for x in level_0_matches]
elif len(level_1_matches) > 0:
assert (
len(level_2_matches) == 0
), f"Found raw data at {level_1} and {level_2}"
return [Path(x).parent for x in level_1_matches]
elif len(level_2_matches).parent > 0:
return [Path(x) for x in level_2_matches]

raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}")

@property
def train(self) -> list[SimpleDataset]:
return [
SimpleDataset(
name=x.stem,
path=x,
)
for x in self.get_paths(self.train_group_name)
]

@property
def validate(self) -> list[SimpleDataset]:
return [
SimpleDataset(
name=x.stem,
path=x,
)
for x in self.get_paths(self.validate_group_name)
]
5 changes: 5 additions & 0 deletions dacapo/experiments/datasplits/train_validate_datasplit.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from .datasets import Dataset

from typing import List
import warnings


class TrainValidateDataSplit(DataSplit):
@@ -47,6 +48,10 @@ def __init__(self, datasplit_config):
into training and validation datasets.
"""
super().__init__()
warnings.warn(
"TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.",
DeprecationWarning,
)

self.train = [
train_config.dataset_type(train_config)