Skip to content

Commit

Permalink
Add decompression preprocessing step to TotalSegmentator2D for more…
Browse files Browse the repository at this point in the history
… efficient slice loading (#705)
  • Loading branch information
nkaenzig authored Nov 12, 2024
1 parent 020eefa commit 50c90a3
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 33 deletions.
3 changes: 2 additions & 1 deletion src/eva/core/utils/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Core I/O utilities."""

from eva.core.utils.io.dataframe import read_dataframe
from eva.core.utils.io.gz import gunzip_file

__all__ = ["read_dataframe"]
__all__ = ["read_dataframe", "gunzip_file"]
28 changes: 28 additions & 0 deletions src/eva/core/utils/io/gz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Utils for .gz files."""

import gzip
import os


def gunzip_file(path: str, unpack_dir: str | None = None, keep: bool = True) -> str:
"""Unpacks a .gz file to the provided directory.
Args:
path: Path to the .gz file to extract.
unpack_dir: Directory to extract the file to. If `None`, it will use the
same directory as the compressed file.
keep: Whether to keep the compressed .gz file.
Returns:
The path to the extracted file.
"""
unpack_dir = unpack_dir or os.path.dirname(path)
os.makedirs(unpack_dir, exist_ok=True)
save_path = os.path.join(unpack_dir, os.path.basename(path).replace(".gz", ""))
if not os.path.isfile(save_path):
with gzip.open(path, "rb") as f_in:
with open(save_path, "wb") as f_out:
f_out.write(f_in.read())
if not keep:
os.remove(path)
return save_path
47 changes: 46 additions & 1 deletion src/eva/core/utils/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import multiprocessing
import sys
import traceback
from typing import Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar

from eva.core.utils.progress_bar import tqdm


class Process(multiprocessing.Process):
Expand Down Expand Up @@ -42,3 +45,45 @@ def check_exceptions(self) -> None:
error, traceback = self.exception
sys.stderr.write(traceback + "\n")
raise error


R = TypeVar("R")


def run_with_threads(
func: Callable[..., R],
items: Iterable[Tuple[Any, ...]],
kwargs: Dict[str, Any] | None = None,
num_workers: int = 8,
progress_desc: Optional[str] = None,
show_progress: bool = True,
return_results: bool = True,
) -> List[R] | None:
"""Process items with multiple threads using ThreadPoolExecutor.
Args:
func: Function to execute for each item
items: Iterable of items to process. Each item should be a tuple of
arguments to pass to func.
kwargs: Additional keyword arguments to pass to func.
num_workers: Number of worker threads
progress_desc: Description for progress bar
show_progress: Whether to show progress bar
return_results: Whether to return the results. If False, the function
will return None.
Returns:
List of results if return_results is True, otherwise None
"""
results: List[Any] = []

with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(func, *args, **(kwargs or {})) for args in items]
pbar = tqdm(total=len(futures), desc=progress_desc, disable=not show_progress, leave=False)
for future in as_completed(futures):
if return_results:
results.append(future.result())
pbar.update(1)
pbar.close()

return results if return_results else None
11 changes: 7 additions & 4 deletions src/eva/vision/callbacks/loggers/batch/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,19 @@ def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor:
integer values which represent the pixel class id.
Args:
tensor: An image tensor of range [0., 1.].
tensor: An image tensor of range [0., N_CLASSES].
Returns:
The image as a tensor of range [0., 255.].
"""
tensor = torch.squeeze(tensor)
height, width = tensor.shape[-2], tensor.shape[-1]
red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8)
for class_id, color in colormap.COLORMAP.items():
class_ids = torch.unique(tensor)
colors = colormap.get_colors(max(class_ids))
for class_id in class_ids:
indices = tensor == class_id
red[indices], green[indices], blue[indices] = color
red[indices], green[indices], blue[indices] = colors[int(class_id)]
return torch.stack([red, green, blue])


Expand All @@ -157,8 +159,9 @@ def _overlay_mask(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
from the predefined colormap.
"""
binary_masks = functional.one_hot(mask).permute(2, 0, 1).to(dtype=torch.bool)
colors = colormap.get_colors(binary_masks.shape[0] + 1)
return torchvision.utils.draw_segmentation_masks(
image, binary_masks[1:], alpha=0.65, colors=colormap.COLORS[1:] # type: ignore
image, binary_masks[1:], alpha=0.65, colors=colors[1:] # type: ignore
)


