Skip to content

Commit

Permalink
added missing docstrings & formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Nov 22, 2024
1 parent bd4767f commit c8545a7
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 30 deletions.
5 changes: 4 additions & 1 deletion src/eva/core/data/dataloaders/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ class DataLoader:
prefetch_factor: int | None = 2
"""Number of batches loaded in advance by each worker."""

def __call__(self, dataset: datasets.TorchDataset, sampler: samplers.Sampler | None = None) -> dataloader.DataLoader:
def __call__(
self, dataset: datasets.TorchDataset, sampler: samplers.Sampler | None = None
) -> dataloader.DataLoader:
"""Returns the dataloader on the provided dataset.
Args:
dataset: dataset from which to load the data.
sampler: defines the strategy to draw samples from the dataset.
"""
return dataloader.DataLoader(
dataset=dataset,
Expand Down
24 changes: 17 additions & 7 deletions src/eva/core/data/datamodules/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
Args:
datasets: The desired datasets.
dataloaders: The desired dataloaders.
samplers: The desired samplers for the dataloaders.
"""
super().__init__()

Expand Down Expand Up @@ -74,44 +75,53 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
)
if isinstance(self.datasets.train, list) and len(self.datasets.train) > 1:
raise ValueError("Train dataloader can not be initialized with multiple datasets.")

return self._initialize_dataloaders(self.dataloaders.train, self.datasets.train, self.samplers.train)[0]

return self._initialize_dataloaders(
self.dataloaders.train, self.datasets.train, self.samplers.train
)[0]

@override
def val_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.val is None:
raise ValueError(
"Validation dataloader can not be initialized as `self.datasets.val` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val, self.samplers.val)
return self._initialize_dataloaders(
self.dataloaders.val, self.datasets.val, self.samplers.val
)

@override
def test_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.test is None:
raise ValueError(
"Test dataloader can not be initialized as `self.datasets.test` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test, self.samplers.test)
return self._initialize_dataloaders(
self.dataloaders.test, self.datasets.test, self.samplers.test
)

@override
def predict_dataloader(self) -> EVAL_DATALOADERS:
if self.datasets.predict is None:
raise ValueError(
"Predict dataloader can not be initialized as `self.datasets.predict` is `None`."
)
return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict, self.samplers.predict)
return self._initialize_dataloaders(
self.dataloaders.predict, self.datasets.predict, self.samplers.predict
)

def _initialize_dataloaders(
self,
dataloader: dataloaders_lib.DataLoader,
datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset],
sampler: samplers_lib.Sampler | None = None
sampler: samplers_lib.Sampler | None = None,
) -> EVAL_DATALOADERS:
"""Initializes dataloaders from a given set of dataset.
Args:
dataloader: The dataloader to apply to the provided datasets.
datasets: The desired dataset(s) to allocate dataloader(s).
sampler: The sampler to use for the dataloader.
Returns:
A list with the dataloaders of the provided dataset(s).
Expand All @@ -121,6 +131,6 @@ def _initialize_dataloaders(
dataloaders = []
for dataset in datasets:
if sampler and isinstance(sampler, samplers_lib.SamplerWithDataSource):
sampler.set_dataset(dataset)
sampler.set_dataset(dataset) # type: ignore
dataloaders.append(dataloader(dataset, sampler=sampler))
return dataloaders
1 change: 1 addition & 0 deletions src/eva/core/data/datamodules/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DataloadersSchema:
predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader)
"""Predict dataloader."""


@dataclasses.dataclass(frozen=True)
class SamplersSchema:
"""Samplers schema used in DataModule."""
Expand Down
6 changes: 3 additions & 3 deletions src/eva/core/data/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Base dataset class."""

from eva.core.data.datasets import dataset
import abc

from eva.core.data.datasets import dataset


class Dataset(dataset.TorchDataset):
"""Base dataset class."""
Expand Down Expand Up @@ -53,6 +54,7 @@ def teardown(self) -> None:
called from every process (i.e. GPU) across all the nodes in DDP.
"""


class MapDataset(Dataset):
"""Abstract base class for all map-style datasets."""

Expand All @@ -76,5 +78,3 @@ def __len__(self) -> int:
int: Length of the dataset
"""
raise NotImplementedError


4 changes: 2 additions & 2 deletions src/eva/core/data/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Data samplers API."""

from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource
from eva.core.data.samplers.random import RandomSampler
from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource

__all__ = ["Sampler", "RandomSampler"]
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler"]
37 changes: 23 additions & 14 deletions src/eva/core/data/samplers/random.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
"""Random sampler for data loading."""

from eva.core.data.samplers.sampler import SamplerWithDataSource
from eva.core.data import datasets
from typing import Optional

from torch.utils import data
from typing_extensions import override

from eva.core.data import datasets
from eva.core.data.samplers.sampler import SamplerWithDataSource

class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.

Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
"""Samples elements randomly."""

data_source: datasets.MapDataset # type: ignore
replacement: bool

def __init__(self, replacement: bool = False, num_samples: Optional[int] = None, generator=None) -> None:
def __init__(
self, replacement: bool = False, num_samples: Optional[int] = None, generator=None
) -> None:
"""Initializes the random sampler.
Args:
data_source: dataset to sample from
replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples: number of samples to draw, default=`len(dataset)`.
generator: Generator used in sampling.
"""
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator

@override
def set_dataset(self, data_source: datasets.MapDataset) -> None:
super().__init__(data_source, replacement=self.replacement, num_samples=self.num_samples, generator=self.generator)
super().__init__(
data_source,
replacement=self.replacement,
num_samples=self.num_samples,
generator=self.generator,
)
10 changes: 8 additions & 2 deletions src/eva/core/data/samplers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Core data sampler."""

from typing import TypeVar, Generic
from typing import Generic, TypeVar

from torch.utils import data

from eva.core.data import datasets

Sampler = data.Sampler
"""Core abstract data sampler class."""

T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)


class SamplerWithDataSource(Sampler, Generic[T_co]):
"""A sampler base class that enables to specify the data source after initialization."""
Expand All @@ -17,6 +20,9 @@ class SamplerWithDataSource(Sampler, Generic[T_co]):
def set_dataset(self, data_source: datasets.MapDataset) -> None:
"""Sets the dataset to sample from.
This is not done in the constructor because the dataset might not be
available at that time.
Args:
data_source: The dataset to sample from.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/eva/vision/data/datasets/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def filename(self, index: int) -> str:
Returns:
The filename of the `index`'th data sample.
"""
"""

0 comments on commit c8545a7

Please sign in to comment.