Skip to content

Commit

Permalink
Add BN before softmax in attention. More options in advnet and softat…
Browse files Browse the repository at this point in the history
…tnet
  • Loading branch information
bearpaw committed May 15, 2017
1 parent f7fe8d6 commit c869f55
Show file tree
Hide file tree
Showing 5 changed files with 543 additions and 64 deletions.
292 changes: 292 additions & 0 deletions cifar-att.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
'''
Training script for CIFAR-10/100
Copyright (c) Wei YANG, 2017
cifar10:
cifar100:
classes = ('beaver', 'dolphin', 'otter', 'seal', 'whale', 'aquarium fish',
'flatfish', 'ray', 'shark', 'trout', 'orchids', 'poppies',
'roses', 'sunflowers', 'tulips', 'bottles', 'bowls', 'cans',
'cups', 'plates', 'apples', 'mushrooms', 'oranges', 'pears',
'sweet peppers', 'clock', 'computer keyboard', 'lamp', 'telephone',
'television', 'bed', 'chair', 'couch', 'table', 'wardrobe', 'bee',
'beetle', 'butterfly', 'caterpillar', 'cockroach', 'bear', 'leopard',
'lion', 'tiger', 'wolf', 'bridge', 'castle', 'house', 'road', 'skyscraper',
'cloud', 'forest', 'mountain', 'plain', 'sea', 'camel', 'cattle', 'chimpanzee',
'elephant', 'kangaroo', 'fox', 'porcupine', 'possum', 'raccoon', 'skunk', 'crab',
'lobster', 'snail', 'spider', 'worm', 'baby', 'boy', 'girl', 'man', 'woman',
'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle', 'hamster', 'mouse', 'rabbit',
'shrew', 'squirrel', 'maple', 'oak', 'palm', 'pine', 'willow', 'bicycle', 'bus',
'motorcycle', 'pickup truck', 'train', 'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor')
'''
from __future__ import print_function

import argparse
import os
import shutil
import time
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import models

from utils import *


model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-d', '--dataset', default='cifar10', type=str)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=164, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--train_batch', default=128, type=int, metavar='N',
help='train batchsize')
parser.add_argument('--test_batch', default=100, type=int, metavar='N',
help='test batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--manualSeed', type=int, help='manual seed')

args = parser.parse_args()

# Validate dataset
assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.'

# Use CUDA
use_cuda = torch.cuda.is_available()

# Random seed
if args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if use_cuda:
torch.cuda.manual_seed_all(args.manualSeed)

best_acc = 0 # best test accuracy

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def main():
global best_acc
start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch

if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint)



# Data
print('==> Preparing dataset %s' % args.dataset)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
if args.dataset == 'cifar10':
dataloader = datasets.CIFAR10
num_classes = 10
else:
dataloader = datasets.CIFAR100
num_classes = 100


trainset = dataloader(root='./data', train=True, download=True, transform=transform_train)
trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers)

testset = dataloader(root='./data', train=False, download=False, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers)

# Model
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch](num_classes=num_classes)
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

# Resume
title = 'cifar-10-' + args.arch
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
args.checkpoint = os.path.dirname(args.resume)
checkpoint = torch.load(args.resume)
best_acc = checkpoint['best_acc']
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
else:
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
logger.set_names(['Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])


if args.evaluate:
print('\nEvaluation only')
test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda)
print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc*100))
return

# Train and val
for epoch in range(start_epoch, args.epochs):
lr = adjust_learning_rate(optimizer, epoch)

print('\nEpoch: [%d | %d] LR: %f' % (epoch, args.epochs, lr))

train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)

print(' Train Loss: %.8f, Train Acc: %.2f' % (train_loss, train_acc*100))
print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc*100))

# append logger file
logger.append([train_loss, test_loss, train_acc, test_acc])

# save model
is_best = test_acc > best_acc
best_acc = max(test_acc, best_acc)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'acc': test_acc,
'best_acc': best_acc,
'optimizer' : optimizer.state_dict(),
}, is_best, checkpoint=args.checkpoint)

logger.close()
logger.plot()
savefig(os.path.join(args.checkpoint, 'log.eps'))

print('Best acc:')
print(best_acc)

def train(trainloader, model, criterion, optimizer, epoch, use_cuda):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

train_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
return (train_loss/total, correct*1.0/total)

def test(testloader, model, criterion, epoch, use_cuda):
global best_acc
model.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)
outputs = model(inputs)

# attention
att_mask_list = model._modules['module'].get_mask()
num_att = len(att_mask_list)

# show_mask(inputs.data.cpu(), att_mask_list, Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010))

print('GroundTruth: ', ' '.join('%5s' % classes[target] for target in targets.data))

for att_mask in att_mask_list:
# print('Attention: Max %f | min %f' % (att.max(), att.min()))
# print(inputs.data.size())
show_mask_single(inputs.data.cpu(), att_mask.data.cpu(), Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010))
plt.show()


# for att_mask in att_mask_list:
# for att_ in att_mask:
# att = att_.data
# print('Attention: Max %f | min %f' % (att.max(), att.min()))
# show_mask(inputs.data.cpu(), att.cpu(), Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010))

loss = criterion(outputs, targets)

test_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
# print('Predicted: ', ' '.join('%5s' % classes[pred] for pred in predicted))

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
return (test_loss/total, correct*1.0/total)

def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
filepath = os.path.join(checkpoint, filename)
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer, epoch):
deday = 0
if epoch >= 81:
deday = 1
elif epoch >= 122:
deday = 2
lr = args.lr * (0.1 ** deday)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr

if __name__ == '__main__':
main()
13 changes: 10 additions & 3 deletions models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@

class SoftmaxAttention(nn.Module):
# implementation of Wang et al. "Residual Attention Network for Image Classification". CVPR, 2017.
def __init__(self, planes, residual=True):
def __init__(self, planes, residual=True, normalize=False):
super(SoftmaxAttention, self).__init__()
self.residual = residual
self.normalize = normalize
self.bn1 = nn.BatchNorm2d(planes)
self.conv1 = nn.Conv2d(planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, 1, kernel_size=1, bias=False)
if self.normalize == True:
self.bn3 = nn.BatchNorm2d(1)
self.softmax = nn.Softmax()
self.mask = None

Expand All @@ -34,6 +37,10 @@ def forward(self, x):
mask = self.bn2(mask)
mask = self.relu(mask)
mask = self.conv2(mask)
# print('min: %.4f | max: %.4f' % (mask.data.min(), mask.data.max()))
if self.normalize == True:
mask = self.bn3(mask)
# print('min: %.4f | max: %.4f' % (mask.data.min(), mask.data.max()))
mask = mask.view(mask.size(0), -1)
mask = self.softmax(mask)
mask = mask.view(mask.size(0), 1, x.size(2), x.size(3))
Expand All @@ -43,15 +50,15 @@ def forward(self, x):


out = x * mask.expand_as(x)
if self.residual:
if self.residual == True:
out += x

return out


class SigmoidAttention(nn.Module):
# implementation of Wang et al. "Residual Attention Network for Image Classification". CVPR, 2017.
def __init__(self, planes, residual=True):
def __init__(self, planes, residual=True, normalize=False):
super(SigmoidAttention, self).__init__()
self.residual = residual
self.bn1 = nn.BatchNorm2d(planes)
Expand Down
Loading

0 comments on commit c869f55

Please sign in to comment.