Expand Down
68 changes: 50 additions & 18 deletions src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import os
from glob import glob
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Tuple

import numpy as np
Expand All @@ -12,7 +13,8 @@
from torchvision.datasets import utils
from typing_extensions import override

from eva.core.utils.progress_bar import tqdm
from eva.core.utils import io as core_io
from eva.core.utils import multiprocessing
from eva.vision.data.datasets import _validators, structs
from eva.vision.data.datasets.segmentation import base
from eva.vision.utils import io
Expand Down Expand Up @@ -65,6 +67,8 @@ def __init__(
download: bool = False,
classes: List[str] | None = None,
optimize_mask_loading: bool = True,
decompress: bool = True,
num_workers: int = 10,
transforms: Callable | None = None,
) -> None:
"""Initialize dataset.
Expand All @@ -85,8 +89,15 @@ def __init__(
in order to optimize the loading time. In the `setup` method, it
will reformat the binary one-hot masks to a semantic mask and store
it on disk.
decompress: Whether to decompress the ct.nii.gz files when preparing the data.
The label masks won't be decompressed, but when enabling optimize_mask_loading
it will export the semantic label masks to a single file in uncompressed .nii
format.
num_workers: The number of workers to use for optimizing the masks &
decompressing the .gz files.
transforms: A function/transforms that takes in an image and a target
mask and returns the transformed versions of both.
"""
super().__init__(transforms=transforms)

Expand All @@ -96,6 +107,8 @@ def __init__(
self._download = download
self._classes = classes
self._optimize_mask_loading = optimize_mask_loading
self._decompress = decompress
self._num_workers = num_workers

if self._optimize_mask_loading and self._classes is not None:
raise ValueError(
Expand Down Expand Up @@ -128,23 +141,29 @@ def get_filename(path: str) -> str:
def class_to_idx(self) -> Dict[str, int]:
return {label: index for index, label in enumerate(self.classes)}

@property
def _file_suffix(self) -> str:
return "nii" if self._decompress else "nii.gz"

@override
def filename(self, index: int, segmented: bool = True) -> str:
def filename(self, index: int) -> str:
sample_idx, _ = self._indices[index]
sample_dir = self._samples_dirs[sample_idx]
return os.path.join(sample_dir, "ct.nii.gz")
return os.path.join(sample_dir, f"ct.{self._file_suffix}")

@override
def prepare_data(self) -> None:
if self._download:
self._download_dataset()
if self._decompress:
self._decompress_files()
self._samples_dirs = self._fetch_samples_dirs()
if self._optimize_mask_loading:
self._export_semantic_label_masks()

@override
def configure(self) -> None:
self._samples_dirs = self._fetch_samples_dirs()
self._indices = self._create_indices()
if self._optimize_mask_loading:
self._export_semantic_label_masks()

@override
def validate(self) -> None:
Expand Down Expand Up @@ -186,16 +205,15 @@ def load_metadata(self, index: int) -> Dict[str, Any]:
return {"slice_index": slice_index}

def _load_mask(self, index: int) -> tv_tensors.Mask:
"""Loads and builds the segmentation mask from NifTi files."""
sample_index, slice_index = self._indices[index]
semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]

def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
"""Loads the segmentation mask from a semantic label NifTi file."""
sample_index, slice_index = self._indices[index]
masks_dir = self._get_masks_dir(sample_index)
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii")
semantic_labels = io.read_nifti(filename, slice_index)
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]

