Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function in datasets to get input paths #367

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions torch_em/data/datasets/electron_microscopy/asem.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""ASEM is a dataset for segmentation of cellular structures in FIB-SEM.

The dataset was publised in https://doi.org/10.1083/jcb.202208005.
Please cite this publication if you use the dataset in your research.
"""
Expand All @@ -8,6 +9,8 @@

import numpy as np

from torch.utils.data import Dataset, DataLoader

import torch_em

from .. import util
Expand Down Expand Up @@ -54,11 +57,7 @@
}


def get_asem_data(
path: Union[os.PathLike, str],
volume_ids: List[str],
download: bool = False
):
def get_asem_data(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False):
"""Download the ASEM dataset.

The dataset is located at https://open.quiltdata.com/b/asem-project.
Expand All @@ -67,19 +66,14 @@ def get_asem_data(
path: Filepath to a folder where the downloaded data will be saved.
volume_ids: List of volumes to download.
download: Whether to download the data if it is not present.

Returns:
List of paths for all volume ids.
"""
if download and not have_quilt:
raise ModuleNotFoundError("Please install quilt3: 'pip install quilt3'.")

b = q3.Bucket("s3://asem-project")

volume_paths = []
for volume_id in volume_ids:
volume_path = os.path.join(path, VOLUMES[volume_id])
volume_paths.append(volume_path)
if os.path.exists(volume_path):
continue

Expand All @@ -100,6 +94,20 @@ def get_asem_data(
b.fetch(key=f"datasets/{VOLUMES[volume_id]}/.zgroup", path=f"{volume_path}/")
b.fetch(key=f"datasets/{VOLUMES[volume_id]}/volumes/.zgroup", path=f"{volume_path}/volumes/")


def get_asem_paths(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False) -> List[str]:
"""Get paths to the ASEM data.

Args:
path: Filepath to a folder where the downloaded data will be saved.
volume_ids: List of volumes to download.
download: Whether to download the data if it is not present.

Returns:
List of paths for all volume ids.
"""
get_asem_data(path, volume_ids, download)
volume_paths = [os.path.join(path, VOLUMES[vol_id]) for vol_id in volume_ids]
return volume_paths


Expand Down Expand Up @@ -170,14 +178,14 @@ def get_asem_dataset(
organelles: Optional[Union[List[str], str]] = None,
volume_ids: Optional[Union[List[str], str]] = None,
**kwargs
):
) -> Dataset:
"""Get dataset for segmentation of organelles in FIB-SEM cells.

Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
download: Whether to download the data if it is not present.
orgnalles: The choice of organelles.
organelles: The choice of organelles.
volume_ids: The choice of volumes.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.

Expand All @@ -200,16 +208,17 @@ def get_asem_dataset(
assert volume_id in ORGANELLES[organelle], \
f"The chosen volume and organelle combination does not match: '{volume_id}' & '{organelle}'"

volume_paths = get_asem_data(path, volume_ids, download)
volume_paths = get_asem_paths(path, volume_ids, download)

for volume_path in volume_paths:
have_volumes_inconsistent = _make_volumes_consistent(volume_path, organelle)

raw_key = f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw"
dataset = torch_em.default_segmentation_dataset(
volume_path, raw_key,
volume_path, f"volumes/labels/{organelle}",
patch_shape,
raw_paths=volume_path,
raw_key=f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw",
label_paths=volume_path,
label_key=f"volumes/labels/{organelle}",
patch_shape=patch_shape,
is_seg_dataset=True,
**kwargs
)
Expand All @@ -227,15 +236,15 @@ def get_asem_loader(
organelles: Optional[Union[List[str], str]] = None,
volume_ids: Optional[Union[List[str], str]] = None,
**kwargs
):
) -> DataLoader:
"""Get dataloader for the segmentation of organelles in FIB-SEM cells.

Args:
path: Filepath to a folder where the downloaded data will be saved.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
download: Whether to download the data if it is not present.
orgnalles: The choice of organelles.
organelles: The choice of organelles.
volume_ids: The choice of volumes.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.

Expand Down
69 changes: 50 additions & 19 deletions torch_em/data/datasets/electron_microscopy/axondeepseg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM.
It contains two different data types: TEM and SEM.

The dataset was published in https://doi.org/10.1038/s41598-018-22181-4.
Please cite this publication if you use the dataset in your research.
"""

import os
from glob import glob
from shutil import rmtree
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Literal, List

import imageio
import numpy as np
Expand Down Expand Up @@ -116,8 +117,8 @@ def _preprocess_tem_data(out_path):
rmtree(data_root)


def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: bool) -> str:
"""Download the axondeepseg data.
def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str:
"""Download the AxonDeepSeg data.

