Skip to content

Commit

Permalink
Refactor each type of datasets into more abstract classes
Browse files Browse the repository at this point in the history
  • Loading branch information
pczarnik committed May 31, 2024
1 parent 87101c9 commit 3088bd7
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 244 deletions.
153 changes: 8 additions & 145 deletions mnists/_emnist.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import os
from typing import Optional

import numpy as np
from .dataset import SplitDataset, ZippedDataset

from ._mnist import MNIST, TEMPORARY_DIR
from .utils import check_file_integrity, extract_from_zip


class EMNIST(MNIST):
class EMNIST(SplitDataset):
"""
EMNIST Dataset
https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
Expand Down Expand Up @@ -51,149 +47,16 @@ def __init__(
download: bool = True,
force_download: bool = False,
) -> None:
"""
Parameters
----------
target_dir : str, default='/tmp/emnist/'
Directory where zip exists or will be downloaded to (if `download` is True).
download : bool, default=True
If True and zip doesn't exist in `target_dir`, downloads zip to `target_dir`.
force_download : bool, default=False
If True, downloads zip to `target_dir`, even if it exists there.
"""
self.target_dir = (
os.path.join(TEMPORARY_DIR, type(self).__name__)
if target_dir is None
else target_dir
)
super().__init__(target_dir, download, force_download)

self.Balanced = self._create_split(Balanced)
self.ByClass = self._create_split(ByClass)
self.ByMerge = self._create_split(ByMerge)
self.Digits = self._create_split(Digits)
self.Letters = self._create_split(Letters)

if download or force_download:
self.download(force_download)

def _not_implemented(self):
raise NotImplementedError(
"Method is not implemented because EMNIST is a parent class and "
"is used only for downloading zip file. For accessing datasets "
"use any of the child classes: emnist.Balanced, emnist.ByClass, "
"emnist.ByMerge, emnist.Digits or emnist.Leters."
)

train_images = _not_implemented
train_labels = _not_implemented
test_images = _not_implemented
test_labels = _not_implemented

def _create_split(self, split_cls: type["_Split"]) -> type["_Split"]:
split_cls.default_base_dir = self.target_dir
split_cls.default_zip_filepath = os.path.join(self.target_dir, "gzip.zip")
split_cls.zip_md5 = self.resources["gzip"][1]
return split_cls


class _Split(MNIST):
default_base_dir = os.path.join(TEMPORARY_DIR, "emnist")
default_zip_filepath = os.path.join(default_base_dir, "gzip.zip")
zip_md5 = None

def __init__(
self,
target_dir: Optional[str] = None,
zip_filepath: Optional[str] = None,
unzip: bool = True,
force_unzip: bool = False,
load: bool = True,
transpose: bool = True,
) -> None:
"""
Parameters
----------
target_dir : str, default='/tmp/emnist/<split_name>/'
Directory where all files exist or will be unzipped to (if `unzip` is True).
zip_filepath : str, default='/tmp/emnist/gzip.zip'
Filepath to zip file containing all EMNIST split files.
unzip : bool, default=True
If True and files don't exist in `target_dir`, unzips all files to `target_dir`.
force_unzip : bool, default=False
If True, unzips all files to `target_dir`, even if they exist there.
load : bool, default=True
If True, loads data from files in `target_dir`.
transpose : bool, default=True
If True, transposes train and test images.
"""

self.target_dir = (
os.path.join(self.default_base_dir, type(self).__name__)
if target_dir is None
else target_dir
)

self.zip_filepath = (
self.default_zip_filepath if zip_filepath is None else zip_filepath
)

self._train_images: Optional[np.ndarray] = None
self._train_labels: Optional[np.ndarray] = None
self._test_images: Optional[np.ndarray] = None
self._test_labels: Optional[np.ndarray] = None

if unzip or force_unzip:
self.unzip_files(force_unzip)

if load:
self.load(transpose)

def unzip_files(self, force: bool = False) -> None:
"""
Unzip files from `zip_filepath` to `target_dir`.
Parameters
----------
force : bool=False
If True, unzips all files even if they exist.
"""

os.makedirs(self.target_dir, exist_ok=True)
if not check_file_integrity(self.zip_filepath, self.zip_md5):
raise RuntimeError(
f"Zip file '{self.zip_filepath}' doesn't exists or its MD5"
"checksum is not valid. "
"Use EMNIST(download=True) or emnist.download() to download it"
)

for filename, md5 in self.resources.values():
filepath = os.path.join(self.target_dir, filename)

if not force and check_file_integrity(filepath, md5):
continue

extract_from_zip(self.zip_filepath, filename, self.target_dir)

def load(self, transpose: bool = True) -> None:
"""
Load data from files in `target_dir` and transpose images (by default).
Parameters
----------
transpose : bool=True
If True, transposes train and test images.
"""

super().load()
if transpose:
self._transpose_images()

def _transpose_images(self) -> None:
self._train_images = np.moveaxis(self._train_images, -2, -1)
self._test_images = np.moveaxis(self._test_images, -2, -1)


class Balanced(_Split):
class Balanced(ZippedDataset):
"""
EMNIST Balanced
https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
Expand Down Expand Up @@ -245,7 +108,7 @@ class Balanced(_Split):
}


