Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dataloader samplers #713

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/eva/core/data/dataloaders/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,20 @@ class DataLoader:
prefetch_factor: int | None = 2
"""Number of batches loaded in advance by each worker."""

def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader:
def __call__(
self, dataset: datasets.TorchDataset, sampler: samplers.Sampler | None = None
) -> dataloader.DataLoader:
"""Returns the dataloader on the provided dataset.

Args:
dataset: dataset from which to load the data.
sampler: defines the strategy to draw samples from the dataset.
"""
return dataloader.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
sampler=self.sampler,
sampler=sampler or self.sampler,
batch_sampler=self.batch_sampler,
num_workers=self.num_workers or multiprocessing.cpu_count(),
collate_fn=self.collate_fn,
Expand Down
47 changes: 42 additions & 5 deletions src/eva/core/data/datamodules/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from eva.core.data import dataloaders as dataloaders_lib
from eva.core.data import datasets as datasets_lib
from eva.core.data import samplers as samplers_lib
from eva.core.data.datamodules import call, schemas


Expand All @@ -24,17 +25,20 @@ def __init__(
self,
datasets: schemas.DatasetsSchema | None = None,
dataloaders: schemas.DataloadersSchema | None = None,
samplers: schemas.SamplersSchema | None = None,
) -> None:
"""Initializes the datamodule.

Args:
datasets: The desired datasets.
dataloaders: The desired dataloaders.
samplers: The desired samplers for the dataloaders.
"""
super().__init__()

self.datasets = datasets or self.default_datasets
self.dataloaders = dataloaders or self.default_dataloaders
self.samplers = samplers or self.default_samplers

@property
def default_datasets(self) -> schemas.DatasetsSchema:
Expand All @@ -46,6 +50,11 @@ def default_dataloaders(self) -> schemas.DataloadersSchema:
"""Returns the default dataloader schema."""
return schemas.DataloadersSchema()

@property
def default_samplers(self) -> schemas.SamplersSchema:
"""Returns the default samplers schema."""
return schemas.SamplersSchema()

@override
def prepare_data(self) -> None:
call.call_method_if_exists(self.datasets.tolist(), "prepare_data")
Expand All @@ -64,45 +73,73 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
raise ValueError(
"Train dataloader can not be initialized as `self.datasets.train` is `None`."
)
return self.dataloaders.train(self.datasets.train)
if isinstance(self.datasets.train, list) and len(self.datasets.train) > 1:
raise ValueError("Train dataloader can not be initialized with multiple datasets.")

return self._initialize_dataloaders(
self.dataloaders.train, self.datasets.train, self.samplers.train
)[0]

@override
def val_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.val is None:
raise ValueError(
"Validation dataloader can not be initialized as `self.datasets.val` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val)
return self._initialize_dataloaders(
self.dataloaders.val, self.datasets.val, self.samplers.val
)

@override
def test_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.test is None:
raise ValueError(
"Test dataloader can not be initialized as `self.datasets.test` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test)
return self._initialize_dataloaders(
self.dataloaders.test, self.datasets.test, self.samplers.test
)

@override
def predict_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.predict is None:
raise ValueError(
"Predict dataloader can not be initialized as `self.datasets.predict` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict)
if isinstance(self.datasets.predict, list) and len(self.datasets.predict) > 1:
# Only apply sampler to the first predict dataset (should correspond to train split)
train_dataloader = self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict[0], self.samplers.predict
)
return train_dataloader + self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict[1:]
)

return self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict, self.samplers.predict
)

def _initialize_dataloaders(
self,
dataloader: dataloaders_lib.DataLoader,
datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset],
sampler: samplers_lib.Sampler | None = None,
) -> EVAL_DATALOADERS:
"""Initializes dataloaders from a given set of dataset.

Args:
dataloader: The dataloader to apply to the provided datasets.
datasets: The desired dataset(s) to allocate dataloader(s).
sampler: The sampler to use for the dataloader.

Returns:
A list with the dataloaders of the provided dataset(s).
"""
datasets = datasets if isinstance(datasets, list) else [datasets]
return list(map(dataloader, datasets))

dataloaders = []
for dataset in datasets:
if sampler is not None and isinstance(sampler, samplers_lib.SamplerWithDataSource):
sampler.set_dataset(dataset) # type: ignore
dataloaders.append(dataloader(dataset, sampler=sampler))
return dataloaders
19 changes: 18 additions & 1 deletion src/eva/core/data/datamodules/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
from typing import List

from eva.core.data import dataloaders, datasets
from eva.core.data import dataloaders, datasets, samplers

TRAIN_DATASET = datasets.TorchDataset | None
"""Train dataset."""
Expand Down Expand Up @@ -60,3 +60,20 @@ class DataloadersSchema:

predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
"""Predict dataloader."""


@dataclasses.dataclass(frozen=True)
class SamplersSchema:
"""Samplers schema used in DataModule."""

train: samplers.Sampler | None = None
"""Train sampler."""

val: samplers.Sampler | None = None
"""Validation sampler."""

test: samplers.Sampler | None = None
"""Test sampler."""

