forked from cure-lab/SmoothNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
59 lines (47 loc) · 2.06 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import importlib
from abc import ABC, abstractmethod
import torch.utils.data as data
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, cfg):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.cfg = cfg
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def find_dataset_using_name(dataset_name):
"""
Import the module "lib/dataset/[dataset_name]_dataset.py".
"""
dataset_filename = "lib.dataset." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError(
"In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase."
% (dataset_filename, target_dataset_name))
return dataset