From d2b4b6211640d20a4359559874ad76fd195b7d64 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 9 Oct 2024 22:35:45 +0200 Subject: [PATCH] Add function in datasets to get input paths (#367) * Add function in datasets to get input paths * Add _paths functionality to all microscopy datasets * Make download optional for livecell functions --- .../data/datasets/electron_microscopy/asem.py | 47 +++--- .../electron_microscopy/axondeepseg.py | 69 +++++--- .../data/datasets/electron_microscopy/cem.py | 157 ++++++++++-------- .../datasets/electron_microscopy/cremi.py | 62 ++++--- .../datasets/electron_microscopy/deepict.py | 52 ++++-- .../datasets/electron_microscopy/isbi2012.py | 62 ++++--- .../datasets/electron_microscopy/kasthuri.py | 49 ++++-- .../datasets/electron_microscopy/lucchi.py | 52 ++++-- .../datasets/electron_microscopy/mitoem.py | 60 ++++--- .../datasets/electron_microscopy/nuc_mm.py | 47 ++++-- .../electron_microscopy/platynereis.py | 154 +++++++++++------ .../datasets/electron_microscopy/snemi.py | 41 +++-- .../datasets/electron_microscopy/sponge_em.py | 46 +++-- .../datasets/electron_microscopy/uro_cell.py | 52 ++++-- .../data/datasets/electron_microscopy/vnc.py | 36 +++- .../datasets/light_microscopy/cellpose.py | 53 +++--- .../datasets/light_microscopy/cellseg_3d.py | 48 ++++-- .../datasets/light_microscopy/covid_if.py | 64 ++++--- .../data/datasets/light_microscopy/ctc.py | 77 ++++++--- .../datasets/light_microscopy/deepbacs.py | 39 ++++- .../datasets/light_microscopy/dic_hepg2.py | 34 ++-- .../data/datasets/light_microscopy/dsb.py | 42 ++++- .../light_microscopy/dynamicnuclearnet.py | 40 +++-- .../light_microscopy/embedseg_data.py | 51 ++++-- .../datasets/light_microscopy/gonuclear.py | 58 +++++-- .../data/datasets/light_microscopy/hpa.py | 47 ++++-- .../datasets/light_microscopy/livecell.py | 66 ++++---- .../datasets/light_microscopy/mouse_embryo.py | 61 +++++-- .../light_microscopy/neurips_cell_seg.py | 24 ++- .../datasets/light_microscopy/omnipose.py | 37 +++-- .../datasets/light_microscopy/organoidnet.py | 88 +++++++--- .../datasets/light_microscopy/orgasegment.py | 29 +++- .../datasets/light_microscopy/plantseg.py | 52 ++++-- .../datasets/light_microscopy/tissuenet.py | 46 +++-- .../datasets/light_microscopy/vgg_hela.py | 49 ++++-- 35 files changed, 1347 insertions(+), 644 deletions(-) diff --git a/torch_em/data/datasets/electron_microscopy/asem.py b/torch_em/data/datasets/electron_microscopy/asem.py index 50ad29f7..99d6dfe8 100644 --- a/torch_em/data/datasets/electron_microscopy/asem.py +++ b/torch_em/data/datasets/electron_microscopy/asem.py @@ -1,4 +1,5 @@ """ASEM is a dataset for segmentation of cellular structures in FIB-SEM. + The dataset was publised in https://doi.org/10.1083/jcb.202208005. Please cite this publication if you use the dataset in your research. """ @@ -8,6 +9,8 @@ import numpy as np +from torch.utils.data import Dataset, DataLoader + import torch_em from .. import util @@ -54,11 +57,7 @@ } -def get_asem_data( - path: Union[os.PathLike, str], - volume_ids: List[str], - download: bool = False -): +def get_asem_data(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False): """Download the ASEM dataset. The dataset is located at https://open.quiltdata.com/b/asem-project. @@ -67,19 +66,14 @@ def get_asem_data( path: Filepath to a folder where the downloaded data will be saved. volume_ids: List of volumes to download. download: Whether to download the data if it is not present. - - Returns: - List of paths for all volume ids. """ if download and not have_quilt: raise ModuleNotFoundError("Please install quilt3: 'pip install quilt3'.") b = q3.Bucket("s3://asem-project") - volume_paths = [] for volume_id in volume_ids: volume_path = os.path.join(path, VOLUMES[volume_id]) - volume_paths.append(volume_path) if os.path.exists(volume_path): continue @@ -100,6 +94,20 @@ def get_asem_data( b.fetch(key=f"datasets/{VOLUMES[volume_id]}/.zgroup", path=f"{volume_path}/") b.fetch(key=f"datasets/{VOLUMES[volume_id]}/volumes/.zgroup", path=f"{volume_path}/volumes/") + +def get_asem_paths(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False) -> List[str]: + """Get paths to the ASEM data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + volume_ids: List of volumes to download. + download: Whether to download the data if it is not present. + + Returns: + List of paths for all volume ids. + """ + get_asem_data(path, volume_ids, download) + volume_paths = [os.path.join(path, VOLUMES[vol_id]) for vol_id in volume_ids] return volume_paths @@ -170,14 +178,14 @@ def get_asem_dataset( organelles: Optional[Union[List[str], str]] = None, volume_ids: Optional[Union[List[str], str]] = None, **kwargs -): +) -> Dataset: """Get dataset for segmentation of organelles in FIB-SEM cells. Args: path: Filepath to a folder where the downloaded data will be saved. patch_shape: The patch shape to use for training. download: Whether to download the data if it is not present. - orgnalles: The choice of organelles. + organelles: The choice of organelles. volume_ids: The choice of volumes. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. @@ -200,16 +208,17 @@ def get_asem_dataset( assert volume_id in ORGANELLES[organelle], \ f"The chosen volume and organelle combination does not match: '{volume_id}' & '{organelle}'" - volume_paths = get_asem_data(path, volume_ids, download) + volume_paths = get_asem_paths(path, volume_ids, download) for volume_path in volume_paths: have_volumes_inconsistent = _make_volumes_consistent(volume_path, organelle) - raw_key = f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw" dataset = torch_em.default_segmentation_dataset( - volume_path, raw_key, - volume_path, f"volumes/labels/{organelle}", - patch_shape, + raw_paths=volume_path, + raw_key=f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw", + label_paths=volume_path, + label_key=f"volumes/labels/{organelle}", + patch_shape=patch_shape, is_seg_dataset=True, **kwargs ) @@ -227,7 +236,7 @@ def get_asem_loader( organelles: Optional[Union[List[str], str]] = None, volume_ids: Optional[Union[List[str], str]] = None, **kwargs -): +) -> DataLoader: """Get dataloader for the segmentation of organelles in FIB-SEM cells. Args: @@ -235,7 +244,7 @@ def get_asem_loader( patch_shape: The patch shape to use for training. batch_size: The batch size for training. download: Whether to download the data if it is not present. - orgnalles: The choice of organelles. + organelles: The choice of organelles. volume_ids: The choice of volumes. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. diff --git a/torch_em/data/datasets/electron_microscopy/axondeepseg.py b/torch_em/data/datasets/electron_microscopy/axondeepseg.py index 19eb7629..4fb386c0 100644 --- a/torch_em/data/datasets/electron_microscopy/axondeepseg.py +++ b/torch_em/data/datasets/electron_microscopy/axondeepseg.py @@ -1,5 +1,6 @@ """AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. It contains two different data types: TEM and SEM. + The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. Please cite this publication if you use the dataset in your research. """ @@ -7,7 +8,7 @@ import os from glob import glob from shutil import rmtree -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Literal, List import imageio import numpy as np @@ -116,8 +117,8 @@ def _preprocess_tem_data(out_path): rmtree(data_root) -def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str: - """Download the axondeepseg data. +def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str: + """Download the AxonDeepSeg data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -149,14 +150,47 @@ def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: boo return out_path +def get_axondeepseg_paths( + path: Union[str, os.PathLike], + name: Literal["sem", "tem"], + download: bool = False, + val_fraction: Optional[float] = None, + split: Optional[str] = None, +) -> List[str]: + """Get paths to the AxonDeepSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the dataset to download. Can be either 'sem' or 'tem'. + download: Whether to download the data if it is not present. + val_fraction: The fraction of the data to use for validation. + split: The data split. Either 'train' or 'val'. + + Returns: + List of paths for all the data. + """ + all_paths = [] + for nn in name: + data_root = get_axondeepseg_data(path, nn, download) + paths = glob(os.path.join(data_root, "*.h5")) + paths.sort() + if val_fraction is not None: + assert split is not None + n_samples = int(len(paths) * (1 - val_fraction)) + paths = paths[:n_samples] if split == "train" else paths[n_samples:] + all_paths.extend(paths) + + return all_paths + + def get_axondeepseg_dataset( path: Union[str, os.PathLike], - name: str, + name: Literal["sem", "tem"], patch_shape: Tuple[int, int], download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, - split: Optional[str] = None, + split: Optional[Literal['train', 'val']] = None, **kwargs, ) -> Dataset: """Get dataset for segmentation of myelinated axons. @@ -178,16 +212,7 @@ def get_axondeepseg_dataset( name = [name] assert isinstance(name, (tuple, list)) - all_paths = [] - for nn in name: - data_root = get_axondeepseg_data(path, nn, download) - paths = glob(os.path.join(data_root, "*.h5")) - paths.sort() - if val_fraction is not None: - assert split is not None - n_samples = int(len(paths) * (1 - val_fraction)) - paths = paths[:n_samples] if split == "train" else paths[n_samples:] - all_paths.extend(paths) + all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split) if one_hot_encoding: if isinstance(one_hot_encoding, bool): @@ -205,19 +230,25 @@ def get_axondeepseg_dataset( msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=all_paths, + raw_key="raw", + label_paths=all_paths, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_axondeepseg_loader( path: Union[str, os.PathLike], - name: str, + name: Literal["sem", "tem"], patch_shape: Tuple[int, int], batch_size: int, download: bool = False, one_hot_encoding: bool = False, val_fraction: Optional[float] = None, - split: Optional[str] = None, + split: Optional[Literal["train", "val"]] = None, **kwargs ) -> DataLoader: """Get dataloader for the segmentation of myelinated axons. diff --git a/torch_em/data/datasets/electron_microscopy/cem.py b/torch_em/data/datasets/electron_microscopy/cem.py index b79fa361..d8342539 100644 --- a/torch_em/data/datasets/electron_microscopy/cem.py +++ b/torch_em/data/datasets/electron_microscopy/cem.py @@ -28,7 +28,7 @@ import os import json from glob import glob -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Literal import numpy as np import imageio.v3 as imageio @@ -40,6 +40,7 @@ from .. import util + BENCHMARK_DATASETS = { 1: "mito_benchmarks/c_elegans", 2: "mito_benchmarks/fly_brain", @@ -60,20 +61,6 @@ } -def _get_mitolab_data(path, download): - access_id = "11037" - data_path = util.download_source_empiar(path, access_id, download) - - zip_path = os.path.join(data_path, "data/cem_mitolab.zip") - if os.path.exists(zip_path): - util.unzip(zip_path, data_path, remove=True) - - data_root = os.path.join(data_path, "cem_mitolab") - assert os.path.exists(data_root) - - return data_root - - def _get_all_images(path): raw_paths, label_paths = [], [] folders = glob(os.path.join(path, "*")) @@ -124,14 +111,37 @@ def _get_non_empty_images(path): return raw_paths, label_paths -def get_mitolab_data( +def get_mitolab_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the MitoLab training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the downloaded data. + """ + access_id = "11037" + data_path = util.download_source_empiar(path, access_id, download) + + zip_path = os.path.join(data_path, "data/cem_mitolab.zip") + if os.path.exists(zip_path): + util.unzip(zip_path, data_path, remove=True) + + data_root = os.path.join(data_path, "cem_mitolab") + assert os.path.exists(data_root) + + return data_root + + +def get_mitolab_paths( path: Union[os.PathLike, str], - split: str, - val_fraction: float, - download: bool, - discard_empty_images: bool + split: Literal['train', 'val'], + val_fraction: float = 0.05, + download: bool = False, + discard_empty_images: bool = True, ) -> Tuple[List[str], List[str]]: - """Download the mitolab training data. + """Get the paths to MitoLab training data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -144,7 +154,8 @@ def get_mitolab_data( List of the image data paths. List of the label data paths. """ - data_path = _get_mitolab_data(path, download) + data_path = get_mitolab_data(path, download) + if discard_empty_images: raw_paths, label_paths = _get_non_empty_images(data_path) else: @@ -164,14 +175,27 @@ def get_mitolab_data( return raw_paths, label_paths -def get_benchmark_data( - path: Union[os.PathLike, str], - dataset_id: int, - download: bool -) -> Tuple[ - List[str], List[str], str, str, bool -]: - """Download the mitolab benchmark data. +def get_benchmark_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str: + """Download the MitoLab benchmark data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_id: The id of the benchmark dataset to download. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + access_id = "10982" + data_path = util.download_source_empiar(path, access_id, download) + dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) + return dataset_path + + +def get_benchmark_paths( + path: Union[os.PathLike, str], dataset_id: int, download: bool = False +) -> Tuple[List[str], List[str], str, str, bool]: + """Get paths to the MitoLab benchmark data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -185,9 +209,7 @@ def get_benchmark_data( The label data key. Whether this is a segmentation dataset. """ - access_id = "10982" - data_path = util.download_source_empiar(path, access_id, download) - dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) + dataset_path = get_benchmark_data(path, dataset_id, download) # these are the 3d datasets if dataset_id in range(1, 7): @@ -214,14 +236,14 @@ def get_benchmark_data( def get_mitolab_dataset( path: Union[os.PathLike, str], - split: str, + split: Literal['train', 'val'], patch_shape: Tuple[int, int] = (224, 224), val_fraction: float = 0.05, download: bool = False, discard_empty_images: bool = True, **kwargs ) -> Dataset: - """Get the dataset for the mitolab training data. + """Get the dataset for the MitoLab training data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -237,11 +259,18 @@ def get_mitolab_dataset( """ assert split in ("train", "val", None) assert os.path.exists(path) - raw_paths, label_paths = get_mitolab_data(path, split, val_fraction, download, discard_empty_images) + + raw_paths, label_paths = get_mitolab_paths(path, split, val_fraction, download, discard_empty_images) + return torch_em.default_segmentation_dataset( - raw_paths=raw_paths, raw_key=None, - label_paths=label_paths, label_key=None, - patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + ndim=2, + **kwargs ) @@ -250,11 +279,7 @@ def get_cem15m_dataset(path): def get_benchmark_dataset( - path, - dataset_id, - patch_shape, - download=False, - **kwargs, + path: Union[os.PathLike, str], dataset_id: int, patch_shape: Tuple[int, int], download: bool = False, **kwargs ) -> Dataset: """Get the dataset for one of the mitolab benchmark datasets. @@ -270,12 +295,17 @@ def get_benchmark_dataset( """ if dataset_id not in range(1, 8): raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].") - raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_data(path, dataset_id, download) + + raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_paths(path, dataset_id, download) + return torch_em.default_segmentation_dataset( - raw_paths=raw_paths, raw_key=raw_key, - label_paths=label_paths, label_key=label_key, + raw_paths=raw_paths, + raw_key=raw_key, + label_paths=label_paths, + label_key=label_key, patch_shape=patch_shape, - is_seg_dataset=is_seg_dataset, **kwargs, + is_seg_dataset=is_seg_dataset, + **kwargs, ) @@ -294,7 +324,7 @@ def get_mitolab_loader( download: bool = False, **kwargs ) -> DataLoader: - """Get the dataloader for the mitolab training data. + """Get the dataloader for the MitoLab training data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -309,14 +339,17 @@ def get_mitolab_loader( Returns: The PyTorch DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_mitolab_dataset( - path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs + path=path, + split=split, + patch_shape=patch_shape, + val_fraction=val_fraction, + download=download, + discard_empty_images=discard_empty_images, + **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) def get_cem15m_loader(path): @@ -331,7 +364,7 @@ def get_benchmark_loader( download: bool = False, **kwargs ) -> DataLoader: - """Get the dataloader for one of the mitolab benchmark datasets. + """Get the dataloader for one of the MitoLab benchmark datasets. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -344,12 +377,6 @@ def get_benchmark_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_benchmark_dataset( - path, dataset_id, - patch_shape=patch_shape, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_benchmark_dataset(path, dataset_id, patch_shape=patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/cremi.py b/torch_em/data/datasets/electron_microscopy/cremi.py index 5d9942e0..c72f94f9 100644 --- a/torch_em/data/datasets/electron_microscopy/cremi.py +++ b/torch_em/data/datasets/electron_microscopy/cremi.py @@ -2,6 +2,7 @@ It contains three annotated volumes from the adult fruit-fly brain. It was held as a challenge at MICCAI 2016. For details on the dataset check out https://cremi.org/. +Please cite the challenge if you use the dataset in your research. """ # TODO add support for realigned volumes @@ -16,6 +17,7 @@ from .. import util + CREMI_URLS = { "original": { "A": "https://cremi.org/static/data/sample_A_20160501.hdf", @@ -36,12 +38,7 @@ } -def get_cremi_data( - path: Union[os.PathLike, str], - samples: Tuple[str], - download: bool, - use_realigned: bool = False, -) -> List[str]: +def get_cremi_data(path: Union[os.PathLike, str], samples: Tuple[str], download: bool, use_realigned: bool = False): """Download the CREMI training data. Args: @@ -49,9 +46,6 @@ def get_cremi_data( samples: The CREMI samples to use. The available samples are 'A', 'B', 'C'. download: Whether to download the data if it is not present. use_realigned: Use the realigned instead of the original training data. - - Returns: - The filepaths to the training data. """ if use_realigned: # we need to sample batches in this case @@ -62,14 +56,33 @@ def get_cremi_data( checksums = CHECKSUMS["original"] os.makedirs(path, exist_ok=True) - data_paths = [] for name in samples: url = urls[name] checksum = checksums[name] data_path = os.path.join(path, f"sample{name}.h5") # CREMI SSL certificates expired, so we need to disable verification util.download_source(data_path, url, download, checksum, verify=False) - data_paths.append(data_path) + + +def get_cremi_paths( + path: Union[os.PathLike, str], + samples: Tuple[str, ...] = ("A", "B", "C"), + use_realigned: bool = False, + download: bool = False +) -> List[str]: + """Get paths to the CREMI data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + samples: The CREMI samples to use. The available samples are 'A', 'B', 'C'. + use_realigned: Use the realigned instead of the original training data. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the training data. + """ + get_cremi_data(path, samples, download, use_realigned) + data_paths = [os.path.join(path, f"sample{name}.h5") for name in samples] return data_paths @@ -111,7 +124,7 @@ def get_cremi_dataset( if rois is not None: assert isinstance(rois, dict) - data_paths = get_cremi_data(path, samples, download, use_realigned) + data_paths = get_cremi_paths(path, samples, use_realigned, download) data_rois = [rois.get(name, np.s_[:, :, :]) for name in samples] if defect_augmentation_kwargs is not None and "artifact_source" not in defect_augmentation_kwargs: @@ -121,15 +134,14 @@ def get_cremi_dataset( defect_path = os.path.join(path, "cremi_defects.h5") util.download_source(defect_path, url, download, checksum) defect_patch_shape = (1,) + tuple(patch_shape[1:]) - artifact_source = torch_em.transform.get_artifact_source(defect_path, defect_patch_shape, - min_mask_fraction=0.75, - raw_key="defect_sections/raw", - mask_key="defect_sections/mask") + artifact_source = torch_em.transform.get_artifact_source( + defect_path, defect_patch_shape, + min_mask_fraction=0.75, + raw_key="defect_sections/raw", + mask_key="defect_sections/mask" + ) defect_augmentation_kwargs.update({"artifact_source": artifact_source}) - raw_key = "volumes/raw" - label_key = "volumes/labels/neuron_ids" - # defect augmentations if defect_augmentation_kwargs is not None: raw_transform = torch_em.transform.get_raw_transform( @@ -142,7 +154,13 @@ def get_cremi_dataset( ) return torch_em.default_segmentation_dataset( - data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois, **kwargs + raw_paths=data_paths, + raw_key="volumes/raw", + label_paths=data_paths, + label_key="volumes/labels/neuron_ids", + patch_shape=patch_shape, + rois=data_rois, + **kwargs ) @@ -182,9 +200,7 @@ def get_cremi_loader( Returns: The DataLoader. """ - dataset_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_cremi_dataset( path=path, patch_shape=patch_shape, diff --git a/torch_em/data/datasets/electron_microscopy/deepict.py b/torch_em/data/datasets/electron_microscopy/deepict.py index 5b7fe91a..b9b9f023 100644 --- a/torch_em/data/datasets/electron_microscopy/deepict.py +++ b/torch_em/data/datasets/electron_microscopy/deepict.py @@ -1,7 +1,7 @@ """Dataset for segmentation of structures in Cryo ET. - The DeePict dataset contains annotations for several structures in CryoET. The dataset implemented here currently only provides access to the actin annotations. + The dataset is part of the publication https://doi.org/10.1038/s41592-022-01746-2. Plase cite it if you use this dataset in your research. """ @@ -9,7 +9,9 @@ import os from glob import glob from shutil import rmtree -from typing import Tuple, Union +from typing import Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader try: import mrcfile @@ -73,7 +75,7 @@ def _process_deepict_actin(input_path, output_path): def get_deepict_actin_data(path: Union[os.PathLike, str], download: bool) -> str: - """Download the deepict actin dataset. + """Download the DeePict actin dataset. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -99,13 +101,28 @@ def get_deepict_actin_data(path: Union[os.PathLike, str], download: bool) -> str return dataset_path +def get_deepict_actin_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]: + """Get paths to DeePict actin data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the stored data. + """ + get_deepict_actin_data(path, download) + data_paths = sorted(glob(os.path.join(path, "deepict_actin", "*.h5"))) + return data_paths + + def get_deepict_actin_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], label_key: str = "labels/actin", download: bool = False, **kwargs -): +) -> Dataset: """Get the dataset for actin segmentation in Cryo ET data. Args: @@ -120,11 +137,17 @@ def get_deepict_actin_dataset( The segmentation dataset. """ assert len(patch_shape) == 3 - data_path = get_deepict_actin_data(path, download) - data_paths = sorted(glob(os.path.join(data_path, "*.h5"))) - raw_key = "raw" + + data_paths = get_deepict_actin_paths(path, download) + return torch_em.default_segmentation_dataset( - data_paths, raw_key, data_paths, label_key, patch_shape, is_seg_dataset=True, **kwargs + raw_paths=data_paths, + raw_key="raw", + label_paths=data_paths, + label_key=label_key, + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs ) @@ -135,7 +158,7 @@ def get_deepict_actin_loader( label_key: str = "labels/actin", download: bool = False, **kwargs -): +) -> DataLoader: """Get the DataLoader for actin segmentation in CryoET data. Args: @@ -150,11 +173,6 @@ def get_deepict_actin_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_deepict_actin_dataset( - path, patch_shape, label_key=label_key, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_deepict_actin_dataset(path, patch_shape, label_key=label_key, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/isbi2012.py b/torch_em/data/datasets/electron_microscopy/isbi2012.py index 5f4f0d33..b117d76f 100644 --- a/torch_em/data/datasets/electron_microscopy/isbi2012.py +++ b/torch_em/data/datasets/electron_microscopy/isbi2012.py @@ -1,35 +1,46 @@ """The ISBI2012 dataset was the first neuron segmentation challenge, held at the ISBI 2012 competition. +It contains a small annotated EM volume from the fruit-fly brain. -It contains a small annotated EM volume from the fruit-fly brain. If you use this dataset in -your research please cite the following publication: https://doi.org/10.3389/fnana.2015.00142. +If you use this dataset in your research please cite the following publication: +https://doi.org/10.3389/fnana.2015.00142. """ import os from typing import List, Optional, Tuple, Union +from torch.utils.data import Dataset, DataLoader + import torch_em from .. import util + ISBI_URL = "https://oc.embl.de/index.php/s/h0TkwqxU0PJDdMd/download" CHECKSUM = "0e10fe909a1243084d91773470856993b7d40126a12e85f0f1345a7a9e512f29" -def get_isbi_data(path: Union[os.PathLike, str], download: bool) -> str: +def get_isbi_data(path: Union[os.PathLike, str], download: bool = False): """Download the ISBI2012 dataset. + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + os.makedirs(path, exist_ok=True) + util.download_source(os.path.join(path, "isbi.h5"), ISBI_URL, download, CHECKSUM) + + +def get_isbi_paths(path: Union[os.PathLike, str], download: bool = False) -> str: + """Get path to ISBI data. + Args: path: Filepath to a folder where the downloaded data will be saved. download: Whether to download the data if it is not present. Returns: - The path to the downloaded data. + The filepath for the stored data. """ - if path.endswith(".h5"): - volume_path = path - else: - os.makedirs(path, exist_ok=True) - volume_path = os.path.join(path, "isbi.h5") - util.download_source(volume_path, ISBI_URL, download, CHECKSUM) + get_isbi_data(path, download) + volume_path = os.path.join(path, "isbi.h5") return volume_path @@ -41,7 +52,7 @@ def get_isbi_dataset( boundaries: bool = False, use_original_labels: bool = False, **kwargs -): +) -> Dataset: """Get the dataset for EM neuron segmentation in ISBI 2012. Args: @@ -57,7 +68,8 @@ def get_isbi_dataset( The segmentation dataset. """ assert len(patch_shape) == 3 - volume_path = get_isbi_data(path, download) + + volume_path = get_isbi_paths(path, download) ndim = 2 if patch_shape[0] == 1 else 3 kwargs = util.update_kwargs(kwargs, "ndim", ndim) @@ -66,10 +78,14 @@ def get_isbi_dataset( kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets ) - raw_key = "raw" - label_key = "labels/membranes" if use_original_labels else "labels/gt_segmentation" - - return torch_em.default_segmentation_dataset(volume_path, raw_key, volume_path, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=volume_path, + raw_key="raw", + label_paths=volume_path, + label_key="labels/membranes" if use_original_labels else "labels/gt_segmentation", + patch_shape=patch_shape, + **kwargs + ) def get_isbi_loader( @@ -81,7 +97,7 @@ def get_isbi_loader( boundaries: bool = False, use_original_labels: bool = False, **kwargs -): +) -> DataLoader: """Get the DataLoader for EM neuron segmentation in ISBI 2012. Args: @@ -97,13 +113,9 @@ def get_isbi_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_isbi_dataset( - path, patch_shape, download=download, - offsets=offsets, boundaries=boundaries, use_original_labels=use_original_labels, - **ds_kwargs + path, patch_shape, download=download, offsets=offsets, + boundaries=boundaries, use_original_labels=use_original_labels, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/kasthuri.py b/torch_em/data/datasets/electron_microscopy/kasthuri.py index ba8fb246..2e391e8f 100644 --- a/torch_em/data/datasets/electron_microscopy/kasthuri.py +++ b/torch_em/data/datasets/electron_microscopy/kasthuri.py @@ -21,6 +21,7 @@ from .. import util + URL = "http://www.casser.io/files/kasthuri_pp.zip " CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792" @@ -69,7 +70,7 @@ def _create_data(root, inputs, out_path): f.create_dataset("labels", data=labels, compression="gzip") -def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str: +def get_kasthuri_data(path: Union[os.PathLike, str], download: bool = False) -> str: """Download the kasthuri dataset. Args: @@ -100,12 +101,25 @@ def get_kasthuri_data(path: Union[os.PathLike, str], download: bool) -> str: return path +def get_kasthuri_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> str: + """Get paths to the Kasthuri data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the stored data. + """ + get_kasthuri_data(path, download) + data_path = os.path.join(path, f"kasthuri_{split}.h5") + assert os.path.exists(data_path), data_path + return data_path + + def get_kasthuri_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int, int], - download: bool = False, - **kwargs + path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs ) -> Dataset: """Get dataset for EM mitochondrion segmentation in the kasthuri dataset. @@ -120,11 +134,17 @@ def get_kasthuri_dataset( The segmentation dataset. """ assert split in ("train", "test") - get_kasthuri_data(path, download) - data_path = os.path.join(path, f"kasthuri_{split}.h5") - assert os.path.exists(data_path), data_path - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) + + data_path = get_kasthuri_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="raw", + label_paths=data_path, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_kasthuri_loader( @@ -148,9 +168,6 @@ def get_kasthuri_loader( Returns: The PyTorch DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/lucchi.py b/torch_em/data/datasets/electron_microscopy/lucchi.py index 3b63ba51..e4d5dec8 100644 --- a/torch_em/data/datasets/electron_microscopy/lucchi.py +++ b/torch_em/data/datasets/electron_microscopy/lucchi.py @@ -10,7 +10,7 @@ from tqdm import tqdm from shutil import rmtree from concurrent import futures -from typing import Tuple, Union +from typing import Tuple, Union, Literal import imageio import numpy as np @@ -21,6 +21,7 @@ from .. import util + URL = "http://www.casser.io/files/lucchi_pp.zip" CHECKSUM = "770ce9e98fc6f29c1b1a250c637e6c5125f2b5f1260e5a7687b55a79e2e8844d" @@ -64,8 +65,8 @@ def _create_data(root, inputs, out_path): f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") -def get_lucchi_data(path: Union[os.PathLike, str], split: str, download: bool) -> str: - """Download the lucchi dataset. +def get_lucchi_data(path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False) -> str: + """Download the Lucchi dataset. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -98,14 +99,30 @@ def get_lucchi_data(path: Union[os.PathLike, str], split: str, download: bool) - return data_path +def get_lucchi_paths(path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False) -> str: + """Get paths to the Lucchi data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + get_lucchi_data(path, split, download) + data_path = os.path.join(path, f"lucchi_{split}.h5") + return data_path + + def get_lucchi_dataset( path: Union[os.PathLike, str], - split: str, + split: Literal["train", "test"], patch_shape: Tuple[int, int, int], download: bool = False, **kwargs ) -> Dataset: - """Get dataset for EM mitochondrion segmentation in the lucchi dataset. + """Get dataset for EM mitochondrion segmentation in the Lucchi dataset. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -118,20 +135,28 @@ def get_lucchi_dataset( The segmentation dataset. """ assert split in ("train", "test") - data_path = get_lucchi_data(path, split, download) - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) + + data_path = get_lucchi_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="raw", + label_paths=data_path, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_lucchi_loader( path: Union[os.PathLike, str], - split: str, + split: Literal["train", "test"], patch_shape: Tuple[int, int, int], batch_size: int, download: bool = False, **kwargs ) -> DataLoader: - """Get dataloader for EM mitochondrion segmentation in the lucchi dataset. + """Get dataloader for EM mitochondrion segmentation in the Lucchi dataset. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -144,9 +169,6 @@ def get_lucchi_loader( Returns: The PyTorch DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_lucchi_dataset(path, split, patch_shape, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/mitoem.py b/torch_em/data/datasets/electron_microscopy/mitoem.py index a95eff4a..04716efb 100644 --- a/torch_em/data/datasets/electron_microscopy/mitoem.py +++ b/torch_em/data/datasets/electron_microscopy/mitoem.py @@ -1,7 +1,7 @@ """MitoEM is a dataset for segmenting mitochondria in electron microscopy. - It contains two large annotated volumes, one from rat cortex, the other from human cortex. This dataset was used for a segmentation challenge at ISBI 2022. + If you use it in your research then please cite https://doi.org/10.1007/978-3-030-59722-1_7. """ @@ -21,6 +21,7 @@ from .. import util + URLS = { "raw": { "human": "https://www.dropbox.com/s/z41qtu4y735j95e/EM30-H-im.zip?dl=1", @@ -136,9 +137,7 @@ def _require_mitoem_sample(path, sample, download): rmtree(val_folder) -def get_mitoem_data( - path: Union[os.PathLike, str], samples: Sequence[str], splits: Sequence[str], download: bool -) -> List[str]: +def get_mitoem_data(path: Union[os.PathLike, str], samples: Sequence[str], splits: Sequence[str], download: bool): """Download the MitoEM training data. Args: @@ -146,9 +145,6 @@ def get_mitoem_data( samples: The samples to download. The available samples are 'human' and 'rat'. splits: The data splits to download. The available splits are 'train', 'val' and 'test'. download: Whether to download the data if it is not present. - - Returns: - The paths to the downloaded and converted files. """ if isinstance(splits, str): splits = [splits] @@ -156,7 +152,6 @@ def get_mitoem_data( assert len(set(samples) - {"human", "rat"}) == 0, f"{samples}" os.makedirs(path, exist_ok=True) - data_paths = [] for sample in samples: if not _check_data(path, sample): print("The MitoEM data for sample", sample, "is not available yet and will be downloaded and created.") @@ -167,7 +162,27 @@ def get_mitoem_data( for split in splits: split_path = os.path.join(path, f"{sample}_{split}.n5") assert os.path.exists(split_path), split_path - data_paths.append(split_path) + + +def get_mitoem_paths( + path: Union[os.PathLike, str], + splits: Sequence[str], + samples: Sequence[str] = ("human", "rat"), + download: bool = False, +) -> List[str]: + """Get paths for MitoEM data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + samples: The samples to download. The available samples are 'human' and 'rat'. + splits: The data splits to download. The available splits are 'train', 'val' and 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepaths for the stored data. + """ + get_mitoem_data(path, samples, splits, download) + data_paths = [os.path.join(path, f"{sample}_{split}.n5") for split in splits for sample in samples] return data_paths @@ -200,14 +215,20 @@ def get_mitoem_dataset( """ assert len(patch_shape) == 3 - data_paths = get_mitoem_data(path, samples, splits, download) + data_paths = get_mitoem_paths(path, samples, splits, download) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets ) - raw_key = "raw" - label_key = "labels" - return torch_em.default_segmentation_dataset(data_paths, raw_key, data_paths, label_key, patch_shape, **kwargs) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="raw", + label_paths=data_paths, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_mitoem_loader( @@ -239,14 +260,9 @@ def get_mitoem_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_mitoem_dataset( - path, splits, patch_shape, - samples=samples, download=download, - offsets=offsets, boundaries=boundaries, binary=binary, - **ds_kwargs + path, splits, patch_shape, samples=samples, download=download, + offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/nuc_mm.py b/torch_em/data/datasets/electron_microscopy/nuc_mm.py index b42927a6..a9736a3c 100644 --- a/torch_em/data/datasets/electron_microscopy/nuc_mm.py +++ b/torch_em/data/datasets/electron_microscopy/nuc_mm.py @@ -4,10 +4,9 @@ Please cite it if you use this dataset for a publication. """ - import os from glob import glob -from typing import Tuple, Union, Literal +from typing import Tuple, Union, Literal, List import torch_em @@ -15,6 +14,7 @@ from .. import util + URL = "https://drive.google.com/drive/folders/1_4CrlYvzx0ITnGlJOHdgcTRgeSkm9wT8" @@ -37,11 +37,7 @@ def _extract_split(image_folder, label_folder, output_folder): f.create_dataset("labels", data=seg, compression="gzip") -def get_nuc_mm_data( - path: Union[os.PathLike, str], - sample: Literal['mouse', 'zebrafish'], - download: bool -) -> str: +def get_nuc_mm_data(path: Union[os.PathLike, str], sample: Literal['mouse', 'zebrafish'], download: bool) -> str: """Download the NucMM training data. Args: @@ -79,6 +75,26 @@ def get_nuc_mm_data( return sample_folder +def get_nuc_mm_paths( + path: Union[os.PathLike], sample: Literal['mouse', 'zebrafish'], split: str, download: bool = False, +) -> List[str]: + """Get paths to the NucMM data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample: The NucMM samples to use. The available samples are 'mouse' and 'zebrafish'. + split: The split for the dataset, either 'train' or 'val'. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the stored data. + """ + get_nuc_mm_data(path, sample, download) + split_folder = os.path.join(path, sample, split) + paths = sorted(glob(os.path.join(split_folder, "*.h5"))) + return paths + + def get_nuc_mm_dataset( path: Union[os.PathLike, str], sample: Literal['mouse', 'zebrafish'], @@ -102,13 +118,16 @@ def get_nuc_mm_dataset( """ assert split in ("train", "val") - sample_folder = get_nuc_mm_data(path, sample, download) - split_folder = os.path.join(sample_folder, split) - paths = sorted(glob(os.path.join(split_folder, "*.h5"))) + paths = get_nuc_mm_paths(path, sample, split, download) - raw_key, label_key = "raw", "labels" return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, is_seg_dataset=True, **kwargs + raw_paths=paths, + raw_key="raw", + label_paths=paths, + label_key="labels", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs ) @@ -135,8 +154,6 @@ def get_nuc_mm_loader( Returns: The segmentation dataset. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_nuc_mm_dataset(path, sample, split, patch_shape, download, **ds_kwargs) return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/platynereis.py b/torch_em/data/datasets/electron_microscopy/platynereis.py index 86432f31..5ff737fe 100644 --- a/torch_em/data/datasets/electron_microscopy/platynereis.py +++ b/torch_em/data/datasets/electron_microscopy/platynereis.py @@ -21,6 +21,7 @@ from .. import util + URLS = { "cells": "https://zenodo.org/record/3675220/files/membrane.zip", "nuclei": "https://zenodo.org/record/3675220/files/nuclei.zip", @@ -49,17 +50,6 @@ def _check_data(path, prefix, extension, n_files): return len(files) == n_files -def _get_paths_and_rois(sample_ids, n_files, template, rois): - if sample_ids is None: - sample_ids = list(range(1, n_files + 1)) - else: - assert min(sample_ids) >= 1 and max(sample_ids) <= n_files - sample_ids.sort() - paths = [template % sample for sample in sample_ids] - data_rois = [rois.get(sample, np.s_[:, :, :]) for sample in sample_ids] - return paths, data_rois - - def get_platynereis_data(path: Union[os.PathLike, str], name: str, download: bool) -> Tuple[str, int]: """Download the platynereis dataset. @@ -101,6 +91,37 @@ def get_platynereis_data(path: Union[os.PathLike, str], name: str, download: boo return data_root, n_files +def get_platynereis_paths(path, sample_ids, name, file_template, rois={}, download=False, return_rois=False): + """Get paths to the platynereis data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample_ids: The sample ids to use for the dataset + name: Name of the segmentation task. Available tasks: 'cuticle', 'cilia', 'cells' or 'nuclei'. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + return_rois: Whether to return the extracted rois. + + Returns: + The filepaths for the stored data. + """ + root, n_files = get_platynereis_data(path, name, download) + template = os.path.join(root, file_template) + + if sample_ids is None: + sample_ids = list(range(1, n_files + 1)) + else: + assert min(sample_ids) >= 1 and max(sample_ids) <= n_files + sample_ids.sort() + paths = [template % sample for sample in sample_ids] + data_rois = [rois.get(sample, np.s_[:, :, :]) for sample in sample_ids] + + if return_rois: + return paths, data_rois + else: + return paths + + def get_platynereis_cuticle_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], @@ -122,12 +143,23 @@ def get_platynereis_cuticle_dataset( Returns: The segmentation dataset. """ - cuticle_root, n_files = get_platynereis_data(path, "cuticle", download) - - paths, data_rois = _get_paths_and_rois(sample_ids, n_files, os.path.join(cuticle_root, "train_data_%02i.n5"), rois) - raw_key, label_key = "volumes/raw", "volumes/labels/segmentation" + paths, data_rois = get_platynereis_paths( + path=path, + sample_ids=sample_ids, + name="cuticle", + file_template="train_data_%02i.n5", + rois=rois, + download=download, + return_rois=True, + ) return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, rois=data_rois, **kwargs + raw_paths=paths, + raw_key="volumes/raw", + label_paths=paths, + label_key="volumes/labels/segmentation", + patch_shape=patch_shape, + rois=data_rois, + **kwargs ) @@ -154,9 +186,7 @@ def get_platynereis_cuticle_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_platynereis_cuticle_dataset( path, patch_shape, sample_ids=sample_ids, download=download, rois=rois, **ds_kwargs, ) @@ -190,17 +220,27 @@ def get_platynereis_cilia_dataset( Returns: The segmentation dataset. """ - cilia_root, n_files = get_platynereis_data(path, "cilia", download) - - paths, rois = _get_paths_and_rois(sample_ids, n_files, os.path.join(cilia_root, "train_data_cilia_%02i.h5"), rois) - raw_key = "volumes/raw" - label_key = "volumes/labels/segmentation" - + paths, rois = get_platynereis_paths( + path=path, + sample_ids=sample_ids, + name="cilia", + file_template="train_data_cilia_%02i.h5", + rois=rois, + download=download, + return_rois=True, + ) kwargs = util.update_kwargs(kwargs, "rois", rois) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, ) - return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="volumes/raw", + label_paths=paths, + label_key="volumes/labels/segmentation", + patch_shape=patch_shape, + **kwargs + ) def get_platynereis_cilia_loader( @@ -232,9 +272,7 @@ def get_platynereis_cilia_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_platynereis_cilia_dataset( path, patch_shape, sample_ids=sample_ids, offsets=offsets, boundaries=boundaries, binary=binary, @@ -268,19 +306,29 @@ def get_platynereis_cell_dataset( Returns: The segmentation dataset. """ - cell_root, n_files = get_platynereis_data(path, "cells", download) - - template = os.path.join(cell_root, "train_data_membrane_%02i.n5") - data_paths, data_rois = _get_paths_and_rois(sample_ids, n_files, template, rois) + data_paths, data_rois = get_platynereis_paths( + path=path, + sample_ids=sample_ids, + name="cells", + file_template="train_data_membrane_%02i.n5", + rois=rois, + download=download, + return_rois=True, + ) kwargs = util.update_kwargs(kwargs, "rois", data_rois) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets, ) - raw_key = "volumes/raw/s1" - label_key = "volumes/labels/segmentation/s1" - return torch_em.default_segmentation_dataset(data_paths, raw_key, data_paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="volumes/raw/s1", + label_paths=data_paths, + label_key="volumes/labels/segmentation/s1", + patch_shape=patch_shape, + **kwargs + ) def get_platynereis_cell_loader( @@ -310,9 +358,7 @@ def get_platynereis_cell_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_platynereis_cell_dataset( path, patch_shape, sample_ids, rois=rois, offsets=offsets, boundaries=boundaries, download=download, @@ -348,15 +394,22 @@ def get_platynereis_nuclei_dataset( Returns: The segmentation dataset. """ - nuc_root, n_files = get_platynereis_data(path, "nuclei", download) + _, n_files = get_platynereis_data(path, "nuclei", download) if sample_ids is None: sample_ids = list(range(1, n_files + 1)) assert min(sample_ids) >= 1 and max(sample_ids) <= n_files sample_ids.sort() - template = os.path.join(nuc_root, "train_data_nuclei_%02i.h5") - data_paths, data_rois = _get_paths_and_rois(sample_ids, n_files, template, rois) + data_paths, data_rois = get_platynereis_paths( + path=path, + sample_ids=sample_ids, + name="nuclei", + file_template="train_data_nuclei_%02i.h5", + rois=rois, + download=download, + return_rois=True, + ) kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) kwargs = util.update_kwargs(kwargs, "rois", data_rois) @@ -364,9 +417,14 @@ def get_platynereis_nuclei_dataset( kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, ) - raw_key = "volumes/raw" - label_key = "volumes/labels/nucleus_instance_labels" - return torch_em.default_segmentation_dataset(data_paths, raw_key, data_paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="volumes/raw", + label_paths=data_paths, + label_key="volumes/labels/nucleus_instance_labels", + patch_shape=patch_shape, + **kwargs + ) def get_platynereis_nuclei_loader( @@ -380,7 +438,7 @@ def get_platynereis_nuclei_loader( rois: Dict[int, Any] = {}, download: bool = False, **kwargs -): +) -> DataLoader: """Get the dataloader for nucleus segmentation in platynereis. Args: @@ -398,9 +456,7 @@ def get_platynereis_nuclei_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_platynereis_nuclei_dataset( path, patch_shape, sample_ids=sample_ids, rois=rois, offsets=offsets, boundaries=boundaries, binary=binary, download=download, diff --git a/torch_em/data/datasets/electron_microscopy/snemi.py b/torch_em/data/datasets/electron_microscopy/snemi.py index c49df190..05f36638 100644 --- a/torch_em/data/datasets/electron_microscopy/snemi.py +++ b/torch_em/data/datasets/electron_microscopy/snemi.py @@ -1,6 +1,6 @@ """SNEMI is a dataset for neuron segmentation in EM. - It contains an annotated volumes from the mouse brain. + The data is part of the publication https://doi.org/10.1016/j.cell.2015.06.054. Please cite it if you use this dataset for a publication. """ @@ -14,6 +14,7 @@ from .. import util + SNEMI_URLS = { "train": "https://oc.embl.de/index.php/s/43iMotlXPyAB39z/download", "test": "https://oc.embl.de/index.php/s/aRhphk35H23De2s/download" @@ -24,20 +25,32 @@ } -def get_snemi_data(path: Union[os.PathLike, str], sample: str, download: bool) -> str: +def get_snemi_data(path: Union[os.PathLike, str], sample: str, download: bool = False): """Download the SNEMI training data. Args: path: Filepath to a folder where the downloaded data will be saved. sample: The sample to download, either 'train' or 'test'. download: Whether to download the data if it is not present. - - Returns: - The path to the downloaded data. """ os.makedirs(path, exist_ok=True) data_path = os.path.join(path, f"snemi_{sample}.h5") util.download_source(data_path, SNEMI_URLS[sample], download, CHECKSUMS[sample]) + + +def get_snemi_paths(path: Union[os.PathLike, str], sample: str, download: bool = False) -> str: + """Get path to the SNEMI data. + + Args: + path: Filepath to a folder where the downloaded data is saved. + sample: The sample to download, either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + get_snemi_data(path, sample, download) + data_path = os.path.join(path, f"snemi_{sample}.h5") assert os.path.exists(data_path), data_path return data_path @@ -66,16 +79,22 @@ def get_snemi_dataset( The segmentation dataset. """ assert len(patch_shape) == 3 - data_path = get_snemi_data(path, sample, download) + + data_path = get_snemi_paths(path, sample, download) kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets ) - raw_key = "volumes/raw" - label_key = "volumes/labels/neuron_ids" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="volumes/raw", + label_paths=data_path, + label_key="volumes/labels/neuron_ids", + patch_shape=patch_shape, + **kwargs + ) def get_snemi_loader( @@ -103,9 +122,7 @@ def get_snemi_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_snemi_dataset( path=path, patch_shape=patch_shape, diff --git a/torch_em/data/datasets/electron_microscopy/sponge_em.py b/torch_em/data/datasets/electron_microscopy/sponge_em.py index ad0f1ba2..635b916d 100644 --- a/torch_em/data/datasets/electron_microscopy/sponge_em.py +++ b/torch_em/data/datasets/electron_microscopy/sponge_em.py @@ -8,7 +8,7 @@ import os from glob import glob -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union, List from torch.utils.data import Dataset, DataLoader @@ -16,6 +16,7 @@ from .. import util + URL = "https://zenodo.org/record/8150818/files/sponge_em_train_data.zip?download=1" CHECKSUM = "f1df616cd60f81b91d7642933e9edd74dc6c486b2e546186a7c1e54c67dd32a5" @@ -51,6 +52,28 @@ def get_sponge_em_data(path: Union[os.PathLike, str], download: bool) -> Tuple[s return path, n_files +def get_sponge_em_paths( + path: Union[os.PathLike, str], sample_ids: Optional[Sequence[int]], download: bool = False +) -> List[str]: + """Get paths to the SpongeEM data. + + Args: + path: Filepath to a folder where the downloaded data will saved. + sample_ids: The sample to download, valid ids are 1, 2 and 3. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the stored data. + """ + data_folder, n_files = get_sponge_em_data(path, download) + + if sample_ids is None: + sample_ids = range(1, n_files + 1) + + paths = [os.path.join(data_folder, f"train_data_0{i}.h5") for i in sample_ids] + return paths + + def get_sponge_em_dataset( path: Union[os.PathLike, str], mode: str, @@ -72,17 +95,18 @@ def get_sponge_em_dataset( Returns: The segmentation dataset. """ - assert mode in ("semantic", "instances") - data_folder, n_files = get_sponge_em_data(path, download) - if sample_ids is None: - sample_ids = range(1, n_files + 1) - paths = [os.path.join(data_folder, f"train_data_0{i}.h5") for i in sample_ids] + paths = get_sponge_em_paths(path, sample_ids, download) - raw_key = "volumes/raw" - label_key = f"volumes/labels/{mode}" - return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="volumes/raw", + label_paths=paths, + label_key=f"volumes/labels/{mode}", + patch_shape=patch_shape, + **kwargs + ) def get_sponge_em_loader( @@ -108,8 +132,6 @@ def get_sponge_em_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_sponge_em_dataset(path, mode, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs) return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/uro_cell.py b/torch_em/data/datasets/electron_microscopy/uro_cell.py index 4c0994a8..bd8e12ad 100644 --- a/torch_em/data/datasets/electron_microscopy/uro_cell.py +++ b/torch_em/data/datasets/electron_microscopy/uro_cell.py @@ -3,8 +3,8 @@ - Golgi Apparatus - Lysosomes - Mitochondria - It contains several FIB-SEM volumes with annotations. + This dataset is from the publication https://doi.org/10.1016/j.compbiomed.2020.103693. Please cite it if you use this dataset for a publication. """ @@ -26,7 +26,7 @@ CHECKSUM = "a48cf31b06114d7def642742b4fcbe76103483c069122abe10f377d71a1acabc" -def get_uro_cell_data(path: Union[os.PathLike, str], download: bool) -> str: +def get_uro_cell_data(path: Union[os.PathLike, str], download: bool = False) -> str: """Download the UroCell training data. Args: @@ -95,14 +95,34 @@ def get_uro_cell_data(path: Union[os.PathLike, str], download: bool) -> str: return path -def _get_paths(path, target): +def get_uro_cell_paths( + path: Union[os.PathLike], target: str, download: bool = False, return_label_key: bool = False, +) -> List[str]: + """Get paths to the UroCell data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + target: The segmentation target, corresponding to the organelle to segment. + Available organelles are 'fv', 'golgi', 'lyso' and 'mito'. + download: Whether to download the data if it is not present. + return_label_key: Whether to return the label key. + + Returns: + List of filepaths to the stored data. + """ import h5py + get_uro_cell_data(path, download) + label_key = f"labels/{target}" all_paths = glob(os.path.join(path, "*.h5")) all_paths.sort() paths = [path for path in all_paths if label_key in h5py.File(path, "r")] - return paths, label_key + + if return_label_key: + return paths, label_key + else: + return paths def get_uro_cell_dataset( @@ -132,8 +152,8 @@ def get_uro_cell_dataset( The segmentation dataset. """ assert target in ("fv", "golgi", "lyso", "mito") - get_uro_cell_data(path, download) - paths, label_key = _get_paths(path, target) + + paths, label_key = get_uro_cell_paths(path, target, download, return_label_key=True) assert sum((offsets is not None, boundaries, binary)) <= 1, f"{offsets}, {boundaries}, {binary}" if offsets is not None: @@ -142,10 +162,9 @@ def get_uro_cell_dataset( f"{target} does not have instance labels, affinities will be computed based on binary segmentation." ) # we add a binary target channel for foreground background segmentation - label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets, - ignore_label=None, - add_binary_target=True, - add_mask=True) + label_transform = torch_em.transform.label.AffinityTransform( + offsets=offsets, ignore_label=None, add_binary_target=True, add_mask=True + ) msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." kwargs = util.update_kwargs(kwargs, 'label_transform2', label_transform, msg=msg) elif boundaries: @@ -161,9 +180,14 @@ def get_uro_cell_dataset( msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." kwargs = util.update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) - raw_key = "raw" return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, is_seg_dataset=True, **kwargs + raw_paths=paths, + raw_key="raw", + label_paths=paths, + label_key=label_key, + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs ) @@ -195,9 +219,7 @@ def get_uro_cell_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_uro_cell_dataset( path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs ) diff --git a/torch_em/data/datasets/electron_microscopy/vnc.py b/torch_em/data/datasets/electron_microscopy/vnc.py index 7644e0ad..b6fd8568 100644 --- a/torch_em/data/datasets/electron_microscopy/vnc.py +++ b/torch_em/data/datasets/electron_microscopy/vnc.py @@ -1,6 +1,6 @@ """The VNC dataset contains segmentation annotations for mitochondria in EM. - It contains two volumes from TEM of the drosophila brain. + Please cite https://doi.org/10.6084/m9.figshare.856713.v1 if you use this dataset in your publication. """ @@ -19,6 +19,7 @@ from .. import util + URL = "https://github.com/unidesigner/groundtruth-drosophila-vnc/archive/refs/heads/master.zip" CHECKSUM = "f7bd0db03c86b64440a16b60360ad60c0a4411f89e2c021c7ee2c8d6af3d7e86" @@ -71,6 +72,21 @@ def get_vnc_data(path: Union[os.PathLike, str], download: bool) -> str: return path +def get_vnc_mito_paths(path: Union[os.PathLike, str], download: bool = False) -> str: + """Get path to the VNC data. + + Args: + path: Filepath to a folder where the downloaded data is saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the stored data. + """ + get_vnc_data(path, download) + data_path = os.path.join(path, "vnc_train.h5") + return data_path + + def get_vnc_mito_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], @@ -94,16 +110,20 @@ def get_vnc_mito_dataset( Returns: The segmentation dataset. """ - get_vnc_data(path, download) - data_path = os.path.join(path, "vnc_train.h5") + data_path = get_vnc_mito_paths(path, download) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, ) - raw_key = "raw" - label_key = "labels/mitochondria" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="raw", + label_paths=data_path, + label_key="labels/mitochondria", + patch_shape=patch_shape, + **kwargs + ) def get_vnc_mito_loader( @@ -131,9 +151,7 @@ def get_vnc_mito_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) ds = get_vnc_mito_dataset( path, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs ) diff --git a/torch_em/data/datasets/light_microscopy/cellpose.py b/torch_em/data/datasets/light_microscopy/cellpose.py index 243e9483..841bcda7 100644 --- a/torch_em/data/datasets/light_microscopy/cellpose.py +++ b/torch_em/data/datasets/light_microscopy/cellpose.py @@ -10,7 +10,7 @@ import os from glob import glob from natsort import natsorted -from typing import Union, Tuple, Literal, Optional +from typing import Union, Tuple, Literal, Optional, List import torch_em @@ -23,19 +23,12 @@ AVAILABLE_CHOICES = ["cyto", "cyto2"] -def _get_cellpose_paths(data_dir): - image_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) - gt_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) - - return image_paths, gt_paths - - def get_cellpose_data( path: Union[os.PathLike, str], split: Literal["train", "test"], choice: Literal["cyto", "cyto2"], download: bool = False, -): +) -> str: """Instruction to download CellPose data. NOTE: Please download the dataset from "https://www.cellpose.org/dataset". @@ -73,6 +66,32 @@ def get_cellpose_data( return data_dir +def get_cellpose_paths( + path: Union[os.PathLike, str], + split: Literal['train', 'test'], + choice: Literal["cyto", "cyto2"], + download: bool = False, +) -> Tuple[List[str], List[str]]: + """Get paths to the CellPose data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', or 'test'. + choice: The choice of dataset. Either 'cyto' or 'cyto2'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_cellpose_data(path=path, split=split, choice=choice, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) + + return image_paths, gt_paths + + def get_cellpose_dataset( path: Union[os.PathLike, str], split: Literal["train", "test"], @@ -105,8 +124,7 @@ def get_cellpose_dataset( image_paths, gt_paths = [], [] for per_choice in choice: assert per_choice in AVAILABLE_CHOICES - data_dir = get_cellpose_data(path=path, split=split, choice=per_choice, download=download) - per_image_paths, per_gt_paths = _get_cellpose_paths(data_dir=data_dir) + per_image_paths, per_gt_paths = get_cellpose_paths(path, split, choice, download) image_paths.extend(per_image_paths) gt_paths.extend(per_gt_paths) @@ -116,7 +134,7 @@ def get_cellpose_dataset( if "transform" not in kwargs: transform = torch_em.transform.get_augmentations(ndim=2) - dataset = torch_em.default_segmentation_dataset( + return torch_em.default_segmentation_dataset( raw_paths=image_paths, raw_key=None, label_paths=gt_paths, @@ -127,7 +145,6 @@ def get_cellpose_dataset( transform=transform, **kwargs ) - return dataset def get_cellpose_loader( @@ -155,12 +172,6 @@ def get_cellpose_loader( """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_cellpose_dataset( - path=path, - split=split, - patch_shape=patch_shape, - choice=choice, - download=download, - **ds_kwargs + path=path, split=split, patch_shape=patch_shape, choice=choice, download=download, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/cellseg_3d.py b/torch_em/data/datasets/light_microscopy/cellseg_3d.py index be7cab23..69482e89 100644 --- a/torch_em/data/datasets/light_microscopy/cellseg_3d.py +++ b/torch_em/data/datasets/light_microscopy/cellseg_3d.py @@ -6,7 +6,7 @@ import os from glob import glob -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List from torch.utils.data import Dataset, DataLoader @@ -14,11 +14,12 @@ from .. import util + URL = "https://zenodo.org/records/11095111/files/DATASET_WITH_GT.zip?download=1" CHECKSUM = "6d8e8d778e479000161fdfea70201a6ded95b3958a703f69def63e69bbddf9d6" -def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool) -> str: +def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool = False) -> str: """Download the CellSeg3d training data. Args: @@ -43,6 +44,26 @@ def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool) -> str: return data_path +def get_cellseg_3d_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: + """Get paths to the CellSeg3d data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_root = get_cellseg_3d_data(path, download) + + raw_paths = sorted(glob(os.path.join(data_root, "*.tif"))) + label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif"))) + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + def get_cellseg_3d_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], @@ -62,20 +83,20 @@ def get_cellseg_3d_dataset( Returns: The segmentation dataset. """ - data_root = get_cellseg_3d_data(path, download) + raw_paths, label_paths = get_cellseg_3d_paths(path, download) - raw_paths = sorted(glob(os.path.join(data_root, "*.tif"))) - label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif"))) - assert len(raw_paths) == len(label_paths) if sample_ids is not None: assert all(sid < len(raw_paths) for sid in sample_ids) raw_paths = [raw_paths[i] for i in sample_ids] label_paths = [label_paths[i] for i in sample_ids] - raw_key, label_key = None, None - return torch_em.default_segmentation_dataset( - raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + **kwargs ) @@ -87,7 +108,7 @@ def get_cellseg_3d_loader( download: bool = False, **kwargs ) -> DataLoader: - """Get the CellSeg3d dataloder for segmenting nuclei in 3d fluorescence microscopy. + """Get the CellSeg3d dataloader for segmenting nuclei in 3d fluorescence microscopy. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -101,8 +122,5 @@ def get_cellseg_3d_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_cellseg_3d_dataset( - path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs, - ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader + dataset = get_cellseg_3d_dataset(path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/covid_if.py b/torch_em/data/datasets/light_microscopy/covid_if.py index 90824def..59bbb74f 100644 --- a/torch_em/data/datasets/light_microscopy/covid_if.py +++ b/torch_em/data/datasets/light_microscopy/covid_if.py @@ -15,12 +15,13 @@ from .. import util + COVID_IF_URL = "https://zenodo.org/record/5092850/files/covid-if-groundtruth.zip?download=1" CHECKSUM = "d9cd6c85a19b802c771fb4ff928894b19a8fab0e0af269c49235fdac3f7a60e1" -def get_covid_if_data(path: Union[os.PathLike, str], download: bool) -> str: - """Download the CovidIF training data. +def get_covid_if_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the Covid-IF training data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -43,6 +44,36 @@ def get_covid_if_data(path: Union[os.PathLike, str], download: bool) -> str: return path +def get_covid_if_paths( + path: Union[os.PathLike, str], + sample_range: Optional[Tuple[int, int]] = None, + download: bool = False +) -> List[str]: + """Get paths to the Covid-IF data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample_range: Id range of samples to load from the training dataset. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the stored data. + """ + get_covid_if_data(path, download) + + file_paths = sorted(glob(os.path.join(path, "*.h5"))) + if sample_range is not None: + start, stop = sample_range + if start is None: + start = 0 + if stop is None: + stop = len(file_paths) + file_paths = [os.path.join(path, f"gt_image_{idx:03}.h5") for idx in range(start, stop)] + assert all(os.path.exists(fp) for fp in file_paths), f"Invalid sample range {sample_range}" + + return file_paths + + def get_covid_if_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], @@ -54,7 +85,7 @@ def get_covid_if_dataset( binary: bool = False, **kwargs ) -> Dataset: - """Get the CovidIF dataset for segmenting nuclei or cells in immunofluorescence microscopy. + """Get the Covid-IF dataset for segmenting nuclei or cells in immunofluorescence microscopy. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -73,7 +104,6 @@ def get_covid_if_dataset( available_targets = ("cells", "nuclei") # TODO also support infected_cells # available_targets = ("cells", "nuclei", "infected_cells") - assert target in available_targets, f"{target} not found in {available_targets}" if target == "cells": raw_key = "raw/serum_IgG/s0" @@ -81,18 +111,10 @@ def get_covid_if_dataset( elif target == "nuclei": raw_key = "raw/nuclei/s0" label_key = "labels/nuclei/s0" + else: + raise ValueError(f"{target} not found in {available_targets}") - get_covid_if_data(path, download) - - file_paths = sorted(glob(os.path.join(path, "*.h5"))) - if sample_range is not None: - start, stop = sample_range - if start is None: - start = 0 - if stop is None: - stop = len(file_paths) - file_paths = [os.path.join(path, f"gt_image_{idx:03}.h5") for idx in range(start, stop)] - assert all(os.path.exists(fp) for fp in file_paths), f"Invalid sample range {sample_range}" + file_paths = get_covid_if_paths(path, sample_range, download) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets @@ -100,7 +122,12 @@ def get_covid_if_dataset( kwargs = util.update_kwargs(kwargs, "ndim", 2) return torch_em.default_segmentation_dataset( - file_paths, raw_key, file_paths, label_key, patch_shape, **kwargs + raw_paths=file_paths, + raw_key=raw_key, + label_paths=file_paths, + label_key=label_key, + patch_shape=patch_shape, + **kwargs ) @@ -116,7 +143,7 @@ def get_covid_if_loader( binary: bool = False, **kwargs ) -> DataLoader: - """Get the CovidIF dataloder for segmenting nuclei or cells in immunofluorescence microscopy. + """Get the Covid-IF dataloder for segmenting nuclei or cells in immunofluorescence microscopy. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -138,5 +165,4 @@ def get_covid_if_loader( path, patch_shape, sample_range=sample_range, target=target, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs, ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/ctc.py b/torch_em/data/datasets/light_microscopy/ctc.py index efea0106..d6ed6e16 100644 --- a/torch_em/data/datasets/light_microscopy/ctc.py +++ b/torch_em/data/datasets/light_microscopy/ctc.py @@ -1,6 +1,6 @@ """The Cell Tracking Challenge contains annotated data for cell segmentation and tracking. +We currently provide the 2d datasets with segmentation annotations. -We currently cprovide the 2d datasets with segmentation annotations. If you use this data in your research please cite https://doi.org/10.1038/nmeth.4473. """ @@ -55,20 +55,17 @@ def _get_ctc_url_and_checksum(dataset_name, split): return url, checksum -def get_ctc_data( - path: Union[os.PathLike, str], - dataset_name: str, - download: bool, - split: str +def get_ctc_segmentation_data( + path: Union[os.PathLike, str], dataset_name: str, split: str, download: bool = False, ) -> str: - f"""Download training data from the cell tracking challenge. + f"""Download training data from the Cell Tracking Challenge. Args: path: Filepath to a folder where the downloaded data will be saved. dataset_name: Name of the dataset to be downloaded. The available datasets are: {', '.join(CTC_CHECKSUMS['train'].keys())} - download: Whether to download the data if it is not present. split: The split to download. Either 'train' or 'test'. + download: Whether to download the data if it is not present. Returns: The filepath to the training data. @@ -123,6 +120,40 @@ def _require_gt_images(data_path, vol_ids): return image_paths, label_paths +def get_ctc_segmentation_paths( + path: Union[os.PathLike, str], + dataset_name: str, + split: str = "train", + vol_id: Optional[int] = None, + download: bool = False, +) -> Tuple[str, str]: + f"""Get paths to the Cell Tracking Challenge data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_name: Name of the dataset to be downloaded. The available datasets are: + {', '.join(CTC_CHECKSUMS['train'].keys())} + split: The split to download. Currently only supports 'train'. + vol_id: The train id to load. + download: Whether to download the data if it is not present. + + Returns: + Filepath to the folder where image data is stored. + Filepath to the folder where label data is stored. + """ + data_path = get_ctc_segmentation_data(path, dataset_name, download, split) + + if vol_id is None: + vol_ids = glob(os.path.join(data_path, "*_GT")) + vol_ids = [os.path.basename(vol_id) for vol_id in vol_ids] + vol_ids = [vol_id.rstrip("_GT") for vol_id in vol_ids] + else: + vol_ids = vol_id + + image_path, label_path = _require_gt_images(data_path, vol_ids) + return image_path, label_path + + def get_ctc_segmentation_dataset( path: Union[os.PathLike, str], dataset_name: str, @@ -149,20 +180,18 @@ def get_ctc_segmentation_dataset( """ assert split in ["train"] - data_path = get_ctc_data(path, dataset_name, download, split) - - if vol_id is None: - vol_ids = glob(os.path.join(data_path, "*_GT")) - vol_ids = [os.path.basename(vol_id) for vol_id in vol_ids] - vol_ids = [vol_id.rstrip("_GT") for vol_id in vol_ids] - else: - vol_ids = vol_id - - image_path, label_path = _require_gt_images(data_path, vol_ids) + image_path, label_path = get_ctc_segmentation_paths(path, dataset_name, split, vol_id, download) kwargs = util.update_kwargs(kwargs, "ndim", 2) + return torch_em.default_segmentation_dataset( - image_path, "*.tif", label_path, "*.tif", patch_shape, is_seg_dataset=True, **kwargs + raw_paths=image_path, + raw_key="*.tif", + label_paths=label_path, + label_key="*.tif", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs ) @@ -176,7 +205,7 @@ def get_ctc_segmentation_loader( download: bool = False, **kwargs, ) -> DataLoader: - """Get the CTC dataloader for cell segmentation. + f"""Get the CTC dataloader for cell segmentation. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -192,12 +221,8 @@ def get_ctc_segmentation_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_ctc_segmentation_dataset( path, dataset_name, patch_shape, split=split, vol_id=vol_id, download=download, **ds_kwargs, ) - - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/deepbacs.py b/torch_em/data/datasets/light_microscopy/deepbacs.py index 83702627..24be8468 100644 --- a/torch_em/data/datasets/light_microscopy/deepbacs.py +++ b/torch_em/data/datasets/light_microscopy/deepbacs.py @@ -100,7 +100,24 @@ def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bo return data_folder -def _get_paths(path, bac_type, split): +def get_deepbacs_paths( + path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False +) -> Tuple[str, str]: + f"""Get paths to the DeepBacs data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + bac_type: The bacteria type. The available types are: + {', '.join(URLS.keys())} + download: Whether to download the data if it is not present. + + Returns: + Filepath to the folder where image data is stored. + Filepath to the folder where label data is stored. + """ + get_deepbacs_data(path, bac_type, download) + # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet # mixed is the combination of all other types if split == "train": @@ -110,8 +127,10 @@ def _get_paths(path, bac_type, split): if bac_type != "mixed": raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}") + image_folder = os.path.join(path, bac_type, dir_choice, "source") label_folder = os.path.join(path, bac_type, dir_choice, "target") + return image_folder, label_folder @@ -138,12 +157,17 @@ def get_deepbacs_dataset( The segmentation dataset. """ assert split in ("train", "val", "test") - get_deepbacs_data(path, bac_type, download) - image_folder, label_folder = _get_paths(path, bac_type, split) - dataset = torch_em.default_segmentation_dataset( - image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs + + image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=image_folder, + raw_key="*.tif", + label_paths=label_folder, + label_key="*.tif", + patch_shape=patch_shape, + **kwargs ) - return dataset def get_deepbacs_loader( @@ -172,5 +196,4 @@ def get_deepbacs_loader( """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dic_hepg2.py b/torch_em/data/datasets/light_microscopy/dic_hepg2.py index cea5e20e..30b5160e 100644 --- a/torch_em/data/datasets/light_microscopy/dic_hepg2.py +++ b/torch_em/data/datasets/light_microscopy/dic_hepg2.py @@ -31,10 +31,7 @@ CHECKSUM = "42b939d01c5fc2517dc3ad34bde596ac38dbeba2a96173f37e1b6dfe14cbe3a2" -def get_dic_hepg2_data( - path: Union[str, os.PathLike], - download: bool = False, -) -> str: +def get_dic_hepg2_data(path: Union[str, os.PathLike], download: bool = False) -> str: """Download the DIC HepG2 dataset. Args: @@ -42,7 +39,7 @@ def get_dic_hepg2_data( download: Whether to download the data if it is not present. Returns: - Path to the folder where data is stored. + The path to the folder where data is stored. """ if os.path.exists(path): return path @@ -86,10 +83,26 @@ def _create_segmentations_from_coco_annotation(path, split): return image_folder, gt_folder -def _get_dic_hepg2_paths(path, split): +def get_dic_hepg2_paths( + path: Union[os.PathLike, str], split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to DIC HepG2 data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + path = get_dic_hepg2_data(path=path, download=download) + image_folder, gt_folder = _create_segmentations_from_coco_annotation(path=path, split=split) gt_paths = natsorted(glob(os.path.join(gt_folder, "*.tif"))) image_paths = [os.path.join(image_folder, f"{Path(gt_path).stem}.png") for gt_path in gt_paths] + return image_paths, gt_paths @@ -118,15 +131,14 @@ def get_dic_hepg2_dataset( Returns: The segmentation dataset. """ - path = get_dic_hepg2_data(path=path, download=download) - image_paths, gt_paths = _get_dic_hepg2_paths(path=path, split=split) + image_paths, gt_paths = get_dic_hepg2_paths(path=path, split=split) kwargs = util.ensure_transforms(ndim=2, **kwargs) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, offsets=offsets, boundaries=boundaries, binary=binary ) - dataset = torch_em.default_segmentation_dataset( + return torch_em.default_segmentation_dataset( raw_paths=image_paths, raw_key=None, label_paths=gt_paths, @@ -135,7 +147,6 @@ def get_dic_hepg2_dataset( is_seg_dataset=False, **kwargs ) - return dataset def get_dic_hepg2_loader( @@ -176,5 +187,4 @@ def get_dic_hepg2_loader( download=download, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dsb.py b/torch_em/data/datasets/light_microscopy/dsb.py index 5c38680b..be18fd86 100644 --- a/torch_em/data/datasets/light_microscopy/dsb.py +++ b/torch_em/data/datasets/light_microscopy/dsb.py @@ -15,6 +15,7 @@ from .. import util + DSB_URLS = { "full": "", # TODO "reduced": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip" @@ -56,6 +57,28 @@ def get_dsb_data(path: Union[os.PathLike, str], source: str, download: bool) -> return path +def get_dsb_paths(path: Union[os.PathLike, str], split: str, source: str, download: bool = False) -> Tuple[str, str]: + """Get paths to the DSB data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + source: The source of the dataset. Can either be 'full' for the complete dataset, + or 'reduced' for the dataset excluding histopathology images. + download: Whether to download the data if it is not present. + + Returns: + Filepath for the folder where the images are stored. + Filepath for the folder where the labels are stored. + """ + get_dsb_data(path, source, download) + + image_path = os.path.join(path, split, "images") + label_path = os.path.join(path, split, "masks") + + return image_path, label_path + + def get_dsb_dataset( path: Union[os.PathLike, str], split: str, @@ -85,17 +108,21 @@ def get_dsb_dataset( The segmentation dataset. """ assert split in ("test", "train"), split - get_dsb_data(path, source, download) - image_path = os.path.join(path, split, "images") - label_path = os.path.join(path, split, "masks") + image_path, label_path = get_dsb_paths(path, split, source, download) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets ) kwargs = util.update_kwargs(kwargs, "ndim", 2) + return torch_em.default_segmentation_dataset( - image_path, "*.tif", label_path, "*.tif", patch_shape, **kwargs + raw_paths=image_path, + raw_key="*.tif", + label_paths=label_path, + label_key="*.tif", + patch_shape=patch_shape, + **kwargs ) @@ -129,13 +156,10 @@ def get_dsb_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_dsb_dataset( path, split, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, source=source, **ds_kwargs, ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py b/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py index 42fb39ce..e88fa085 100644 --- a/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py +++ b/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py @@ -11,7 +11,7 @@ import os from tqdm import tqdm from glob import glob -from typing import Tuple, Union +from typing import Tuple, Union, List import numpy as np import pandas as pd @@ -94,6 +94,25 @@ def get_dynamicnuclearnet_data( return split_folder +def get_dynamicnuclearnet_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> List[str]: + """Get paths to the DynamicNuclearNet data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + split_folder = get_dynamicnuclearnet_data(path, split, download) + assert os.path.exists(split_folder) + data_paths = glob(os.path.join(split_folder, "*.zarr")) + assert len(data_paths) > 0 + + return data_paths + + def get_dynamicnuclearnet_dataset( path: Union[os.PathLike, str], split: str, @@ -113,15 +132,17 @@ def get_dynamicnuclearnet_dataset( Returns: The segmentation dataset. """ - split_folder = get_dynamicnuclearnet_data(path, split, download) - assert os.path.exists(split_folder) - data_path = glob(os.path.join(split_folder, "*.zarr")) - assert len(data_path) > 0 - - raw_key, label_key = "raw", "labels" + data_paths = get_dynamicnuclearnet_paths(path, split, download) return torch_em.default_segmentation_dataset( - data_path, raw_key, data_path, label_key, patch_shape, is_seg_dataset=True, ndim=2, **kwargs + raw_paths=data_paths, + raw_key="raw", + label_paths=data_paths, + label_key="labels", + patch_shape=patch_shape, + is_seg_dataset=True, + ndim=2, + **kwargs ) @@ -148,5 +169,4 @@ def get_dynamicnuclearnet_loader( """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_dynamicnuclearnet_dataset(path, split, patch_shape, download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/embedseg_data.py b/torch_em/data/datasets/light_microscopy/embedseg_data.py index 91e96759..28268389 100644 --- a/torch_em/data/datasets/light_microscopy/embedseg_data.py +++ b/torch_em/data/datasets/light_microscopy/embedseg_data.py @@ -7,7 +7,7 @@ import os from glob import glob -from typing import Tuple, Union +from typing import Tuple, Union, List from torch.utils.data import Dataset, DataLoader @@ -15,6 +15,7 @@ from .. import util + URLS = { "Mouse-Organoid-Cells-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Organoid-Cells-CBG.zip", # noqa "Mouse-Skull-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Skull-Nuclei-CBG.zip", @@ -58,6 +59,31 @@ def get_embedseg_data(path: Union[os.PathLike, str], name: str, download: bool) return data_path +def get_embedseg_paths( + path: Union[os.PathLike, str], name: str, split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the EmbedSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: Name of the dataset to download. + split: The split to use for the dataset. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the mage data. + List of filepaths for the label data. + """ + data_root = get_embedseg_data(path, name, download) + + raw_paths = sorted(glob(os.path.join(data_root, split, "images", "*.tif"))) + label_paths = sorted(glob(os.path.join(data_root, split, "masks", "*.tif"))) + assert len(raw_paths) > 0 + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + def get_embedseg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], @@ -66,7 +92,7 @@ def get_embedseg_dataset( download: bool = False, **kwargs ) -> Dataset: - """Get an EmbedSeg dataset for 3d fluorescence microscopy segmentation. + """Get the EmbedSeg dataset for 3d fluorescence microscopy segmentation. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -79,17 +105,15 @@ def get_embedseg_dataset( Returns: The segmentation dataset. """ - data_root = get_embedseg_data(path, name, download) - - raw_paths = sorted(glob(os.path.join(data_root, split, "images", "*.tif"))) - label_paths = sorted(glob(os.path.join(data_root, split, "masks", "*.tif"))) - assert len(raw_paths) > 0 - assert len(raw_paths) == len(label_paths) - - raw_key, label_key = None, None + raw_paths, label_paths = get_embedseg_paths(path, name, split, download) return torch_em.default_segmentation_dataset( - raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + **kwargs ) @@ -102,7 +126,7 @@ def get_embedseg_loader( download: bool = False, **kwargs ) -> DataLoader: - """Get an EmbedSeg dataloader for 3d fluorescence microscopy segmentation. + """Get the EmbedSeg dataloader for 3d fluorescence microscopy segmentation. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -120,5 +144,4 @@ def get_embedseg_loader( dataset = get_embedseg_dataset( path, name=name, split=split, patch_shape=patch_shape, download=download, **ds_kwargs, ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/gonuclear.py b/torch_em/data/datasets/light_microscopy/gonuclear.py index 1895a83e..b4378112 100644 --- a/torch_em/data/datasets/light_microscopy/gonuclear.py +++ b/torch_em/data/datasets/light_microscopy/gonuclear.py @@ -98,16 +98,13 @@ def get_gonuclear_data(path: Union[os.PathLike, str], download: bool) -> str: Returns: The filepath to the training data. """ - url = URL - checksum = CHECKSUM - data_path = os.path.join(path, "gonuclear_datasets") if os.path.exists(data_path): return data_path os.makedirs(path, exist_ok=True) zip_path = os.path.join(path, "gonuclear.zip") - util.download_source(zip_path, url, download, checksum) + util.download_source(zip_path, URL, download, CHECKSUM) util.unzip(zip_path, path, True) extracted_path = os.path.join(path, "Training image dataset_Tiff Files") @@ -119,6 +116,37 @@ def get_gonuclear_data(path: Union[os.PathLike, str], download: bool) -> str: return data_path +def get_gonuclear_paths( + path: Union[os.PathLike, str], + sample_ids: Optional[Union[int, Tuple[int, ...]]] = None, + download: bool = False +) -> List[str]: + """Get paths to the GoNuclear data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample_ids: The sample ids to load. The valid sample ids are: + 1135, 1136, 1137, 1139, 1170. If none is given all samples will be loaded. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + data_root = get_gonuclear_data(path, download) + + if sample_ids is None: + paths = sorted(glob(os.path.join(data_root, "*.h5"))) + else: + paths = [] + for sample_id in sample_ids: + sample_path = os.path.join(data_root, f"{sample_id}.h5") + if not os.path.exists(sample_path): + raise ValueError(f"Invalid sample id {sample_id}.") + paths.append(sample_path) + + return paths + + def get_gonuclear_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], @@ -147,17 +175,7 @@ def get_gonuclear_dataset( Returns: The segmentation dataset. """ - data_root = get_gonuclear_data(path, download) - - if sample_ids is None: - paths = sorted(glob(os.path.join(data_root, "*.h5"))) - else: - paths = [] - for sample_id in sample_ids: - sample_path = os.path.join(data_root, f"{sample_id}.h5") - if not os.path.exists(sample_path): - raise ValueError(f"Invalid sample id {sample_id}.") - paths.append(sample_path) + paths = get_gonuclear_paths(path, sample_ids, download) if segmentation_task == "nuclei": raw_key = "raw/nuclei" @@ -173,7 +191,12 @@ def get_gonuclear_dataset( ) return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, **kwargs + raw_paths=paths, + raw_key=raw_key, + label_paths=paths, + label_key=label_key, + patch_shape=patch_shape, + **kwargs ) @@ -219,5 +242,4 @@ def get_gonuclear_loader( download=download, **ds_kwargs, ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/hpa.py b/torch_em/data/datasets/light_microscopy/hpa.py index 1458f27e..bd713166 100644 --- a/torch_em/data/datasets/light_microscopy/hpa.py +++ b/torch_em/data/datasets/light_microscopy/hpa.py @@ -321,11 +321,7 @@ def _check_data(path): return have_train and have_test and have_val -def get_hpa_segmentation_data( - path: Union[os.PathLike, str], - download: bool, - n_workers_preproc: int = 8 -) -> str: +def get_hpa_segmentation_data(path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str: """Download the HPA training data. Args: @@ -343,6 +339,25 @@ def get_hpa_segmentation_data( return path +def get_hpa_segmentation_paths( + path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8, +) -> List[str]: + """Get paths to the HPA data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split for the dataset. Available splits are 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + n_workers_preproc: The number of workers to use for preprocessing. + + Returns: + List of filepaths to the stored data. + """ + get_hpa_segmentation_data(path, download, n_workers_preproc) + paths = glob(os.path.join(path, split, "*.h5")) + return paths + + def get_hpa_segmentation_dataset( path: Union[os.PathLike, str], split: str, @@ -378,19 +393,22 @@ def get_hpa_segmentation_dataset( if chan not in VALID_CHANNELS: raise ValueError(f"'{chan}' is not a valid channel for HPA dataset.") - get_hpa_segmentation_data(path, download, n_workers_preproc) - kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets ) kwargs = util.update_kwargs(kwargs, "ndim", 2) kwargs = util.update_kwargs(kwargs, "with_channels", True) - paths = glob(os.path.join(path, split, "*.h5")) - raw_key = [f"raw/{chan}" for chan in channels] - label_key = "labels" + paths = get_hpa_segmentation_paths(path, split, download, n_workers_preproc) - return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key=[f"raw/{chan}" for chan in channels], + label_paths=paths, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_hpa_segmentation_loader( @@ -425,14 +443,11 @@ def get_hpa_segmentation_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_hpa_segmentation_dataset( path, split, patch_shape, offsets=offsets, boundaries=boundaries, binary=binary, channels=channels, download=download, n_workers_preproc=n_workers_preproc, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/livecell.py b/torch_em/data/datasets/light_microscopy/livecell.py index bb2c2121..d0a79449 100644 --- a/torch_em/data/datasets/light_microscopy/livecell.py +++ b/torch_em/data/datasets/light_microscopy/livecell.py @@ -14,12 +14,13 @@ import numpy as np import imageio.v3 as imageio -import torch.utils.data +import torch from torch.utils.data import Dataset, DataLoader import torch_em from .. import util +from ... import ImageCollectionDataset try: from pycocotools.coco import COCO @@ -39,20 +40,6 @@ CHECKSUM = None -def _download_livecell_images(path, download): - os.makedirs(path, exist_ok=True) - image_path = os.path.join(path, "images") - - if os.path.exists(image_path): - return - - url = URLS["images"] - checksum = CHECKSUM - zip_path = os.path.join(path, "livecell.zip") - util.download_source(zip_path, url, download, checksum) - util.unzip(zip_path, path, True) - - # TODO use download flag def _download_annotation_file(path, split, download): annotation_file = os.path.join(path, f"{split}.json") @@ -156,14 +143,34 @@ def _download_livecell_annotations(path, split, download, cell_types, label_path return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types) -def get_livecell_data( +def get_livecell_data(path: Union[os.PathLike], download: bool = False): + """Download the LIVECell dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + os.makedirs(path, exist_ok=True) + image_path = os.path.join(path, "images") + + if os.path.exists(image_path): + return + + url = URLS["images"] + checksum = CHECKSUM + zip_path = os.path.join(path, "livecell.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + +def get_livecell_paths( path: Union[os.PathLike, str], split: str, - download: bool, + download: bool = False, cell_types: Optional[Sequence[str]] = None, label_path: Optional[Union[os.PathLike, str]] = None ) -> Tuple[List[str], List[str]]: - """Download the LIVECell dataset. + """Get paths to the LIVECell data. Args: path: Filepath to a folder where the downloaded data will be saved. @@ -173,10 +180,10 @@ def get_livecell_data( label_path: Optional path for loading the label data. Returns: - The paths to the image data. - The paths to the label data. + List of filepaths for the image data. + List of filepaths for the label data. """ - _download_livecell_images(path, download) + get_livecell_data(path, download) image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types, label_path) return image_paths, seg_paths @@ -217,18 +224,20 @@ def get_livecell_dataset( assert isinstance(cell_types, (list, tuple)), \ f"cell_types must be passed as a list or tuple instead of {cell_types}" - image_paths, seg_paths = get_livecell_data(path, split, download, cell_types, label_path) + image_paths, seg_paths = get_livecell_paths(path, split, download, cell_types, label_path) kwargs = util.ensure_transforms(ndim=2, **kwargs) kwargs, label_dtype = util.add_instance_label_transform( - kwargs, add_binary_target=True, label_dtype=label_dtype, - offsets=offsets, boundaries=boundaries, binary=binary + kwargs, add_binary_target=True, label_dtype=label_dtype, offsets=offsets, boundaries=boundaries, binary=binary ) - dataset = torch_em.data.ImageCollectionDataset( - image_paths, seg_paths, patch_shape=patch_shape, label_dtype=label_dtype, **kwargs + return ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=seg_paths, + patch_shape=patch_shape, + label_dtype=label_dtype, + **kwargs ) - return dataset def get_livecell_loader( @@ -269,5 +278,4 @@ def get_livecell_loader( path, split, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, cell_types=cell_types, label_path=label_path, label_dtype=label_dtype, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/mouse_embryo.py b/torch_em/data/datasets/light_microscopy/mouse_embryo.py index c2503808..1259c1ad 100644 --- a/torch_em/data/datasets/light_microscopy/mouse_embryo.py +++ b/torch_em/data/datasets/light_microscopy/mouse_embryo.py @@ -15,6 +15,7 @@ from .. import util + URL = "https://zenodo.org/record/6546550/files/MouseEmbryos.zip?download=1" CHECKSUM = "bf24df25e5f919489ce9e674876ff27e06af84445c48cf2900f1ab590a042622" @@ -40,6 +41,29 @@ def get_mouse_embryo_data(path: Union[os.PathLike, str], download: bool) -> str: return path +def get_mouse_embryo_paths(path: Union[os.PathLike, str], name: str, split: str, download: bool = False) -> List[str]: + """Get paths to the Mouse Embryo data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the segmentation task. Either 'membrane' or 'nuclei'. + split: The split to use for the dataset. Either 'train' or 'val'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + get_mouse_embryo_data(path, download) + + # the naming of the data is inconsistent: membrane has val, nuclei has test; + # we treat nuclei:test as val + split_ = "test" if name == "nuclei" and split == "val" else split + file_paths = glob(os.path.join(path, name.capitalize(), split_, "*.h5")) + file_paths.sort() + + return file_paths + + def get_mouse_embryo_dataset( path: Union[os.PathLike, str], name: str, @@ -70,21 +94,26 @@ def get_mouse_embryo_dataset( assert name in ("membrane", "nuclei") assert split in ("train", "val") assert len(patch_shape) == 3 - get_mouse_embryo_data(path, download) - # the naming of the data is inconsistent: membrane has val, nuclei has test; - # we treat nuclei:test as val - split_ = "test" if name == "nuclei" and split == "val" else split - file_paths = glob(os.path.join(path, name.capitalize(), split_, "*.h5")) - file_paths.sort() + file_paths = get_mouse_embryo_paths(path, name, split, download) kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=binary, binary=binary, boundaries=boundaries, - offsets=offsets, binary_is_exclusive=False + kwargs, + add_binary_target=binary, + binary=binary, + boundaries=boundaries, + offsets=offsets, + binary_is_exclusive=False ) - raw_key, label_key = "raw", "label" - return torch_em.default_segmentation_dataset(file_paths, raw_key, file_paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=file_paths, + raw_key="raw", + label_paths=file_paths, + label_key="label", + patch_shape=patch_shape, + **kwargs + ) def get_mouse_embryo_loader( @@ -116,13 +145,9 @@ def get_mouse_embryo_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_mouse_embryo_dataset( - path, name, split, patch_shape, - download=download, offsets=offsets, boundaries=boundaries, binary=binary, - **ds_kwargs + path, name, split, patch_shape, download=download, offsets=offsets, + boundaries=boundaries, binary=binary, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/neurips_cell_seg.py b/torch_em/data/datasets/light_microscopy/neurips_cell_seg.py index b75224df..89be5e1b 100644 --- a/torch_em/data/datasets/light_microscopy/neurips_cell_seg.py +++ b/torch_em/data/datasets/light_microscopy/neurips_cell_seg.py @@ -10,7 +10,7 @@ import os from glob import glob -from typing import Union, Tuple, Any, Optional +from typing import Union, Tuple, Any, Optional, List import numpy as np @@ -87,7 +87,21 @@ def get_neurips_cellseg_data(root: Union[os.PathLike, str], split: str, download return target_dir -def _get_image_and_label_paths(root, split, download): +def get_neurips_cellseg_paths( + root: Union[os.PathLike, str], split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + f"""Get paths to NeurIPS CellSeg Challenge data. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + split: The data split to download. Available splits are: + {', '.join(URL.keys())} + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ path = get_neurips_cellseg_data(root, split, download) image_folder = os.path.join(path, "images") @@ -138,15 +152,16 @@ def get_neurips_cellseg_supervised_dataset( The segmentation dataset. """ assert split in ("train", "val", "test"), split - image_paths, label_paths = _get_image_and_label_paths(root, split, download) + image_paths, label_paths = get_neurips_cellseg_paths(root, split, download) if raw_transform is None: trafo = to_rgb if make_rgb else None raw_transform = torch_em.transform.get_raw_transform(augmentation2=trafo) + if transform is None: transform = torch_em.transform.get_augmentations(ndim=2) - ds = ImageCollectionDataset( + return ImageCollectionDataset( raw_image_paths=image_paths, label_image_paths=label_paths, patch_shape=patch_shape, @@ -158,7 +173,6 @@ def get_neurips_cellseg_supervised_dataset( n_samples=n_samples, sampler=sampler ) - return ds def get_neurips_cellseg_supervised_loader( diff --git a/torch_em/data/datasets/light_microscopy/omnipose.py b/torch_em/data/datasets/light_microscopy/omnipose.py index d5082b8f..ecb49ac9 100644 --- a/torch_em/data/datasets/light_microscopy/omnipose.py +++ b/torch_em/data/datasets/light_microscopy/omnipose.py @@ -28,10 +28,7 @@ DATA_CHOICES = ["bact_fluor", "bact_phase", "worm", "worm_high_res"] -def get_omnipose_data( - path: Union[os.PathLike, str], - download: bool = False, -): +def get_omnipose_data(path: Union[os.PathLike, str], download: bool = False) -> str: """Download the OmniPose dataset. Args: @@ -39,7 +36,7 @@ def get_omnipose_data( download: Whether to download the data if it is not present. Return: - The filepath to the data. + The filepath where the data is downloaded. """ os.makedirs(path, exist_ok=True) @@ -54,7 +51,25 @@ def get_omnipose_data( return data_dir -def _get_omnipose_paths(path, split, data_choice, download): +def get_omnipose_paths( + path: Union[os.PathLike, str], + split: str, + data_choice: Optional[Union[str, List[str]]] = None, + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the OmniPose data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train' or 'test'. + data_choice: The choice of specific data. + Either 'bact_fluor', 'bact_phase', 'worm' or 'worm_high_res'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ data_dir = get_omnipose_data(path=path, download=download) if split not in ["train", "test"]: @@ -114,9 +129,9 @@ def get_omnipose_dataset( Returns: The segmentation dataset. """ - image_paths, gt_paths = _get_omnipose_paths(path, split, data_choice, download) - print(len(image_paths), len(gt_paths)) - dataset = torch_em.default_segmentation_dataset( + image_paths, gt_paths = get_omnipose_paths(path, split, data_choice, download) + + return torch_em.default_segmentation_dataset( raw_paths=image_paths, raw_key=None, label_paths=gt_paths, @@ -125,7 +140,6 @@ def get_omnipose_dataset( patch_shape=patch_shape, **kwargs ) - return dataset def get_omnipose_loader( @@ -156,5 +170,4 @@ def get_omnipose_loader( dataset = get_omnipose_dataset( path=path, patch_shape=patch_shape, split=split, data_choice=data_choice, download=download, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/organoidnet.py b/torch_em/data/datasets/light_microscopy/organoidnet.py index 3656f727..a48191fc 100644 --- a/torch_em/data/datasets/light_microscopy/organoidnet.py +++ b/torch_em/data/datasets/light_microscopy/organoidnet.py @@ -1,8 +1,17 @@ +"""The OrganoIDNet dataset contains annotations of panceratic organoids. + +This dataset is from the publication https://doi.org/10.1007/s13402-024-00958-2. +Please cite it if you use this dataset for a publication. +""" + + import os import shutil import zipfile from glob import glob -from typing import Tuple, Union +from typing import Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader import torch_em @@ -13,7 +22,17 @@ CHECKSUM = "3cd9239bf74bda096ecb5b7bdb95f800c7fa30b9937f9aba6ddf98d754cbfa3d" -def get_organoidnet_data(path, split, download): +def get_organoidnet_data(path: Union[os.PathLike, str], split: str, download: bool = False) -> str: + """Download the OrganoIDNet dataset. + + Args: + path: Filepath to the folder where the downloaded data will be saved. + split: The data split to use. + download: Whether to download the data if it is not present. + + Returns: + The filepath where the data is downloaded. + """ splits = ["Training", "Validation", "Test"] assert split in splits @@ -51,7 +70,20 @@ def get_organoidnet_data(path, split, download): return data_dir -def _get_data_paths(path, split, download): +def get_organoidnet_paths( + path: Union[os.PathLike, str], split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the OrganoIDNet data. + + Args: + path: Filepath to the folder where the downloaded data will be saved. + split: The data split to use. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ data_dir = get_organoidnet_data(path=path, split=split, download=download) image_paths = sorted(glob(os.path.join(data_dir, "Images", "*.tif"))) @@ -61,18 +93,21 @@ def _get_data_paths(path, split, download): def get_organoidnet_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - download: bool = False, - **kwargs -): - """Dataset for the segmentation of panceratic organoids. - - This dataset is from the publication https://doi.org/10.1007/s13402-024-00958-2. - Please cite it if you use this dataset for a publication. + path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], download: bool = False, **kwargs +) -> Dataset: + """Get the OrganoIDNet dataset for organoid segmentation in microscopy images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. """ - image_paths, label_paths = _get_data_paths(path=path, split=split, download=download) + image_paths, label_paths = get_organoidnet_paths(path, split, download) return torch_em.default_segmentation_dataset( raw_paths=image_paths, @@ -92,17 +127,22 @@ def get_organoidnet_loader( batch_size: int, download: bool = False, **kwargs -): - """Dataloader for the segmentation of pancreatic organoids in brightfield images. - See `get_organoidnet_dataset` for details. +) -> DataLoader: + """Get the OrganoIDNet dataset for organoid segmentation in microscopy images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_organoidnet_dataset( - path=path, - split=split, - patch_shape=patch_shape, - download=download, - **ds_kwargs + path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/orgasegment.py b/torch_em/data/datasets/light_microscopy/orgasegment.py index 1f6b8edd..f6b3f2d7 100644 --- a/torch_em/data/datasets/light_microscopy/orgasegment.py +++ b/torch_em/data/datasets/light_microscopy/orgasegment.py @@ -8,7 +8,7 @@ import os import shutil from glob import glob -from typing import Tuple, Union, Literal +from typing import Tuple, Union, Literal, List from torch.utils.data import Dataset, DataLoader @@ -54,7 +54,22 @@ def get_orgasegment_data( return data_dir -def _get_data_paths(path, split, download): +def get_orgasegment_paths( + path: Union[os.PathLike, str], + split: Literal["train", "val", "eval"], + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths for the OrgaSegment data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to download. Either 'train', 'val or 'eval'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the image data. + List of filepaths to the label data. + """ data_dir = get_orgasegment_data(path=path, split=split, download=download) image_paths = sorted(glob(os.path.join(data_dir, "*_img.jpg"))) @@ -88,11 +103,10 @@ def get_orgasegment_dataset( """ assert split in ["train", "val", "eval"] - image_paths, label_paths = _get_data_paths(path=path, split=split, download=download) + image_paths, label_paths = get_orgasegment_paths(path=path, split=split, download=download) + + kwargs, _ = util.add_instance_label_transform(kwargs, add_binary_target=True, binary=binary, boundaries=boundaries) - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, - ) return torch_em.default_segmentation_dataset( raw_paths=image_paths, raw_key=None, @@ -138,5 +152,4 @@ def get_orgasegment_loader( download=download, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/plantseg.py b/torch_em/data/datasets/light_microscopy/plantseg.py index 4adf84c0..2db4f603 100644 --- a/torch_em/data/datasets/light_microscopy/plantseg.py +++ b/torch_em/data/datasets/light_microscopy/plantseg.py @@ -113,14 +113,14 @@ def _fix_inconsistent_volumes(data_path, name, split): labels[...] = resized_labels -def get_plantseg_data(path: Union[os.PathLike, str], download: bool, name: str, split: str) -> str: +def get_plantseg_data(path: Union[os.PathLike, str], name: str, split: str, download: bool = False) -> str: """Download the PlantSeg training data. Args: path: Filepath to a folder where the downloaded data will be saved. - download: Whether to download the data if it is not present. name: The name of the data to load. Either 'root', 'nuclei' or 'ovules'. split: The split to download. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. Returns: The filepath to the training data. @@ -138,6 +138,28 @@ def get_plantseg_data(path: Union[os.PathLike, str], download: bool, name: str, return out_path +def get_plantseg_paths( + path: Union[os.PathLike, str], + name: str, + split: str, + download: bool = False +) -> List[str]: + """Get paths to the PlantSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the data to load. Either 'root', 'nuclei' or 'ovules'. + split: The split to download. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the data. + """ + data_path = get_plantseg_data(path, download, name, split) + file_paths = sorted(glob(os.path.join(data_path, "*.h5"))) + return file_paths + + def get_plantseg_dataset( path: Union[os.PathLike, str], name: str, @@ -166,18 +188,22 @@ def get_plantseg_dataset( The segmentation dataset. """ assert len(patch_shape) == 3 - data_path = get_plantseg_data(path, download, name, split) - file_paths = glob(os.path.join(data_path, "*.h5")) - file_paths.sort() + file_paths = get_plantseg_paths(path, name, split, download) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=binary, binary=binary, boundaries=boundaries, offsets=offsets, binary_is_exclusive=False ) - raw_key, label_key = "raw", "label" - return torch_em.default_segmentation_dataset(file_paths, raw_key, file_paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=file_paths, + raw_key="raw", + label_paths=file_paths, + label_key="label", + patch_shape=patch_shape, + **kwargs + ) # TODO add support for ignore label, key: "/label_with_ignore" @@ -210,13 +236,9 @@ def get_plantseg_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_plantseg_dataset( - path, name, split, patch_shape, - download=download, offsets=offsets, boundaries=boundaries, binary=binary, - **ds_kwargs + path, name, split, patch_shape, download=download, offsets=offsets, + boundaries=boundaries, binary=binary, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/tissuenet.py b/torch_em/data/datasets/light_microscopy/tissuenet.py index f8f9af5a..52e88c88 100644 --- a/torch_em/data/datasets/light_microscopy/tissuenet.py +++ b/torch_em/data/datasets/light_microscopy/tissuenet.py @@ -10,7 +10,7 @@ import os from glob import glob from tqdm import tqdm -from typing import Tuple, Union +from typing import Tuple, Union, List import numpy as np import pandas as pd @@ -60,11 +60,7 @@ def _create_dataset(path, zip_path): _create_split(path, split) -def get_tissuenet_data( - path: Union[os.PathLike, str], - split: str, - download: bool = False, -): +def get_tissuenet_data(path: Union[os.PathLike, str], split: str, download: bool = False) -> str: """Download the TissueNet dataset. NOTE: Automatic download is not supported for TissueNet datset. @@ -97,6 +93,25 @@ def get_tissuenet_data( return split_folder +def get_tissuenet_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> List[str]: + """Get paths to the TissueNet data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the data. + """ + split_folder = get_tissuenet_data(path, split, download) + assert os.path.exists(split_folder) + data_paths = glob(os.path.join(split_folder, "*.zarr")) + assert len(data_paths) > 0 + + return data_paths + + def get_tissuenet_dataset( path: Union[os.PathLike, str], split: str, @@ -123,19 +138,21 @@ def get_tissuenet_dataset( assert raw_channel in ("nucleus", "cell", "rgb") assert label_channel in ("nucleus", "cell") - split_folder = get_tissuenet_data(path, split, download) - assert os.path.exists(split_folder) - data_path = glob(os.path.join(split_folder, "*.zarr")) - assert len(data_path) > 0 - - raw_key, label_key = f"raw/{raw_channel}", f"labels/{label_channel}" + data_paths = get_tissuenet_paths(path, split, download) with_channels = True if raw_channel == "rgb" else False kwargs = util.update_kwargs(kwargs, "with_channels", with_channels) kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) kwargs = util.update_kwargs(kwargs, "ndim", 2) - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key=f"raw/{raw_channel}", + label_paths=data_paths, + label_key=f"labels/{label_channel}", + patch_shape=patch_shape, + **kwargs + ) # TODO enable loading specific tissue types etc. (from the 'meta' attributes) @@ -168,5 +185,4 @@ def get_tissuenet_loader( dataset = get_tissuenet_dataset( path, split, patch_shape, raw_channel, label_channel, download, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/vgg_hela.py b/torch_em/data/datasets/light_microscopy/vgg_hela.py index 332d61e8..75500241 100644 --- a/torch_em/data/datasets/light_microscopy/vgg_hela.py +++ b/torch_em/data/datasets/light_microscopy/vgg_hela.py @@ -82,12 +82,28 @@ def get_vgg_hela_data(path: Union[os.PathLike, str], download: bool) -> str: return path +def get_vgg_hela_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> Tuple[str, str]: + """Get paths for HeLA VGG data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + Filepath to the folder where image data is stored. + Filepath to the folder where label data is stored. + """ + get_vgg_hela_data(path, download) + + image_path = os.path.join(path, split, "images") + label_path = os.path.join(path, split, "labels") + + return image_path, label_path + + def get_vgg_hela_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - download: bool = False, - **kwargs + path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], download: bool = False, **kwargs ) -> Dataset: """Get the HeLA VGG dataset for cell counting. @@ -102,15 +118,19 @@ def get_vgg_hela_dataset( The segmentation dataset. """ assert split in ("test", "train"), split - get_vgg_hela_data(path, download) - image_path = os.path.join(path, split, "images") - label_path = os.path.join(path, split, "labels") + image_path, label_path = get_vgg_hela_paths(path, split, download) kwargs = util.update_kwargs(kwargs, "ndim", 2) kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) + return torch_em.default_segmentation_dataset( - image_path, "*.tif", label_path, "*.tif", patch_shape, **kwargs + raw_paths=image_path, + raw_key="*.tif", + label_paths=label_path, + label_key="*.tif", + patch_shape=patch_shape, + **kwargs ) @@ -135,11 +155,6 @@ def get_vgg_hela_loader( Returns: The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_vgg_hela_dataset( - path, split, patch_shape, download=download, **ds_kwargs, - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_vgg_hela_dataset(path, split, patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)