diff --git a/datasets.py b/datasets.py index 5c989c4..2db8b55 100644 --- a/datasets.py +++ b/datasets.py @@ -11,6 +11,17 @@ def __init__(self, args): self.test_loader = torch.utils.data.DataLoader( datasets.MNIST('data/mnist', train=False, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs) + +class Celldata(object): + def __init__(self, args): + kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} + self.train_loader = torch.utils.data.DataLoader( + datasets.Celldata('data/mnist', train=True, download=True, + transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True, **kwargs) + self.test_loader = torch.utils.data.DataLoader( + datasets.Celldata('data/mnist', train=False, transform=transforms.ToTensor()), + batch_size=args.batch_size, shuffle=True, **kwargs) class EMNIST(object): def __init__(self, args):