Skip to content

Commit

Permalink
Add function in datasets to get input paths (#367)
Browse files Browse the repository at this point in the history
* Add function in datasets to get input paths

* Add _paths functionality to all microscopy datasets

* Make download optional for livecell functions
  • Loading branch information
anwai98 authored Oct 9, 2024
1 parent dd3519a commit d2b4b62
Show file tree
Hide file tree
Showing 35 changed files with 1,347 additions and 644 deletions.
47 changes: 28 additions & 19 deletions torch_em/data/datasets/electron_microscopy/asem.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""ASEM is a dataset for segmentation of cellular structures in FIB-SEM.
The dataset was publised in https://doi.org/10.1083/jcb.202208005.
Please cite this publication if you use the dataset in your research.
"""
Expand All @@ -8,6 +9,8 @@

import numpy as np

from torch.utils.data import Dataset, DataLoader

import torch_em

from .. import util
Expand Down Expand Up @@ -54,11 +57,7 @@
}


def get_asem_data(
path: Union[os.PathLike, str],
volume_ids: List[str],
download: bool = False
):
def get_asem_data(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False):
"""Download the ASEM dataset.
The dataset is located at https://open.quiltdata.com/b/asem-project.
Expand All @@ -67,19 +66,14 @@ def get_asem_data(
path: Filepath to a folder where the downloaded data will be saved.
volume_ids: List of volumes to download.
download: Whether to download the data if it is not present.
Returns:
List of paths for all volume ids.
"""
if download and not have_quilt:
raise ModuleNotFoundError("Please install quilt3: 'pip install quilt3'.")

b = q3.Bucket("s3://asem-project")

volume_paths = []
for volume_id in volume_ids:
volume_path = os.path.join(path, VOLUMES[volume_id])
volume_paths.append(volume_path)
if os.path.exists(volume_path):
continue

Expand All @@ -100,6 +94,20 @@ def get_asem_data(
b.fetch(key=f"datasets/{VOLUMES[volume_id]}/.zgroup", path=f"{volume_path}/")
b.fetch(key=f"datasets/{VOLUMES[volume_id]}/volumes/.zgroup", path=f"{volume_path}/volumes/")


def get_asem_paths(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False) -> List[str]:
"""Get paths to the ASEM data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
volume_ids: List of volumes to download.
download: Whether to download the data if it is not present.
Returns:
List of paths for all volume ids.
"""
get_asem_data(path, volume_ids, download)
volume_paths = [os.path.join(path, VOLUMES[vol_id]) for vol_id in volume_ids]
return volume_paths


Expand Down Expand Up @@ -170,14 +178,14 @@ def get_asem_dataset(
organelles: Optional[Union[List[str], str]] = None,
volume_ids: Optional[Union[List[str], str]] = None,
**kwargs
):
) -> Dataset:
"""Get dataset for segmentation of organelles in FIB-SEM cells.
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
download: Whether to download the data if it is not present.
orgnalles: The choice of organelles.
organelles: The choice of organelles.
volume_ids: The choice of volumes.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Expand All @@ -200,16 +208,17 @@ def get_asem_dataset(
assert volume_id in ORGANELLES[organelle], \
f"The chosen volume and organelle combination does not match: '{volume_id}' & '{organelle}'"

volume_paths = get_asem_data(path, volume_ids, download)
volume_paths = get_asem_paths(path, volume_ids, download)

for volume_path in volume_paths:
have_volumes_inconsistent = _make_volumes_consistent(volume_path, organelle)

raw_key = f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw"
dataset = torch_em.default_segmentation_dataset(
volume_path, raw_key,
volume_path, f"volumes/labels/{organelle}",
patch_shape,
raw_paths=volume_path,
raw_key=f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw",
label_paths=volume_path,
label_key=f"volumes/labels/{organelle}",
patch_shape=patch_shape,
is_seg_dataset=True,
**kwargs
)
Expand All @@ -227,15 +236,15 @@ def get_asem_loader(
organelles: Optional[Union[List[str], str]] = None,
volume_ids: Optional[Union[List[str], str]] = None,
**kwargs
):
) -> DataLoader:
"""Get dataloader for the segmentation of organelles in FIB-SEM cells.
Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
download: Whether to download the data if it is not present.
orgnalles: The choice of organelles.
organelles: The choice of organelles.
volume_ids: The choice of volumes.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Expand Down
69 changes: 50 additions & 19 deletions torch_em/data/datasets/electron_microscopy/axondeepseg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM.
It contains two different data types: TEM and SEM.
The dataset was published in https://doi.org/10.1038/s41598-018-22181-4.
Please cite this publication if you use the dataset in your research.
"""

import os
from glob import glob
from shutil import rmtree
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Literal, List

import imageio
import numpy as np
Expand Down Expand Up @@ -116,8 +117,8 @@ def _preprocess_tem_data(out_path):
rmtree(data_root)


def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str:
"""Download the axondeepseg data.
def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str:
"""Download the AxonDeepSeg data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
Expand Down Expand Up @@ -149,14 +150,47 @@ def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: boo
return out_path


def get_axondeepseg_paths(
path: Union[str, os.PathLike],
name: Literal["sem", "tem"],
download: bool = False,
val_fraction: Optional[float] = None,
split: Optional[str] = None,
) -> List[str]:
"""Get paths to the AxonDeepSeg data.
Args:
path: Filepath to a folder where the downloaded data will be saved.
name: The name of the dataset to download. Can be either 'sem' or 'tem'.
download: Whether to download the data if it is not present.
val_fraction: The fraction of the data to use for validation.
split: The data split. Either 'train' or 'val'.
Returns:
List of paths for all the data.
"""
all_paths = []
for nn in name:
data_root = get_axondeepseg_data(path, nn, download)
paths = glob(os.path.join(data_root, "*.h5"))
paths.sort()
if val_fraction is not None:
assert split is not None
n_samples = int(len(paths) * (1 - val_fraction))
paths = paths[:n_samples] if split == "train" else paths[n_samples:]
all_paths.extend(paths)

return all_paths


def get_axondeepseg_dataset(
path: Union[str, os.PathLike],
name: str,
name: Literal["sem", "tem"],
patch_shape: Tuple[int, int],
download: bool = False,
one_hot_encoding: bool = False,
val_fraction: Optional[float] = None,
split: Optional[str] = None,
split: Optional[Literal['train', 'val']] = None,
**kwargs,
) -> Dataset:
"""Get dataset for segmentation of myelinated axons.
Expand All @@ -178,16 +212,7 @@ def get_axondeepseg_dataset(
name = [name]
assert isinstance(name, (tuple, list))

all_paths = []
for nn in name:
data_root = get_axondeepseg_data(path, nn, download)
paths = glob(os.path.join(data_root, "*.h5"))
paths.sort()
if val_fraction is not None:
assert split is not None
n_samples = int(len(paths) * (1 - val_fraction))
paths = paths[:n_samples] if split == "train" else paths[n_samples:]
all_paths.extend(paths)
all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split)

if one_hot_encoding:
if isinstance(one_hot_encoding, bool):
Expand All @@ -205,19 +230,25 @@ def get_axondeepseg_dataset(
msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)

raw_key, label_key = "raw", "labels"
return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs)
return torch_em.default_segmentation_dataset(
raw_paths=all_paths,
raw_key="raw",
label_paths=all_paths,
label_key="labels",
patch_shape=patch_shape,
**kwargs
)


def get_axondeepseg_loader(
path: Union[str, os.PathLike],
name: str,
name: Literal["sem", "tem"],
patch_shape: Tuple[int, int],
batch_size: int,
download: bool = False,
one_hot_encoding: bool = False,
val_fraction: Optional[float] = None,
split: Optional[str] = None,
split: Optional[Literal["train", "val"]] = None,
**kwargs
) -> DataLoader:
"""Get dataloader for the segmentation of myelinated axons.
Expand Down
Loading

0 comments on commit d2b4b62

Please sign in to comment.