Skip to content

Commit

Permalink
Add check if dataset exists (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Mar 21, 2024
1 parent 7375768 commit 7195c23
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 7 deletions.
15 changes: 15 additions & 0 deletions src/eva/vision/data/datasets/_validators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Dataset validation related functions."""

import os

from typing_extensions import List, Tuple

from eva.vision.data.datasets import vision
Expand Down Expand Up @@ -42,3 +44,16 @@ def check_dataset_integrity(
f"({(dataset_classes[0], dataset_classes[-1])}) does not match the expected "
f"ones ({first_and_last_labels}). {_SUFFIX_ERROR_MESSAGE}"
)


def check_dataset_exists(dataset_dir: str, download_available: bool) -> None:
"""Verifies that the dataset folder exists.
Raise:
FileNotFoundError: If the dataset folder does not exist.
"""
if not os.path.isdir(dataset_dir):
error_message = f"Dataset not found at '{dataset_dir}'."
if download_available:
error_message += " You can set `download=True` to download the dataset automatically."
raise FileNotFoundError(error_message)
9 changes: 5 additions & 4 deletions src/eva/vision/data/datasets/classification/bach.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,25 @@ def class_to_idx(self) -> Dict[str, int]:
return {"Benign": 0, "InSitu": 1, "Invasive": 2, "Normal": 3}

@property
def dataset_path(self) -> str:
def _dataset_path(self) -> str:
"""Returns the path of the image data of the dataset."""
return os.path.join(self._root, "ICIAR2018_BACH_Challenge", "Photos")

@override
def filename(self, index: int) -> str:
image_path, _ = self._samples[self._indices[index]]
return os.path.relpath(image_path, self.dataset_path)
return os.path.relpath(image_path, self._dataset_path)

@override
def prepare_data(self) -> None:
if self._download:
self._download_dataset()
_validators.check_dataset_exists(self._root, True)

@override
def configure(self) -> None:
self._samples = folder.make_dataset(
directory=self.dataset_path,
directory=self._dataset_path,
class_to_idx=self.class_to_idx,
extensions=(".tif"),
)
Expand Down Expand Up @@ -145,7 +146,7 @@ def __len__(self) -> int:
def _download_dataset(self) -> None:
"""Downloads the dataset."""
for resource in self._resources:
if os.path.isdir(self.dataset_path):
if os.path.isdir(self._dataset_path):
continue

self._print_license()
Expand Down
7 changes: 4 additions & 3 deletions src/eva/vision/data/datasets/classification/crc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,13 @@ def class_to_idx(self) -> Dict[str, int]:
@override
def filename(self, index: int) -> str:
image_path, *_ = self._samples[index]
return os.path.relpath(image_path, self._dataset_dir)
return os.path.relpath(image_path, self._dataset_path)

@override
def prepare_data(self) -> None:
if self._download:
self._download_dataset()
_validators.check_dataset_exists(self._root, True)

@override
def configure(self) -> None:
Expand Down Expand Up @@ -135,7 +136,7 @@ def __len__(self) -> int:
return len(self._samples)

@property
def _dataset_dir(self) -> str:
def _dataset_path(self) -> str:
"""Returns the full path of dataset directory."""
dataset_dirs = {
"train": os.path.join(self._root, "NCT-CRC-HE-100K"),
Expand All @@ -150,7 +151,7 @@ def _dataset_dir(self) -> str:
def _make_dataset(self) -> List[Tuple[str, int]]:
"""Builds the dataset for the specified split."""
dataset = folder.make_dataset(
directory=self._dataset_dir,
directory=self._dataset_path,
class_to_idx=self.class_to_idx,
extensions=(".tif"),
)
Expand Down
4 changes: 4 additions & 0 deletions src/eva/vision/data/datasets/classification/mhist.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def filename(self, index: int) -> str:
image_filename, _ = self._samples[index]
return image_filename

@override
def prepare_data(self) -> None:
_validators.check_dataset_exists(self._root, False)

@override
def configure(self) -> None:
self._samples = self._make_dataset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def filename(self, index: int) -> str:
def prepare_data(self) -> None:
if self._download:
self._download_dataset()
_validators.check_dataset_exists(self._root, True)

@override
def validate(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def filename(self, index: int) -> str:
def prepare_data(self) -> None:
if self._download:
self._download_dataset()
_validators.check_dataset_exists(self._root, True)

@override
def configure(self) -> None:
Expand Down

0 comments on commit 7195c23

Please sign in to comment.