Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multimnist models and dataset #8

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ def parse_arguments():
parser.add_argument(
"-a", "--arch", metavar="ARCH", default="ResNet18", help="model architecture"
)
parser.add_argument(
"--num-train-examples",
type=int,
default=None,
help="Number of train examples to use for MultiMNIST",
)
parser.add_argument(
"--num-val-examples",
type=int,
default=None,
help="Number of val examples to use for MultiMNIST",
)
parser.add_argument(
"--num-concat",
help="Number of digits to concat MultiMNIST dataset",
type=int,
default=None,
)
parser.add_argument(
"--config", help="Config file to use (see configs dir)", default=None
)
Expand Down Expand Up @@ -244,7 +262,8 @@ def parse_arguments():

args = parser.parse_args()

get_config(args)
if args.config is not None:
get_config(args)

return args

Expand Down
35 changes: 35 additions & 0 deletions configs/smallscale/multimnist/lenet5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Architecture
arch: LeNet5

# ===== Dataset ===== #
data: /usr/data
set: MultiMNIST
name: baseline
num_train_examples: 5000000
num_val_examples: 50000
num_concat: 5
num_classes: 100000


# ===== Learning Rate Policy ======== #
optimizer: sgd
lr: 0.1
lr_policy: cosine_lr
warmup_length: 5

# ===== Network training config ===== #
epochs: 100
weight_decay: 0.0001
momentum: 0.9
batch_size: 256


# ===== Sparsity =========== #
conv_type: DenseConv
bn_type: LearnedBatchNorm
init: kaiming_normal
mode: fan_in
nonlinearity: relu

# ===== Hardware setup ===== #
workers: 12
3 changes: 2 additions & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from data.imagenet import ImageNet
from data.imagenet import TinyImageNet
from data.imagenet import TinyImageNet
from data.mnist import MultiMNIST
102 changes: 102 additions & 0 deletions data/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Any, Callable, Optional, Tuple
from torchvision import datasets, transforms
from PIL import Image
from args import args
import os
import torch
import torchvision
import numpy as np


class MultiMNIST:
def __init__(self, args):
super(MultiMNIST, self).__init__()

data_root = os.path.join(args.data, "mnist")

use_cuda = torch.cuda.is_available()

# Data loading code
kwargs = {"num_workers": args.workers, "pin_memory": True} if use_cuda else {}
self.train_loader = torch.utils.data.DataLoader(
MultiMNISTDataset(
data_root,
train=True,
download=True,
num_concat=args.num_concat,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)
self.val_loader = torch.utils.data.DataLoader(
MultiMNISTDataset(
data_root,
num_concat=args.num_concat,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
)


class MultiMNISTDataset(datasets.MNIST):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download : bool = False,
num_concat : int = 1,
) -> None:
super().__init__(
root,
train=train,
transform=transform,
target_transform=target_transform,
download=download,
)

self.length = int(super().__len__() ** num_concat)
if self.train:
self.length = args.num_train_examples or self.length
else:
self.length = args.num_val_examples or self.length

self.num_concat = num_concat

def __len__(self):
return self.length

def __getitem__(self, index: int) -> Tuple[Any, Any]:
# Pick 4 random
if self.train:
rng = np.random.RandomState(index*2)
else:
rng = np.random.RandomState(index*2 + 1)


indices = rng.randint(0, super().__len__(), (self.num_concat,))
img, target = self.data[indices], self.targets[indices]
base = 10 ** torch.arange(self.num_concat - 1, -1, -1)

img = torch.cat([img[i] for i in range(self.num_concat)], dim=-1)
target = (base * target).sum()

img = Image.fromarray(img.numpy(), mode='L')

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from models.resnet import ResNet18, ResNet50
from models.mobilenetv1 import MobileNetV1
from models.lenet import LeNet5

__all__ = [
"ResNet18",
Expand Down
39 changes: 39 additions & 0 deletions models/lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Lenet-5 implementation from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/model.py
"""

from args import args
from torch.nn import Module
from torch import nn


class LeNet5(Module):
def __init__(self):
super(LeNet5, self).__init__()

self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256 + 448 * (args.num_concat - 1), 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, args.num_classes)

def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
return y
31 changes: 31 additions & 0 deletions tests/multimnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys, os
sys.path.append(os.path.abspath('.'))

from args import args
from data import MultiMNIST
from collections import defaultdict

import seaborn as sns
import matplotlib.pyplot as plt
import tqdm

args.data = "/usr/data"
args.num_train_examples = 200000
args.num_val_examples = 50000
args.num_concat = 4
args.workers = 16

mnist = MultiMNIST(args)

label_counts = defaultdict(int)
for i in tqdm.tqdm(range(len(mnist.train_loader.dataset)), ascii=True):
_, label = mnist.train_loader.dataset[i]
label_counts[label.item()] += 1


fig, (ax1, ax2) = plt.subplots(2)

sns.kdeplot(label_counts, ax=ax1)

plt.plot(*zip(*sorted(label_counts.items())))
plt.savefig("tests/images/multimnist.pdf", bbox_inches="tight")
2 changes: 2 additions & 0 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def train(train_loader, model, criterion, optimizer, epoch, args, writer):

loss = criterion(output, target.view(-1))


# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))

losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
top5.update(acc5.item(), images.size(0))
Expand Down
2 changes: 1 addition & 1 deletion utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def accuracy(output, target, topk=(1,)):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res