-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a814c8c
commit 2a54fbf
Showing
9 changed files
with
808 additions
and
343 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# reference from official zeroQ | ||
# https://github.com/amirgholami/ZeroQ/blob/ba37f793dbcb9f966b58f6b8d1e9de3c34a11b8c/classification/utils/data_utils.py | ||
# @file Different utility functions | ||
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami | ||
# All rights reserved. | ||
# This file is part of ZeroQ repository. | ||
# | ||
# ZeroQ is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# ZeroQ is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with ZeroQ repository. If not, see <http://www.gnu.org/licenses/>. | ||
#* | ||
|
||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision import datasets, transforms | ||
import torch | ||
|
||
|
||
class UniformDataset(Dataset): | ||
""" | ||
get random uniform samples with mean 0 and variance 1 | ||
""" | ||
def __init__(self, length, size, transform): | ||
self.length = length | ||
self.transform = transform | ||
self.size = size | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, idx): | ||
# var[U(-128, 127)] = (127 - (-128))**2 / 12 = 5418.75 | ||
sample = (torch.randint(high=255, size=self.size).float() - | ||
127.5) / 5418.75 | ||
return sample | ||
|
||
|
||
def getRandomData(dataset='cifar10', batch_size=512, for_inception=False): | ||
""" | ||
get random sample dataloader | ||
dataset: name of the dataset | ||
batch_size: the batch size of random data | ||
for_inception: whether the data is for Inception because inception has input size 299 rather than 224 | ||
""" | ||
if dataset == 'cifar10': | ||
size = (3, 32, 32) | ||
num_data = 10000 | ||
elif dataset == 'imagenet': | ||
num_data = 10000 | ||
if not for_inception: | ||
size = (3, 224, 224) | ||
else: | ||
size = (3, 299, 299) | ||
else: | ||
raise NotImplementedError | ||
dataset = UniformDataset(length=10000, size=size, transform=None) | ||
data_loader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=32) | ||
return data_loader | ||
|
||
|
||
def getTestData(dataset='imagenet', | ||
batch_size=1024, | ||
path='data/imagenet', | ||
for_inception=False): | ||
""" | ||
Get dataloader of testset | ||
dataset: name of the dataset | ||
batch_size: the batch size of random data | ||
path: the path to the data | ||
for_inception: whether the data is for Inception because inception has input size 299 rather than 224 | ||
""" | ||
if dataset == 'imagenet': | ||
input_size = 299 if for_inception else 224 | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
test_dataset = datasets.ImageFolder( | ||
path + 'val', | ||
transforms.Compose([ | ||
transforms.Resize(int(input_size / 0.875)), | ||
transforms.CenterCrop(input_size), | ||
transforms.ToTensor(), | ||
normalize, | ||
])) | ||
test_loader = DataLoader(test_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=32) | ||
return test_loader | ||
elif dataset == 'cifar10': | ||
data_dir = '/rscratch/yaohuic/data/' | ||
normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), | ||
std=(0.2023, 0.1994, 0.2010)) | ||
transform_test = transforms.Compose([transforms.ToTensor(), normalize]) | ||
|
||
test_dataset = datasets.CIFAR10(root=data_dir, | ||
train=False, | ||
transform=transform_test) | ||
test_loader = DataLoader(test_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=32) | ||
return test_loader |
Oops, something went wrong.