From 3088bd74b14cf2a669aca67cd4f853c9a1be343e Mon Sep 17 00:00:00 2001 From: Piotr Czarnik Date: Fri, 31 May 2024 17:23:18 +0200 Subject: [PATCH] Refactor each type of datasets into more abstract classes --- mnists/_emnist.py | 153 ++------------------------ mnists/_mnist.py | 103 +---------------- mnists/dataset.py | 273 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 285 insertions(+), 244 deletions(-) create mode 100644 mnists/dataset.py diff --git a/mnists/_emnist.py b/mnists/_emnist.py index aa8ed12..c3e1ff5 100644 --- a/mnists/_emnist.py +++ b/mnists/_emnist.py @@ -1,13 +1,9 @@ -import os from typing import Optional -import numpy as np +from .dataset import SplitDataset, ZippedDataset -from ._mnist import MNIST, TEMPORARY_DIR -from .utils import check_file_integrity, extract_from_zip - -class EMNIST(MNIST): +class EMNIST(SplitDataset): """ EMNIST Dataset https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist @@ -51,21 +47,7 @@ def __init__( download: bool = True, force_download: bool = False, ) -> None: - """ - Parameters - ---------- - target_dir : str, default='/tmp/emnist/' - Directory where zip exists or will be downloaded to (if `download` is True). - download : bool, default=True - If True and zip doesn't exist in `target_dir`, downloads zip to `target_dir`. - force_download : bool, default=False - If True, downloads zip to `target_dir`, even if it exists there. - """ - self.target_dir = ( - os.path.join(TEMPORARY_DIR, type(self).__name__) - if target_dir is None - else target_dir - ) + super().__init__(target_dir, download, force_download) self.Balanced = self._create_split(Balanced) self.ByClass = self._create_split(ByClass) @@ -73,127 +55,8 @@ def __init__( self.Digits = self._create_split(Digits) self.Letters = self._create_split(Letters) - if download or force_download: - self.download(force_download) - - def _not_implemented(self): - raise NotImplementedError( - "Method is not implemented because EMNIST is a parent class and " - "is used only for downloading zip file. For accessing datasets " - "use any of the child classes: emnist.Balanced, emnist.ByClass, " - "emnist.ByMerge, emnist.Digits or emnist.Leters." - ) - - train_images = _not_implemented - train_labels = _not_implemented - test_images = _not_implemented - test_labels = _not_implemented - - def _create_split(self, split_cls: type["_Split"]) -> type["_Split"]: - split_cls.default_base_dir = self.target_dir - split_cls.default_zip_filepath = os.path.join(self.target_dir, "gzip.zip") - split_cls.zip_md5 = self.resources["gzip"][1] - return split_cls - -class _Split(MNIST): - default_base_dir = os.path.join(TEMPORARY_DIR, "emnist") - default_zip_filepath = os.path.join(default_base_dir, "gzip.zip") - zip_md5 = None - - def __init__( - self, - target_dir: Optional[str] = None, - zip_filepath: Optional[str] = None, - unzip: bool = True, - force_unzip: bool = False, - load: bool = True, - transpose: bool = True, - ) -> None: - """ - Parameters - ---------- - target_dir : str, default='/tmp/emnist//' - Directory where all files exist or will be unzipped to (if `unzip` is True). - zip_filepath : str, default='/tmp/emnist/gzip.zip' - Filepath to zip file containing all EMNIST split files. - unzip : bool, default=True - If True and files don't exist in `target_dir`, unzips all files to `target_dir`. - force_unzip : bool, default=False - If True, unzips all files to `target_dir`, even if they exist there. - load : bool, default=True - If True, loads data from files in `target_dir`. - transpose : bool, default=True - If True, transposes train and test images. - """ - - self.target_dir = ( - os.path.join(self.default_base_dir, type(self).__name__) - if target_dir is None - else target_dir - ) - - self.zip_filepath = ( - self.default_zip_filepath if zip_filepath is None else zip_filepath - ) - - self._train_images: Optional[np.ndarray] = None - self._train_labels: Optional[np.ndarray] = None - self._test_images: Optional[np.ndarray] = None - self._test_labels: Optional[np.ndarray] = None - - if unzip or force_unzip: - self.unzip_files(force_unzip) - - if load: - self.load(transpose) - - def unzip_files(self, force: bool = False) -> None: - """ - Unzip files from `zip_filepath` to `target_dir`. - - Parameters - ---------- - force : bool=False - If True, unzips all files even if they exist. - """ - - os.makedirs(self.target_dir, exist_ok=True) - if not check_file_integrity(self.zip_filepath, self.zip_md5): - raise RuntimeError( - f"Zip file '{self.zip_filepath}' doesn't exists or its MD5" - "checksum is not valid. " - "Use EMNIST(download=True) or emnist.download() to download it" - ) - - for filename, md5 in self.resources.values(): - filepath = os.path.join(self.target_dir, filename) - - if not force and check_file_integrity(filepath, md5): - continue - - extract_from_zip(self.zip_filepath, filename, self.target_dir) - - def load(self, transpose: bool = True) -> None: - """ - Load data from files in `target_dir` and transpose images (by default). - - Parameters - ---------- - transpose : bool=True - If True, transposes train and test images. - """ - - super().load() - if transpose: - self._transpose_images() - - def _transpose_images(self) -> None: - self._train_images = np.moveaxis(self._train_images, -2, -1) - self._test_images = np.moveaxis(self._test_images, -2, -1) - - -class Balanced(_Split): +class Balanced(ZippedDataset): """ EMNIST Balanced https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist @@ -245,7 +108,7 @@ class Balanced(_Split): } -class ByClass(_Split): +class ByClass(ZippedDataset): """ EMNIST ByClass https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist @@ -297,7 +160,7 @@ class ByClass(_Split): } -class ByMerge(_Split): +class ByMerge(ZippedDataset): """ EMNIST ByMerge https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist @@ -349,7 +212,7 @@ class ByMerge(_Split): } -class Digits(_Split): +class Digits(ZippedDataset): """ EMNIST Digits https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist @@ -401,7 +264,7 @@ class Digits(_Split): } -class Letters(_Split): +class Letters(ZippedDataset): """ EMNIST Letters https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist diff --git a/mnists/_mnist.py b/mnists/_mnist.py index bc2b291..bd6ff67 100644 --- a/mnists/_mnist.py +++ b/mnists/_mnist.py @@ -1,15 +1,7 @@ -import os -import tempfile -from typing import Optional +from .dataset import IdxDataset -import numpy as np -from .utils import check_file_integrity, download_file, read_idx_file - -TEMPORARY_DIR = os.path.join(tempfile.gettempdir(), "mnists") - - -class MNIST: +class MNIST(IdxDataset): """ MNIST Dataset http://yann.lecun.com/exdb/mnist @@ -98,95 +90,8 @@ class MNIST: ), } - def __init__( - self, - target_dir: Optional[str] = None, - download: bool = True, - force_download: bool = False, - load: bool = True, - ) -> None: - """ - Parameters - ---------- - target_dir : str, default='/tmp//' - Directory where all files exist or will be downloaded to (if `download` is True). - download : bool, default=True - If True and files don't exist in `target_dir`, downloads all files to `target_dir`. - force_download : bool, default=False - If True, downloads all files to `target_dir`, even if they exist there. - load : bool, default=True - If True, loads data from files in `target_dir`. - """ - - self.target_dir = ( - os.path.join(TEMPORARY_DIR, type(self).__name__) - if target_dir is None - else target_dir - ) - - self._train_images: Optional[np.ndarray] = None - self._train_labels: Optional[np.ndarray] = None - self._test_images: Optional[np.ndarray] = None - self._test_labels: Optional[np.ndarray] = None - - if download or force_download: - self.download(force_download) - - if load: - self.load() - - def train_images(self) -> Optional[np.ndarray]: - return self._train_images - - def train_labels(self) -> Optional[np.ndarray]: - return self._train_labels - - def test_images(self) -> Optional[np.ndarray]: - return self._test_images - - def test_labels(self) -> Optional[np.ndarray]: - return self._test_labels - - def download(self, force: bool = False) -> None: - """ - Download files from mirrors and save to `target_dir`. - - Parameters - ---------- - force : bool=False - If True, downloads all files even if they exist. - """ - - os.makedirs(self.target_dir, exist_ok=True) - - for filename, md5 in self.resources.values(): - filepath = os.path.join(self.target_dir, filename) - - if not force and check_file_integrity(filepath, md5): - continue - - download_file(self.mirrors, filename, filepath) - - def load(self) -> None: - """ - Load data from files in `target_dir`. - """ - - for key, (filename, md5) in self.resources.items(): - filepath = os.path.join(self.target_dir, filename) - - if not check_file_integrity(filepath, md5): - raise RuntimeError( - f"Dataset '{key}' not found in '{filepath}' or MD5 " - "checksum is not valid. " - "Use download=True or .download() to download it" - ) - - data = read_idx_file(filepath) - setattr(self, f"_{key}", data) - -class FashionMNIST(MNIST): +class FashionMNIST(IdxDataset): """ Fashion-MNIST Dataset https://github.com/zalandoresearch/fashion-mnist @@ -260,7 +165,7 @@ class FashionMNIST(MNIST): } -class KMNIST(MNIST): +class KMNIST(IdxDataset): """ Kuzushiji-MNIST Dataset https://github.com/rois-codh/kmnist diff --git a/mnists/dataset.py b/mnists/dataset.py new file mode 100644 index 0000000..46a5c48 --- /dev/null +++ b/mnists/dataset.py @@ -0,0 +1,273 @@ +import os +import tempfile +from typing import Optional + +import numpy as np + +from .utils import check_file_integrity, download_file, extract_from_zip, read_idx_file + +TEMPORARY_DIR = os.path.join(tempfile.gettempdir(), "mnists") + + +class Dataset: + mirrors = [] + resources = {} + + def __init__( + self, + target_dir: Optional[str] = None, + download: bool = True, + force_download: bool = False, + ) -> None: + """ + Parameters + ---------- + target_dir : str, default='/tmp//' + Directory where all files exist or will be downloaded to (if `download` is True). + download : bool, default=True + If True and files don't exist in `target_dir`, downloads all files to `target_dir`. + force_download : bool, default=False + If True, downloads all files to `target_dir`, even if they exist there. + """ + + self.target_dir = ( + os.path.join(TEMPORARY_DIR, type(self).__name__) + if target_dir is None + else target_dir + ) + + if download or force_download: + self.download(force_download) + + def download(self, force: bool = False) -> None: + """ + Download files from mirrors and save to `target_dir`. + + Parameters + ---------- + force : bool=False + If True, downloads all files even if they exist. + """ + + os.makedirs(self.target_dir, exist_ok=True) + + for filename, md5 in self.resources.values(): + filepath = os.path.join(self.target_dir, filename) + + if not force and check_file_integrity(filepath, md5): + continue + + download_file(self.mirrors, filename, filepath) + + +class IdxDataset(Dataset): + def __init__( + self, + target_dir: Optional[str] = None, + download: bool = True, + force_download: bool = False, + load: bool = True, + ) -> None: + """ + Parameters + ---------- + target_dir : str, default='/tmp//' + Directory where all files exist or will be downloaded to (if `download` is True). + download : bool, default=True + If True and files don't exist in `target_dir`, downloads all files to `target_dir`. + force_download : bool, default=False + If True, downloads all files to `target_dir`, even if they exist there. + load : bool, default=True + If True, loads data from files in `target_dir`. + """ + + self.target_dir = ( + os.path.join(TEMPORARY_DIR, type(self).__name__) + if target_dir is None + else target_dir + ) + + self._train_images: Optional[np.ndarray] = None + self._train_labels: Optional[np.ndarray] = None + self._test_images: Optional[np.ndarray] = None + self._test_labels: Optional[np.ndarray] = None + + if download or force_download: + self.download(force_download) + + if load: + self.load() + + def train_images(self) -> np.ndarray: + """ + Return train_images numpy array. + + Returns + ------- + np.ndarray + """ + if self._train_images is None: + self._raise_dataset_not_loaded() + return self._train_images + + def train_labels(self) -> np.ndarray: + """ + Return train_labels numpy array. + + Returns + ------- + np.ndarray + """ + if self._train_labels is None: + self._raise_dataset_not_loaded() + return self._train_labels + + def test_images(self) -> np.ndarray: + """ + Return test_images numpy array. + + Returns + ------- + np.ndarray + """ + if self._test_images is None: + self._raise_dataset_not_loaded() + return self._test_images + + def test_labels(self) -> np.ndarray: + """ + Return test_labels numpy array. + + Returns + ------- + np.ndarray + """ + if self._test_labels is None: + self._raise_dataset_not_loaded() + return self._test_labels + + def _raise_dataset_not_loaded(self): + raise RuntimeError( + "Dataset wasn't loaded. You need to run .load() or create new " + "object with load=True" + ) + + def load(self, transpose=False) -> None: + """ + Load data from files in `target_dir`. + + Parameters + ---------- + transpose : bool=False + If True, transposes train and test images. + """ + + for key, (filename, md5) in self.resources.items(): + filepath = os.path.join(self.target_dir, filename) + + if not check_file_integrity(filepath, md5): + raise RuntimeError( + f"Dataset '{key}' not found in '{filepath}' or MD5 " + "checksum is not valid. " + "Use download=True or .download() to download it" + ) + + data = read_idx_file(filepath) + setattr(self, f"_{key}", data) + + if transpose: + self._transpose_images() + + def _transpose_images(self) -> None: + self._train_images = np.moveaxis(self._train_images, -2, -1) + self._test_images = np.moveaxis(self._test_images, -2, -1) + + +class SplitDataset(Dataset): + resources = {"gzip": (None, None)} + + def _create_split(self, split_cls: type["ZippedDataset"]) -> type["ZippedDataset"]: + split_cls.default_base_dir = self.target_dir + file, md5 = self.resources["gzip"] + split_cls.default_zip_filepath = os.path.join(self.target_dir, file) + split_cls.zip_md5 = md5 + return split_cls + + +class ZippedDataset(IdxDataset): + default_base_dir = None + default_zip_filepath = None + zip_md5 = None + + def __init__( + self, + target_dir: Optional[str] = None, + zip_filepath: Optional[str] = None, + unzip: bool = True, + force_unzip: bool = False, + load: bool = True, + transpose: bool = True, + ) -> None: + """ + Parameters + ---------- + target_dir : str, default='/tmp/emnist//' + Directory where all files exist or will be unzipped to (if `unzip` is True). + zip_filepath : str, default='/tmp/emnist/gzip.zip' + Filepath to zip file containing all EMNIST split files. + unzip : bool, default=True + If True and files don't exist in `target_dir`, unzips all files to `target_dir`. + force_unzip : bool, default=False + If True, unzips all files to `target_dir`, even if they exist there. + load : bool, default=True + If True, loads data from files in `target_dir`. + transpose : bool, default=True + If True, transposes train and test images. + """ + + self.target_dir = ( + os.path.join(self.default_base_dir, type(self).__name__) + if target_dir is None + else target_dir + ) + + self.zip_filepath = ( + self.default_zip_filepath if zip_filepath is None else zip_filepath + ) + + self._train_images: Optional[np.ndarray] = None + self._train_labels: Optional[np.ndarray] = None + self._test_images: Optional[np.ndarray] = None + self._test_labels: Optional[np.ndarray] = None + + if unzip or force_unzip: + self.unzip_files(force_unzip) + + if load: + self.load(transpose) + + def unzip_files(self, force: bool = False) -> None: + """ + Unzip files from `zip_filepath` to `target_dir`. + + Parameters + ---------- + force : bool=False + If True, unzips all files even if they exist. + """ + + os.makedirs(self.target_dir, exist_ok=True) + if not check_file_integrity(self.zip_filepath, self.zip_md5): + raise RuntimeError( + f"Zip file '{self.zip_filepath}' doesn't exists or its MD5" + "checksum is not valid. " + "Use EMNIST(download=True) or emnist.download() to download it" + ) + + for filename, md5 in self.resources.values(): + filepath = os.path.join(self.target_dir, filename) + + if not force and check_file_integrity(filepath, md5): + continue + + extract_from_zip(self.zip_filepath, filename, self.target_dir)