-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BN before softmax in attention. More options in advnet and softat…
…tnet
- Loading branch information
Showing
5 changed files
with
543 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
''' | ||
Training script for CIFAR-10/100 | ||
Copyright (c) Wei YANG, 2017 | ||
cifar10: | ||
cifar100: | ||
classes = ('beaver', 'dolphin', 'otter', 'seal', 'whale', 'aquarium fish', | ||
'flatfish', 'ray', 'shark', 'trout', 'orchids', 'poppies', | ||
'roses', 'sunflowers', 'tulips', 'bottles', 'bowls', 'cans', | ||
'cups', 'plates', 'apples', 'mushrooms', 'oranges', 'pears', | ||
'sweet peppers', 'clock', 'computer keyboard', 'lamp', 'telephone', | ||
'television', 'bed', 'chair', 'couch', 'table', 'wardrobe', 'bee', | ||
'beetle', 'butterfly', 'caterpillar', 'cockroach', 'bear', 'leopard', | ||
'lion', 'tiger', 'wolf', 'bridge', 'castle', 'house', 'road', 'skyscraper', | ||
'cloud', 'forest', 'mountain', 'plain', 'sea', 'camel', 'cattle', 'chimpanzee', | ||
'elephant', 'kangaroo', 'fox', 'porcupine', 'possum', 'raccoon', 'skunk', 'crab', | ||
'lobster', 'snail', 'spider', 'worm', 'baby', 'boy', 'girl', 'man', 'woman', | ||
'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle', 'hamster', 'mouse', 'rabbit', | ||
'shrew', 'squirrel', 'maple', 'oak', 'palm', 'pine', 'willow', 'bicycle', 'bus', | ||
'motorcycle', 'pickup truck', 'train', 'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor') | ||
''' | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import os | ||
import shutil | ||
import time | ||
import random | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim as optim | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
import matplotlib.pyplot as plt | ||
import models | ||
|
||
from utils import * | ||
|
||
|
||
model_names = sorted(name for name in models.__dict__ | ||
if name.islower() and not name.startswith("__") | ||
and callable(models.__dict__[name])) | ||
|
||
parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') | ||
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', | ||
choices=model_names, | ||
help='model architecture: ' + | ||
' | '.join(model_names) + | ||
' (default: resnet18)') | ||
parser.add_argument('-d', '--dataset', default='cifar10', type=str) | ||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | ||
help='number of data loading workers (default: 4)') | ||
parser.add_argument('--epochs', default=164, type=int, metavar='N', | ||
help='number of total epochs to run') | ||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', | ||
help='manual epoch number (useful on restarts)') | ||
parser.add_argument('--train_batch', default=128, type=int, metavar='N', | ||
help='train batchsize') | ||
parser.add_argument('--test_batch', default=100, type=int, metavar='N', | ||
help='test batchsize') | ||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, | ||
metavar='LR', help='initial learning rate') | ||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', | ||
help='momentum') | ||
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, | ||
metavar='W', help='weight decay (default: 1e-4)') | ||
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', | ||
help='path to save checkpoint (default: checkpoint)') | ||
parser.add_argument('--resume', default='', type=str, metavar='PATH', | ||
help='path to latest checkpoint (default: none)') | ||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', | ||
help='evaluate model on validation set') | ||
parser.add_argument('--manualSeed', type=int, help='manual seed') | ||
|
||
args = parser.parse_args() | ||
|
||
# Validate dataset | ||
assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.' | ||
|
||
# Use CUDA | ||
use_cuda = torch.cuda.is_available() | ||
|
||
# Random seed | ||
if args.manualSeed is None: | ||
args.manualSeed = random.randint(1, 10000) | ||
random.seed(args.manualSeed) | ||
torch.manual_seed(args.manualSeed) | ||
if use_cuda: | ||
torch.cuda.manual_seed_all(args.manualSeed) | ||
|
||
best_acc = 0 # best test accuracy | ||
|
||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | ||
|
||
def main(): | ||
global best_acc | ||
start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch | ||
|
||
if not os.path.isdir(args.checkpoint): | ||
mkdir_p(args.checkpoint) | ||
|
||
|
||
|
||
# Data | ||
print('==> Preparing dataset %s' % args.dataset) | ||
transform_train = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
if args.dataset == 'cifar10': | ||
dataloader = datasets.CIFAR10 | ||
num_classes = 10 | ||
else: | ||
dataloader = datasets.CIFAR100 | ||
num_classes = 100 | ||
|
||
|
||
trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) | ||
trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) | ||
|
||
testset = dataloader(root='./data', train=False, download=False, transform=transform_test) | ||
testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) | ||
|
||
# Model | ||
print("=> creating model '{}'".format(args.arch)) | ||
model = models.__dict__[args.arch](num_classes=num_classes) | ||
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): | ||
model.features = torch.nn.DataParallel(model.features) | ||
model.cuda() | ||
else: | ||
model = torch.nn.DataParallel(model).cuda() | ||
cudnn.benchmark = True | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) | ||
|
||
# Resume | ||
title = 'cifar-10-' + args.arch | ||
if args.resume: | ||
# Load checkpoint. | ||
print('==> Resuming from checkpoint..') | ||
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' | ||
args.checkpoint = os.path.dirname(args.resume) | ||
checkpoint = torch.load(args.resume) | ||
best_acc = checkpoint['best_acc'] | ||
start_epoch = checkpoint['epoch'] | ||
model.load_state_dict(checkpoint['state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer']) | ||
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) | ||
else: | ||
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) | ||
logger.set_names(['Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) | ||
|
||
|
||
if args.evaluate: | ||
print('\nEvaluation only') | ||
test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) | ||
print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc*100)) | ||
return | ||
|
||
# Train and val | ||
for epoch in range(start_epoch, args.epochs): | ||
lr = adjust_learning_rate(optimizer, epoch) | ||
|
||
print('\nEpoch: [%d | %d] LR: %f' % (epoch, args.epochs, lr)) | ||
|
||
train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) | ||
test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) | ||
|
||
print(' Train Loss: %.8f, Train Acc: %.2f' % (train_loss, train_acc*100)) | ||
print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc*100)) | ||
|
||
# append logger file | ||
logger.append([train_loss, test_loss, train_acc, test_acc]) | ||
|
||
# save model | ||
is_best = test_acc > best_acc | ||
best_acc = max(test_acc, best_acc) | ||
save_checkpoint({ | ||
'epoch': epoch + 1, | ||
'state_dict': model.state_dict(), | ||
'acc': test_acc, | ||
'best_acc': best_acc, | ||
'optimizer' : optimizer.state_dict(), | ||
}, is_best, checkpoint=args.checkpoint) | ||
|
||
logger.close() | ||
logger.plot() | ||
savefig(os.path.join(args.checkpoint, 'log.eps')) | ||
|
||
print('Best acc:') | ||
print(best_acc) | ||
|
||
def train(trainloader, model, criterion, optimizer, epoch, use_cuda): | ||
model.train() | ||
train_loss = 0 | ||
correct = 0 | ||
total = 0 | ||
for batch_idx, (inputs, targets) in enumerate(trainloader): | ||
if use_cuda: | ||
inputs, targets = inputs.cuda(), targets.cuda() | ||
optimizer.zero_grad() | ||
inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) | ||
outputs = model(inputs) | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss += loss.data[0] | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += targets.size(0) | ||
correct += predicted.eq(targets.data).cpu().sum() | ||
|
||
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | ||
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) | ||
return (train_loss/total, correct*1.0/total) | ||
|
||
def test(testloader, model, criterion, epoch, use_cuda): | ||
global best_acc | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
total = 0 | ||
for batch_idx, (inputs, targets) in enumerate(testloader): | ||
if use_cuda: | ||
inputs, targets = inputs.cuda(), targets.cuda() | ||
inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) | ||
outputs = model(inputs) | ||
|
||
# attention | ||
att_mask_list = model._modules['module'].get_mask() | ||
num_att = len(att_mask_list) | ||
|
||
# show_mask(inputs.data.cpu(), att_mask_list, Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010)) | ||
|
||
print('GroundTruth: ', ' '.join('%5s' % classes[target] for target in targets.data)) | ||
|
||
for att_mask in att_mask_list: | ||
# print('Attention: Max %f | min %f' % (att.max(), att.min())) | ||
# print(inputs.data.size()) | ||
show_mask_single(inputs.data.cpu(), att_mask.data.cpu(), Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010)) | ||
plt.show() | ||
|
||
|
||
# for att_mask in att_mask_list: | ||
# for att_ in att_mask: | ||
# att = att_.data | ||
# print('Attention: Max %f | min %f' % (att.max(), att.min())) | ||
# show_mask(inputs.data.cpu(), att.cpu(), Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010)) | ||
|
||
loss = criterion(outputs, targets) | ||
|
||
test_loss += loss.data[0] | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += targets.size(0) | ||
correct += predicted.eq(targets.data).cpu().sum() | ||
# print('Predicted: ', ' '.join('%5s' % classes[pred] for pred in predicted)) | ||
|
||
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | ||
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) | ||
return (test_loss/total, correct*1.0/total) | ||
|
||
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): | ||
filepath = os.path.join(checkpoint, filename) | ||
torch.save(state, filepath) | ||
if is_best: | ||
shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) | ||
|
||
def adjust_learning_rate(optimizer, epoch): | ||
deday = 0 | ||
if epoch >= 81: | ||
deday = 1 | ||
elif epoch >= 122: | ||
deday = 2 | ||
lr = args.lr * (0.1 ** deday) | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = lr | ||
return lr | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.