Expand All @@ -209,7 +227,7 @@ def _load_masks_as_semantic_label(
slice_index: Whether to return only a specific slice.
"""
masks_dir = self._get_masks_dir(sample_index)
mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes]
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes]
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
background_mask = np.zeros_like(binary_masks[0])
return np.argmax([background_mask] + binary_masks, axis=0)
Expand All @@ -219,24 +237,28 @@ def _export_semantic_label_masks(self) -> None:
total_samples = len(self._samples_dirs)
masks_dirs = map(self._get_masks_dir, range(total_samples))
semantic_labels = [
(index, os.path.join(directory, "semantic_labels", "masks.nii.gz"))
(index, os.path.join(directory, "semantic_labels", "masks.nii"))
for index, directory in enumerate(masks_dirs)
]
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)

for sample_index, filename in tqdm(
list(to_export),
desc=">> Exporting optimized semantic masks",
leave=False,
):
def _process_mask(sample_index: Any, filename: str) -> None:
semantic_labels = self._load_masks_as_semantic_label(sample_index)
os.makedirs(os.path.dirname(filename), exist_ok=True)
io.save_array_as_nifti(semantic_labels, filename)

multiprocessing.run_with_threads(
_process_mask,
list(to_export),
num_workers=self._num_workers,
progress_desc=">> Exporting optimized semantic mask",
return_results=False,
)

def _get_image_path(self, sample_index: int) -> str:
"""Returns the corresponding image path."""
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "ct.nii.gz")
return os.path.join(self._root, sample_dir, f"ct.{self._file_suffix}")

def _get_masks_dir(self, sample_index: int) -> str:
"""Returns the directory of the corresponding masks."""
Expand All @@ -246,7 +268,7 @@ def _get_masks_dir(self, sample_index: int) -> str:
def _get_semantic_labels_filename(self, sample_index: int) -> str:
"""Returns the semantic label filename."""
masks_dir = self._get_masks_dir(sample_index)
return os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
return os.path.join(masks_dir, "semantic_labels", "masks.nii")

def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
"""Returns the total amount of slices of a sample."""
Expand Down Expand Up @@ -320,6 +342,16 @@ def _download_dataset(self) -> None:
remove_finished=True,
)

def _decompress_files(self) -> None:
compressed_paths = Path(self._root).rglob("*/ct.nii.gz")
multiprocessing.run_with_threads(
core_io.gunzip_file,
[(str(path),) for path in compressed_paths],
num_workers=self._num_workers,
progress_desc=">> Decompressing .gz files",
return_results=False,
)

def _print_license(self) -> None:
"""Prints the dataset license."""
print(f"Dataset license: {self._license}")
2 changes: 1 addition & 1 deletion src/eva/vision/models/modules/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(
"decoder should map the embeddings (`inputs`) to."
)
features = self.encoder(inputs) if self.encoder else inputs
decoder_inputs = DecoderInputs(features, inputs.shape[-2:], inputs) # type: ignore
decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore
return self.decoder(decoder_inputs)

@override
Expand Down
20 changes: 20 additions & 0 deletions src/eva/vision/utils/colormap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Color mapping constants."""

from typing import List, Tuple

COLORS = [
(0, 0, 0),
(255, 0, 0), # Red
Expand Down Expand Up @@ -75,3 +77,21 @@

COLORMAP = dict(enumerate(COLORS)) | {255: (255, 255, 255)}
"""Class id to RGB color mapping."""


def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
"""Get a list of RGB colors.
If the number of colors is greater than the predefined colors, it will
repeat the colors until it reaches the requested number
Args:
num_colors: The number of colors to return.
Returns:
A list of RGB colors.
"""
colors = COLORS
while len(colors) < num_colors:
colors = colors + COLORS[1:]
return colors
Binary file added tests/eva/assets/core/archives/test.txt.gz
Binary file not shown.
1 change: 1 addition & 0 deletions tests/eva/core/utils/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests the core io utilities."""
40 changes: 40 additions & 0 deletions tests/eva/core/utils/io/test_gz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Tests for .gz file utilities."""

import os
import shutil

import pytest

from eva.core.utils.io import gz


@pytest.mark.parametrize(
"subdir, keep",
[
(None, True),
("test_subdir", True),
(None, False),
],
)
def test_gunzip(tmp_path: str, gzip_file: str, subdir: str | None, keep: bool) -> None:
"""Verifies proper extraction of gzip file contents."""
unpack_dir = os.path.join(tmp_path, subdir) if subdir else tmp_path
tmp_gzip_path = os.path.join(tmp_path, os.path.basename(gzip_file))
shutil.copy(gzip_file, tmp_gzip_path)
gz.gunzip_file(tmp_gzip_path, unpack_dir=unpack_dir, keep=keep)

uncompressed_path = os.path.join(unpack_dir, "test.txt")
assert os.path.isfile(uncompressed_path)
with open(uncompressed_path, "r") as f:
assert f.read() == "gz file test"

if keep:
assert os.path.isfile(tmp_gzip_path)
else:
assert not os.path.isfile(tmp_gzip_path)


@pytest.fixture()
def gzip_file(assets_path: str) -> str:
"""Provides the path to the test gzip file asset."""
return os.path.join(assets_path, "core/archives/test.txt.gz")
Loading

0 comments on commit 50c90a3

Please sign in to comment.