class ByClass(_Split):
class ByClass(ZippedDataset):
"""
EMNIST ByClass
https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
Expand Down Expand Up @@ -297,7 +160,7 @@ class ByClass(_Split):
}


class ByMerge(_Split):
class ByMerge(ZippedDataset):
"""
EMNIST ByMerge
https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
Expand Down Expand Up @@ -349,7 +212,7 @@ class ByMerge(_Split):
}


class Digits(_Split):
class Digits(ZippedDataset):
"""
EMNIST Digits
https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
Expand Down Expand Up @@ -401,7 +264,7 @@ class Digits(_Split):
}


class Letters(_Split):
class Letters(ZippedDataset):
"""
EMNIST Letters
https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
Expand Down
103 changes: 4 additions & 99 deletions mnists/_mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import os
import tempfile
from typing import Optional
from .dataset import IdxDataset

import numpy as np

from .utils import check_file_integrity, download_file, read_idx_file

TEMPORARY_DIR = os.path.join(tempfile.gettempdir(), "mnists")


class MNIST:
class MNIST(IdxDataset):
"""
MNIST Dataset
http://yann.lecun.com/exdb/mnist
Expand Down Expand Up @@ -98,95 +90,8 @@ class MNIST:
),
}

def __init__(
self,
target_dir: Optional[str] = None,
download: bool = True,
force_download: bool = False,
load: bool = True,
) -> None:
"""
Parameters
----------
target_dir : str, default='/tmp/<dataset_name>/'
Directory where all files exist or will be downloaded to (if `download` is True).
download : bool, default=True
If True and files don't exist in `target_dir`, downloads all files to `target_dir`.
force_download : bool, default=False
If True, downloads all files to `target_dir`, even if they exist there.
load : bool, default=True
If True, loads data from files in `target_dir`.
"""

self.target_dir = (
os.path.join(TEMPORARY_DIR, type(self).__name__)
if target_dir is None
else target_dir
)

self._train_images: Optional[np.ndarray] = None
self._train_labels: Optional[np.ndarray] = None
self._test_images: Optional[np.ndarray] = None
self._test_labels: Optional[np.ndarray] = None

if download or force_download:
self.download(force_download)

if load:
self.load()

def train_images(self) -> Optional[np.ndarray]:
return self._train_images

def train_labels(self) -> Optional[np.ndarray]:
return self._train_labels

def test_images(self) -> Optional[np.ndarray]:
return self._test_images

def test_labels(self) -> Optional[np.ndarray]:
return self._test_labels

def download(self, force: bool = False) -> None:
"""
Download files from mirrors and save to `target_dir`.
Parameters
----------
force : bool=False
If True, downloads all files even if they exist.
"""

os.makedirs(self.target_dir, exist_ok=True)

for filename, md5 in self.resources.values():
filepath = os.path.join(self.target_dir, filename)

if not force and check_file_integrity(filepath, md5):
continue

download_file(self.mirrors, filename, filepath)

def load(self) -> None:
"""
Load data from files in `target_dir`.
"""

for key, (filename, md5) in self.resources.items():
filepath = os.path.join(self.target_dir, filename)

if not check_file_integrity(filepath, md5):
raise RuntimeError(
f"Dataset '{key}' not found in '{filepath}' or MD5 "
"checksum is not valid. "
"Use download=True or .download() to download it"
)

data = read_idx_file(filepath)
setattr(self, f"_{key}", data)


class FashionMNIST(MNIST):
class FashionMNIST(IdxDataset):
"""
Fashion-MNIST Dataset
https://github.com/zalandoresearch/fashion-mnist
Expand Down Expand Up @@ -260,7 +165,7 @@ class FashionMNIST(MNIST):
}


class KMNIST(MNIST):
class KMNIST(IdxDataset):
"""
Kuzushiji-MNIST Dataset
https://github.com/rois-codh/kmnist
Expand Down
Loading

0 comments on commit 3088bd7

Please sign in to comment.