-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataloader.py
144 lines (131 loc) · 6.25 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import yaml
import numpy as np
import pickle
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as T
from tqdm import tqdm
from yacs.config import CfgNode
from maskrcnn_benchmark.config import cfg as maskrcnn_cfg
from utils.tupperware import tupperware
from utils.build_transforms import build_transforms
from data.benchmark_mir import CLDataLoader, get_permuted_mnist, get_split_mnist, get_miniimagenet, get_rotated_mnist, \
get_split_cifar10, get_split_cifar100, IIDDataset, FuzzyCLDataLoader
from utils.utils import DotDict, get_config_attr
import random
_dataset = {
}
_smnist_loaders = None
def get_split_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
fuzzy = get_config_attr(cfg,'EXTERNAL.OCL.FUZZY', default=0, mute=True)
d = DotDict()
global _smnist_loaders
if not _smnist_loaders:
data = get_split_mnist(d, cfg)
loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
for elem, t in zip(data, [True, False, False])]
_smnist_loaders = train_loader, val_loader, test_loader
else:
train_loader, val_loader, test_loader = _smnist_loaders
if split == 'train':
return train_loader[filter_obj[0]]
elif split == 'val':
return val_loader[filter_obj[0]]
elif split == 'test':
return test_loader[filter_obj[0]]
_rmnist_loaders = None
def get_rotated_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs):
d = DotDict()
fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
global _rmnist_loaders
if not _rmnist_loaders:
data = get_rotated_mnist(d)
#train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
# for elem, t in zip(data, [True, False, False])]
loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
for elem, t in zip(data, [True, False, False])]
_rmnist_loaders = train_loader, val_loader, test_loader
else:
train_loader, val_loader, test_loader = _rmnist_loaders
if split == 'train':
return train_loader[filter_obj[0]]
elif split == 'val':
return val_loader[filter_obj[0]]
elif split == 'test':
return test_loader[filter_obj[0]]
_pmnist_loaders = None
def get_permute_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs):
d = DotDict()
fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
global _pmnist_loaders
if not _pmnist_loaders:
data = get_permuted_mnist(d)
loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
for elem, t in zip(data, [True, False, False])]
_pmnist_loaders = train_loader, val_loader, test_loader
else:
train_loader, val_loader, test_loader = _pmnist_loaders
if split == 'train':
return train_loader[filter_obj[0]]
elif split == 'val':
return val_loader[filter_obj[0]]
elif split == 'test':
return test_loader[filter_obj[0]]
_cache_cifar = None
def get_split_cifar_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
d = DotDict()
fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
global _cache_cifar
if not _cache_cifar:
data = get_split_cifar10(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
for elem, t in zip(data, [True, False, False])]
_cache_cifar = train_loader, val_loader, test_loader
train_loader, val_loader, test_loader = _cache_cifar
if split == 'train':
return train_loader[filter_obj[0]]
elif split == 'val':
return val_loader[filter_obj[0]]
elif split == 'test':
return test_loader[filter_obj[0]]
_cache_cifar100 = None
def get_split_cifar100_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
d = DotDict()
fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
global _cache_cifar100
if not _cache_cifar100:
data = get_split_cifar100(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
for elem, t in zip(data, [True, False, False])]
_cache_cifar100 = train_loader, val_loader, test_loader
train_loader, val_loader, test_loader = _cache_cifar100
if split == 'train':
return train_loader[filter_obj[0]]
elif split == 'val':
return val_loader[filter_obj[0]]
elif split == 'test':
return test_loader[filter_obj[0]]
_cache_mini_imagenet = None
def get_split_mini_imagenet_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
global _cache_mini_imagenet
d = DotDict()
fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
if not _cache_mini_imagenet:
data = get_miniimagenet(d)
loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
for elem, t in zip(data, [True, False, False])]
_cache_mini_imagenet = train_loader, val_loader, test_loader
train_loader, val_loader, test_loader = _cache_mini_imagenet
if split == 'train':
return train_loader[filter_obj[0]]
elif split == 'val':
return val_loader[filter_obj[0]]
elif split == 'test':
return test_loader[filter_obj[0]]