Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sjtudyq authored Jul 28, 2022
1 parent 66a5999 commit 5c9d271
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4):
X_train, y_train, X_test, y_test = load_celeba_data(datadir)
elif dataset == 'femnist':
X_train, y_train, u_train, X_test, y_test, u_test = load_femnist_data(datadir)
elif dataset == 'cifar100':
X_train, y_train, X_test, y_test = load_cifar100_data(datadir)
elif dataset == 'tinyimagenet':
X_train, y_train, X_test, y_test = load_tinyimagenet_data(datadir)
elif dataset == 'generated':
X_train, y_train = [], []
for loc in range(4):
Expand Down Expand Up @@ -278,6 +282,10 @@ def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4):
if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'):
K = 2
# min_require_size = 100
if dataset == 'cifar100':
K = 100
elif dataset == 'tinyimagenet':
K = 200

N = y_train.shape[0]
#np.random.seed(2020)
Expand Down Expand Up @@ -628,7 +636,7 @@ def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, noise_level=0, net_id=None, total=0):
if dataset in ('mnist', 'femnist', 'fmnist', 'cifar10', 'svhn', 'generated', 'covtype', 'a9a', 'rcv1', 'SUSY'):
if dataset in ('mnist', 'femnist', 'fmnist', 'cifar10', 'svhn', 'generated', 'covtype', 'a9a', 'rcv1', 'SUSY', 'cifar100', 'tinyimagenet'):
if dataset == 'mnist':
dl_obj = MNIST_truncated

Expand Down Expand Up @@ -686,6 +694,40 @@ def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, noise_lev
transform_test = transforms.Compose([
transforms.ToTensor(),
AddGaussianNoise(0., noise_level, net_id, total)])

elif dataset == 'cifar100':
dl_obj = CIFAR100_truncated

normalize = transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
# transform_train = transforms.Compose([
# transforms.RandomCrop(32),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# normalize
# ])
transform_train = transforms.Compose([
# transforms.ToPILImage(),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
normalize
])
# data prep for test set
transform_test = transforms.Compose([
transforms.ToTensor(),
normalize])
elif dataset == 'tinyimagenet':
dl_obj = ImageFolder_custom
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

else:
dl_obj = Generated
Expand Down

0 comments on commit 5c9d271

Please sign in to comment.