predict: samplers.Sampler | None = None
"""Predict sampler."""
5 changes: 4 additions & 1 deletion src/eva/core/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Datasets API."""

from eva.core.data.datasets.base import Dataset
from eva.core.data.datasets.base import Dataset, MapDataset
from eva.core.data.datasets.classification import (
EmbeddingsClassificationDataset,
MultiEmbeddingsClassificationDataset,
)
from eva.core.data.datasets.dataset import TorchDataset
from eva.core.data.datasets.typings import DataSample

__all__ = [
"Dataset",
"MapDataset",
"EmbeddingsClassificationDataset",
"MultiEmbeddingsClassificationDataset",
"TorchDataset",
"DataSample",
]
27 changes: 27 additions & 0 deletions src/eva/core/data/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base dataset class."""

import abc

from eva.core.data.datasets import dataset


Expand Down Expand Up @@ -51,3 +53,28 @@ def teardown(self) -> None:
of fit (train + validate), validate, test, or predict and it will be
called from every process (i.e. GPU) across all the nodes in DDP.
"""


class MapDataset(Dataset):
"""Abstract base class for all map-style datasets."""

@abc.abstractmethod
def __getitem__(self, index: int):
"""Retrieves the item at the given index.

Args:
index (int): Index

Returns:
Any: The data at the given index
"""
raise NotImplementedError

@abc.abstractmethod
def __len__(self) -> int:
"""Returns the length of the dataset.

Returns:
int: Length of the dataset
"""
raise NotImplementedError
18 changes: 18 additions & 0 deletions src/eva/core/data/datasets/typings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Typing definitions for the datasets module."""

from typing import Any, Dict, NamedTuple

import torch


class DataSample(NamedTuple):
"""The default input batch data scheme."""

data: torch.Tensor
"""The data batch."""

targets: torch.Tensor | None = None
"""The target batch."""

metadata: Dict[str, Any] | None = None
"""The associated metadata."""
6 changes: 4 additions & 2 deletions src/eva/core/data/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Data samplers API."""

from eva.core.data.samplers.sampler import Sampler
from eva.core.data.samplers.classification.balanced import BalancedSampler
from eva.core.data.samplers.random import RandomSampler
from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource

__all__ = ["Sampler"]
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"]
5 changes: 5 additions & 0 deletions src/eva/core/data/samplers/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Classification data samplers API."""

from eva.core.data.samplers.classification.balanced import BalancedSampler

__all__ = ["BalancedSampler"]
96 changes: 96 additions & 0 deletions src/eva/core/data/samplers/classification/balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Random class sampler for data loading."""

from collections import defaultdict
from typing import Dict, Iterator, List

import numpy as np
from typing_extensions import override

from eva.core.data import datasets
from eva.core.data.datasets.typings import DataSample
from eva.core.data.samplers.sampler import SamplerWithDataSource
from eva.core.utils.progress_bar import tqdm


class BalancedSampler(SamplerWithDataSource[int]):
"""Balanced class sampler for data loading.

The sampler ensures that:
1. Each class has the same number of samples
2. Samples within each class are randomly selected
3. Samples of different classes appear in random order
"""

def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = 42):
"""Initializes the balanced sampler.

Args:
num_samples: The number of samples to draw per class.
replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
seed: Random seed for reproducibility.
"""
self._num_samples = num_samples
self._replacement = replacement
self._class_indices: Dict[int, List[int]] = defaultdict(list)
self._random_generator = np.random.default_rng(seed)

def __len__(self) -> int:
"""Returns the total number of samples."""
return self._num_samples * len(self._class_indices)

def __iter__(self) -> Iterator[int]:
"""Creates an iterator that yields indices in a class balanced way.

Returns:
Iterator yielding dataset indices.
"""
indices = []

for class_idx in self._class_indices:
class_indices = self._class_indices[class_idx]
sampled_indices = self._random_generator.choice(
class_indices, size=self._num_samples, replace=self._replacement
).tolist()
indices.extend(sampled_indices)

self._random_generator.shuffle(indices)

return iter(indices)

@override
def set_dataset(self, data_source: datasets.MapDataset):
"""Sets the dataset and builds class indices.

Args:
data_source: The dataset to sample from.

Raises:
ValueError: If the dataset doesn't have targets or if any class has
fewer samples than `num_samples` and `replacement` is `False`.
"""
super().set_dataset(data_source)
self._make_indices()

def _make_indices(self):
"""Builds indices for each class in the dataset."""
self._class_indices.clear()

for idx in tqdm(
range(len(self.data_source)), desc="Fetching class indices for balanced sampler"
):
_, target, _ = DataSample(*self.data_source[idx])
if target is None:
raise ValueError("The dataset must return non-empty targets.")
if target.numel() != 1:
raise ValueError("The dataset must return a single & scalar target.")

class_idx = int(target.item())
self._class_indices[class_idx].append(idx)

if not self._replacement:
for class_idx, indices in self._class_indices.items():
if len(indices) < self._num_samples:
raise ValueError(
f"Class {class_idx} has only {len(indices)} samples, "
f"which is less than the required {self._num_samples} samples."
)
Loading