Args:
path: Filepath to a folder where the downloaded data will be saved.
Expand Down Expand Up @@ -149,14 +150,47 @@ def get_axondeepseg_data(path: Union[str, os.PathLike], name: str, download: boo
return out_path


def get_axondeepseg_paths(
path: Union[str, os.PathLike],
name: Literal["sem", "tem"],
download: bool = False,
val_fraction: Optional[float] = None,
split: Optional[str] = None,
) -> List[str]:
"""Get paths to the AxonDeepSeg data.

Args:
path: Filepath to a folder where the downloaded data will be saved.
name: The name of the dataset to download. Can be either 'sem' or 'tem'.
download: Whether to download the data if it is not present.
val_fraction: The fraction of the data to use for validation.
split: The data split. Either 'train' or 'val'.

Returns:
List of paths for all the data.
"""
all_paths = []
for nn in name:
data_root = get_axondeepseg_data(path, nn, download)
paths = glob(os.path.join(data_root, "*.h5"))
paths.sort()
if val_fraction is not None:
assert split is not None
n_samples = int(len(paths) * (1 - val_fraction))
paths = paths[:n_samples] if split == "train" else paths[n_samples:]
all_paths.extend(paths)

return all_paths


def get_axondeepseg_dataset(
path: Union[str, os.PathLike],
name: str,
name: Literal["sem", "tem"],
patch_shape: Tuple[int, int],
download: bool = False,
one_hot_encoding: bool = False,
val_fraction: Optional[float] = None,
split: Optional[str] = None,
split: Optional[Literal['train', 'val']] = None,
**kwargs,
) -> Dataset:
"""Get dataset for segmentation of myelinated axons.
Expand All @@ -178,16 +212,7 @@ def get_axondeepseg_dataset(
name = [name]
assert isinstance(name, (tuple, list))

all_paths = []
for nn in name:
data_root = get_axondeepseg_data(path, nn, download)
paths = glob(os.path.join(data_root, "*.h5"))
paths.sort()
if val_fraction is not None:
assert split is not None
n_samples = int(len(paths) * (1 - val_fraction))
paths = paths[:n_samples] if split == "train" else paths[n_samples:]
all_paths.extend(paths)
all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split)

if one_hot_encoding:
if isinstance(one_hot_encoding, bool):
Expand All @@ -205,19 +230,25 @@ def get_axondeepseg_dataset(
msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden."
kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg)

raw_key, label_key = "raw", "labels"
return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs)
return torch_em.default_segmentation_dataset(
raw_paths=all_paths,
raw_key="raw",
label_paths=all_paths,
label_key="labels",
patch_shape=patch_shape,
**kwargs
)


def get_axondeepseg_loader(
path: Union[str, os.PathLike],
name: str,
name: Literal["sem", "tem"],
patch_shape: Tuple[int, int],
batch_size: int,
download: bool = False,
one_hot_encoding: bool = False,
val_fraction: Optional[float] = None,
split: Optional[str] = None,
split: Optional[Literal["train", "val"]] = None,
**kwargs
) -> DataLoader:
"""Get dataloader for the segmentation of myelinated axons.
Expand Down
Loading