diff --git a/src/eva/core/utils/io/__init__.py b/src/eva/core/utils/io/__init__.py index 232ec98d..adf91e9c 100644 --- a/src/eva/core/utils/io/__init__.py +++ b/src/eva/core/utils/io/__init__.py @@ -1,5 +1,6 @@ """Core I/O utilities.""" from eva.core.utils.io.dataframe import read_dataframe +from eva.core.utils.io.gz import gunzip_file -__all__ = ["read_dataframe"] +__all__ = ["read_dataframe", "gunzip_file"] diff --git a/src/eva/core/utils/io/gz.py b/src/eva/core/utils/io/gz.py new file mode 100644 index 00000000..e352b18e --- /dev/null +++ b/src/eva/core/utils/io/gz.py @@ -0,0 +1,28 @@ +"""Utils for .gz files.""" + +import gzip +import os + + +def gunzip_file(path: str, unpack_dir: str | None = None, keep: bool = True) -> str: + """Unpacks a .gz file to the provided directory. + + Args: + path: Path to the .gz file to extract. + unpack_dir: Directory to extract the file to. If `None`, it will use the + same directory as the compressed file. + keep: Whether to keep the compressed .gz file. + + Returns: + The path to the extracted file. + """ + unpack_dir = unpack_dir or os.path.dirname(path) + os.makedirs(unpack_dir, exist_ok=True) + save_path = os.path.join(unpack_dir, os.path.basename(path).replace(".gz", "")) + if not os.path.isfile(save_path): + with gzip.open(path, "rb") as f_in: + with open(save_path, "wb") as f_out: + f_out.write(f_in.read()) + if not keep: + os.remove(path) + return save_path diff --git a/src/eva/core/utils/multiprocessing.py b/src/eva/core/utils/multiprocessing.py index 0c31f5dd..b0989749 100644 --- a/src/eva/core/utils/multiprocessing.py +++ b/src/eva/core/utils/multiprocessing.py @@ -3,7 +3,10 @@ import multiprocessing import sys import traceback -from typing import Any +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar + +from eva.core.utils.progress_bar import tqdm class Process(multiprocessing.Process): @@ -42,3 +45,45 @@ def check_exceptions(self) -> None: error, traceback = self.exception sys.stderr.write(traceback + "\n") raise error + + +R = TypeVar("R") + + +def run_with_threads( + func: Callable[..., R], + items: Iterable[Tuple[Any, ...]], + kwargs: Dict[str, Any] | None = None, + num_workers: int = 8, + progress_desc: Optional[str] = None, + show_progress: bool = True, + return_results: bool = True, +) -> List[R] | None: + """Process items with multiple threads using ThreadPoolExecutor. + + Args: + func: Function to execute for each item + items: Iterable of items to process. Each item should be a tuple of + arguments to pass to func. + kwargs: Additional keyword arguments to pass to func. + num_workers: Number of worker threads + progress_desc: Description for progress bar + show_progress: Whether to show progress bar + return_results: Whether to return the results. If False, the function + will return None. + + Returns: + List of results if return_results is True, otherwise None + """ + results: List[Any] = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(func, *args, **(kwargs or {})) for args in items] + pbar = tqdm(total=len(futures), desc=progress_desc, disable=not show_progress, leave=False) + for future in as_completed(futures): + if return_results: + results.append(future.result()) + pbar.update(1) + pbar.close() + + return results if return_results else None diff --git a/src/eva/vision/callbacks/loggers/batch/segmentation.py b/src/eva/vision/callbacks/loggers/batch/segmentation.py index b01a50ea..0c878d1f 100644 --- a/src/eva/vision/callbacks/loggers/batch/segmentation.py +++ b/src/eva/vision/callbacks/loggers/batch/segmentation.py @@ -128,7 +128,7 @@ def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor: integer values which represent the pixel class id. Args: - tensor: An image tensor of range [0., 1.]. + tensor: An image tensor of range [0., N_CLASSES]. Returns: The image as a tensor of range [0., 255.]. @@ -136,9 +136,11 @@ def _draw_semantic_mask(tensor: torch.Tensor) -> torch.Tensor: tensor = torch.squeeze(tensor) height, width = tensor.shape[-2], tensor.shape[-1] red, green, blue = torch.zeros((3, height, width), dtype=torch.uint8) - for class_id, color in colormap.COLORMAP.items(): + class_ids = torch.unique(tensor) + colors = colormap.get_colors(max(class_ids)) + for class_id in class_ids: indices = tensor == class_id - red[indices], green[indices], blue[indices] = color + red[indices], green[indices], blue[indices] = colors[int(class_id)] return torch.stack([red, green, blue]) @@ -157,8 +159,9 @@ def _overlay_mask(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: from the predefined colormap. """ binary_masks = functional.one_hot(mask).permute(2, 0, 1).to(dtype=torch.bool) + colors = colormap.get_colors(binary_masks.shape[0] + 1) return torchvision.utils.draw_segmentation_masks( - image, binary_masks[1:], alpha=0.65, colors=colormap.COLORS[1:] # type: ignore + image, binary_masks[1:], alpha=0.65, colors=colors[1:] # type: ignore ) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 9b45b4a7..fefba203 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -3,6 +3,7 @@ import functools import os from glob import glob +from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Tuple import numpy as np @@ -12,7 +13,8 @@ from torchvision.datasets import utils from typing_extensions import override -from eva.core.utils.progress_bar import tqdm +from eva.core.utils import io as core_io +from eva.core.utils import multiprocessing from eva.vision.data.datasets import _validators, structs from eva.vision.data.datasets.segmentation import base from eva.vision.utils import io @@ -65,6 +67,8 @@ def __init__( download: bool = False, classes: List[str] | None = None, optimize_mask_loading: bool = True, + decompress: bool = True, + num_workers: int = 10, transforms: Callable | None = None, ) -> None: """Initialize dataset. @@ -85,8 +89,15 @@ def __init__( in order to optimize the loading time. In the `setup` method, it will reformat the binary one-hot masks to a semantic mask and store it on disk. + decompress: Whether to decompress the ct.nii.gz files when preparing the data. + The label masks won't be decompressed, but when enabling optimize_mask_loading + it will export the semantic label masks to a single file in uncompressed .nii + format. + num_workers: The number of workers to use for optimizing the masks & + decompressing the .gz files. transforms: A function/transforms that takes in an image and a target mask and returns the transformed versions of both. + """ super().__init__(transforms=transforms) @@ -96,6 +107,8 @@ def __init__( self._download = download self._classes = classes self._optimize_mask_loading = optimize_mask_loading + self._decompress = decompress + self._num_workers = num_workers if self._optimize_mask_loading and self._classes is not None: raise ValueError( @@ -128,23 +141,29 @@ def get_filename(path: str) -> str: def class_to_idx(self) -> Dict[str, int]: return {label: index for index, label in enumerate(self.classes)} + @property + def _file_suffix(self) -> str: + return "nii" if self._decompress else "nii.gz" + @override - def filename(self, index: int, segmented: bool = True) -> str: + def filename(self, index: int) -> str: sample_idx, _ = self._indices[index] sample_dir = self._samples_dirs[sample_idx] - return os.path.join(sample_dir, "ct.nii.gz") + return os.path.join(sample_dir, f"ct.{self._file_suffix}") @override def prepare_data(self) -> None: if self._download: self._download_dataset() + if self._decompress: + self._decompress_files() + self._samples_dirs = self._fetch_samples_dirs() + if self._optimize_mask_loading: + self._export_semantic_label_masks() @override def configure(self) -> None: - self._samples_dirs = self._fetch_samples_dirs() self._indices = self._create_indices() - if self._optimize_mask_loading: - self._export_semantic_label_masks() @override def validate(self) -> None: @@ -186,16 +205,15 @@ def load_metadata(self, index: int) -> Dict[str, Any]: return {"slice_index": slice_index} def _load_mask(self, index: int) -> tv_tensors.Mask: - """Loads and builds the segmentation mask from NifTi files.""" sample_index, slice_index = self._indices[index] semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index) - return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue] + return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: """Loads the segmentation mask from a semantic label NifTi file.""" sample_index, slice_index = self._indices[index] masks_dir = self._get_masks_dir(sample_index) - filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") + filename = os.path.join(masks_dir, "semantic_labels", "masks.nii") semantic_labels = io.read_nifti(filename, slice_index) return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] @@ -209,7 +227,7 @@ def _load_masks_as_semantic_label( slice_index: Whether to return only a specific slice. """ masks_dir = self._get_masks_dir(sample_index) - mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes] + mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes] binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths] background_mask = np.zeros_like(binary_masks[0]) return np.argmax([background_mask] + binary_masks, axis=0) @@ -219,24 +237,28 @@ def _export_semantic_label_masks(self) -> None: total_samples = len(self._samples_dirs) masks_dirs = map(self._get_masks_dir, range(total_samples)) semantic_labels = [ - (index, os.path.join(directory, "semantic_labels", "masks.nii.gz")) + (index, os.path.join(directory, "semantic_labels", "masks.nii")) for index, directory in enumerate(masks_dirs) ] to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels) - for sample_index, filename in tqdm( - list(to_export), - desc=">> Exporting optimized semantic masks", - leave=False, - ): + def _process_mask(sample_index: Any, filename: str) -> None: semantic_labels = self._load_masks_as_semantic_label(sample_index) os.makedirs(os.path.dirname(filename), exist_ok=True) io.save_array_as_nifti(semantic_labels, filename) + multiprocessing.run_with_threads( + _process_mask, + list(to_export), + num_workers=self._num_workers, + progress_desc=">> Exporting optimized semantic mask", + return_results=False, + ) + def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" sample_dir = self._samples_dirs[sample_index] - return os.path.join(self._root, sample_dir, "ct.nii.gz") + return os.path.join(self._root, sample_dir, f"ct.{self._file_suffix}") def _get_masks_dir(self, sample_index: int) -> str: """Returns the directory of the corresponding masks.""" @@ -246,7 +268,7 @@ def _get_masks_dir(self, sample_index: int) -> str: def _get_semantic_labels_filename(self, sample_index: int) -> str: """Returns the semantic label filename.""" masks_dir = self._get_masks_dir(sample_index) - return os.path.join(masks_dir, "semantic_labels", "masks.nii.gz") + return os.path.join(masks_dir, "semantic_labels", "masks.nii") def _get_number_of_slices_per_sample(self, sample_index: int) -> int: """Returns the total amount of slices of a sample.""" @@ -320,6 +342,16 @@ def _download_dataset(self) -> None: remove_finished=True, ) + def _decompress_files(self) -> None: + compressed_paths = Path(self._root).rglob("*/ct.nii.gz") + multiprocessing.run_with_threads( + core_io.gunzip_file, + [(str(path),) for path in compressed_paths], + num_workers=self._num_workers, + progress_desc=">> Decompressing .gz files", + return_results=False, + ) + def _print_license(self) -> None: """Prints the dataset license.""" print(f"Dataset license: {self._license}") diff --git a/src/eva/vision/models/modules/semantic_segmentation.py b/src/eva/vision/models/modules/semantic_segmentation.py index b2b044a9..83eb337d 100644 --- a/src/eva/vision/models/modules/semantic_segmentation.py +++ b/src/eva/vision/models/modules/semantic_segmentation.py @@ -103,7 +103,7 @@ def forward( "decoder should map the embeddings (`inputs`) to." ) features = self.encoder(inputs) if self.encoder else inputs - decoder_inputs = DecoderInputs(features, inputs.shape[-2:], inputs) # type: ignore + decoder_inputs = DecoderInputs(features, to_size or inputs.shape[-2:], inputs) # type: ignore return self.decoder(decoder_inputs) @override diff --git a/src/eva/vision/utils/colormap.py b/src/eva/vision/utils/colormap.py index 2ca70604..0b1efd90 100644 --- a/src/eva/vision/utils/colormap.py +++ b/src/eva/vision/utils/colormap.py @@ -1,5 +1,7 @@ """Color mapping constants.""" +from typing import List, Tuple + COLORS = [ (0, 0, 0), (255, 0, 0), # Red @@ -75,3 +77,21 @@ COLORMAP = dict(enumerate(COLORS)) | {255: (255, 255, 255)} """Class id to RGB color mapping.""" + + +def get_colors(num_colors: int) -> List[Tuple[int, int, int]]: + """Get a list of RGB colors. + + If the number of colors is greater than the predefined colors, it will + repeat the colors until it reaches the requested number + + Args: + num_colors: The number of colors to return. + + Returns: + A list of RGB colors. + """ + colors = COLORS + while len(colors) < num_colors: + colors = colors + COLORS[1:] + return colors diff --git a/tests/eva/assets/core/archives/test.txt.gz b/tests/eva/assets/core/archives/test.txt.gz new file mode 100644 index 00000000..b1aaafa4 Binary files /dev/null and b/tests/eva/assets/core/archives/test.txt.gz differ diff --git a/tests/eva/core/utils/io/__init__.py b/tests/eva/core/utils/io/__init__.py new file mode 100644 index 00000000..cdc5a43f --- /dev/null +++ b/tests/eva/core/utils/io/__init__.py @@ -0,0 +1 @@ +"""Tests the core io utilities.""" diff --git a/tests/eva/core/utils/io/test_gz.py b/tests/eva/core/utils/io/test_gz.py new file mode 100644 index 00000000..2f6cad7c --- /dev/null +++ b/tests/eva/core/utils/io/test_gz.py @@ -0,0 +1,40 @@ +"""Tests for .gz file utilities.""" + +import os +import shutil + +import pytest + +from eva.core.utils.io import gz + + +@pytest.mark.parametrize( + "subdir, keep", + [ + (None, True), + ("test_subdir", True), + (None, False), + ], +) +def test_gunzip(tmp_path: str, gzip_file: str, subdir: str | None, keep: bool) -> None: + """Verifies proper extraction of gzip file contents.""" + unpack_dir = os.path.join(tmp_path, subdir) if subdir else tmp_path + tmp_gzip_path = os.path.join(tmp_path, os.path.basename(gzip_file)) + shutil.copy(gzip_file, tmp_gzip_path) + gz.gunzip_file(tmp_gzip_path, unpack_dir=unpack_dir, keep=keep) + + uncompressed_path = os.path.join(unpack_dir, "test.txt") + assert os.path.isfile(uncompressed_path) + with open(uncompressed_path, "r") as f: + assert f.read() == "gz file test" + + if keep: + assert os.path.isfile(tmp_gzip_path) + else: + assert not os.path.isfile(tmp_gzip_path) + + +@pytest.fixture() +def gzip_file(assets_path: str) -> str: + """Provides the path to the test gzip file asset.""" + return os.path.join(assets_path, "core/archives/test.txt.gz") diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 93a2dbf0..06baffdd 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -1,6 +1,7 @@ """TotalSegmentator2D dataset tests.""" import os +import shutil from typing import Literal import pytest @@ -46,17 +47,19 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i @pytest.fixture(scope="function") def total_segmentator_dataset( - split: Literal["train", "val"] | None, assets_path: str + tmp_path: str, split: Literal["train", "val"] | None, assets_path: str ) -> datasets.TotalSegmentator2D: """TotalSegmentator2D dataset fixture.""" + dataset_dir = os.path.join( + assets_path, + "vision", + "datasets", + "total_segmentator", + "Totalsegmentator_dataset_v201", + ) + shutil.copytree(dataset_dir, tmp_path, dirs_exist_ok=True) dataset = datasets.TotalSegmentator2D( - root=os.path.join( - assets_path, - "vision", - "datasets", - "total_segmentator", - "Totalsegmentator_dataset_v201", - ), + root=tmp_path, split=split, version=None, )