diff --git a/cifar-att.py b/cifar-att.py new file mode 100644 index 0000000..3894b48 --- /dev/null +++ b/cifar-att.py @@ -0,0 +1,292 @@ +''' +Training script for CIFAR-10/100 +Copyright (c) Wei YANG, 2017 + +cifar10: +cifar100: + classes = ('beaver', 'dolphin', 'otter', 'seal', 'whale', 'aquarium fish', + 'flatfish', 'ray', 'shark', 'trout', 'orchids', 'poppies', + 'roses', 'sunflowers', 'tulips', 'bottles', 'bowls', 'cans', + 'cups', 'plates', 'apples', 'mushrooms', 'oranges', 'pears', + 'sweet peppers', 'clock', 'computer keyboard', 'lamp', 'telephone', + 'television', 'bed', 'chair', 'couch', 'table', 'wardrobe', 'bee', + 'beetle', 'butterfly', 'caterpillar', 'cockroach', 'bear', 'leopard', + 'lion', 'tiger', 'wolf', 'bridge', 'castle', 'house', 'road', 'skyscraper', + 'cloud', 'forest', 'mountain', 'plain', 'sea', 'camel', 'cattle', 'chimpanzee', + 'elephant', 'kangaroo', 'fox', 'porcupine', 'possum', 'raccoon', 'skunk', 'crab', + 'lobster', 'snail', 'spider', 'worm', 'baby', 'boy', 'girl', 'man', 'woman', + 'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle', 'hamster', 'mouse', 'rabbit', + 'shrew', 'squirrel', 'maple', 'oak', 'palm', 'pine', 'willow', 'bicycle', 'bus', + 'motorcycle', 'pickup truck', 'train', 'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor') +''' +from __future__ import print_function + +import argparse +import os +import shutil +import time +import random + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import matplotlib.pyplot as plt +import models + +from utils import * + + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-d', '--dataset', default='cifar10', type=str) +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=164, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('--train_batch', default=128, type=int, metavar='N', + help='train batchsize') +parser.add_argument('--test_batch', default=100, type=int, metavar='N', + help='test batchsize') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', + help='path to save checkpoint (default: checkpoint)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--manualSeed', type=int, help='manual seed') + +args = parser.parse_args() + +# Validate dataset +assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.' + +# Use CUDA +use_cuda = torch.cuda.is_available() + +# Random seed +if args.manualSeed is None: + args.manualSeed = random.randint(1, 10000) +random.seed(args.manualSeed) +torch.manual_seed(args.manualSeed) +if use_cuda: + torch.cuda.manual_seed_all(args.manualSeed) + +best_acc = 0 # best test accuracy + +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + +def main(): + global best_acc + start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch + + if not os.path.isdir(args.checkpoint): + mkdir_p(args.checkpoint) + + + + # Data + print('==> Preparing dataset %s' % args.dataset) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + if args.dataset == 'cifar10': + dataloader = datasets.CIFAR10 + num_classes = 10 + else: + dataloader = datasets.CIFAR100 + num_classes = 100 + + + trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) + trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) + + testset = dataloader(root='./data', train=False, download=False, transform=transform_test) + testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) + + # Model + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch](num_classes=num_classes) + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() + cudnn.benchmark = True + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + # Resume + title = 'cifar-10-' + args.arch + if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' + args.checkpoint = os.path.dirname(args.resume) + checkpoint = torch.load(args.resume) + best_acc = checkpoint['best_acc'] + start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) + else: + logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) + logger.set_names(['Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) + + + if args.evaluate: + print('\nEvaluation only') + test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) + print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc*100)) + return + + # Train and val + for epoch in range(start_epoch, args.epochs): + lr = adjust_learning_rate(optimizer, epoch) + + print('\nEpoch: [%d | %d] LR: %f' % (epoch, args.epochs, lr)) + + train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) + test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) + + print(' Train Loss: %.8f, Train Acc: %.2f' % (train_loss, train_acc*100)) + print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc*100)) + + # append logger file + logger.append([train_loss, test_loss, train_acc, test_acc]) + + # save model + is_best = test_acc > best_acc + best_acc = max(test_acc, best_acc) + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'acc': test_acc, + 'best_acc': best_acc, + 'optimizer' : optimizer.state_dict(), + }, is_best, checkpoint=args.checkpoint) + + logger.close() + logger.plot() + savefig(os.path.join(args.checkpoint, 'log.eps')) + + print('Best acc:') + print(best_acc) + +def train(trainloader, model, criterion, optimizer, epoch, use_cuda): + model.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + optimizer.zero_grad() + inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + train_loss += loss.data[0] + _, predicted = torch.max(outputs.data, 1) + total += targets.size(0) + correct += predicted.eq(targets.data).cpu().sum() + + progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + return (train_loss/total, correct*1.0/total) + +def test(testloader, model, criterion, epoch, use_cuda): + global best_acc + model.eval() + test_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(testloader): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) + outputs = model(inputs) + + # attention + att_mask_list = model._modules['module'].get_mask() + num_att = len(att_mask_list) + + # show_mask(inputs.data.cpu(), att_mask_list, Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010)) + + print('GroundTruth: ', ' '.join('%5s' % classes[target] for target in targets.data)) + + for att_mask in att_mask_list: + # print('Attention: Max %f | min %f' % (att.max(), att.min())) + # print(inputs.data.size()) + show_mask_single(inputs.data.cpu(), att_mask.data.cpu(), Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010)) + plt.show() + + + # for att_mask in att_mask_list: + # for att_ in att_mask: + # att = att_.data + # print('Attention: Max %f | min %f' % (att.max(), att.min())) + # show_mask(inputs.data.cpu(), att.cpu(), Mean=(0.4914, 0.4822, 0.4465), Std=(0.2023, 0.1994, 0.2010)) + + loss = criterion(outputs, targets) + + test_loss += loss.data[0] + _, predicted = torch.max(outputs.data, 1) + total += targets.size(0) + correct += predicted.eq(targets.data).cpu().sum() + # print('Predicted: ', ' '.join('%5s' % classes[pred] for pred in predicted)) + + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + return (test_loss/total, correct*1.0/total) + +def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): + filepath = os.path.join(checkpoint, filename) + torch.save(state, filepath) + if is_best: + shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) + +def adjust_learning_rate(optimizer, epoch): + deday = 0 + if epoch >= 81: + deday = 1 + elif epoch >= 122: + deday = 2 + lr = args.lr * (0.1 ** deday) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/models/modules.py b/models/modules.py index c4b5fc3..f81a0a3 100644 --- a/models/modules.py +++ b/models/modules.py @@ -12,14 +12,17 @@ class SoftmaxAttention(nn.Module): # implementation of Wang et al. "Residual Attention Network for Image Classification". CVPR, 2017. - def __init__(self, planes, residual=True): + def __init__(self, planes, residual=True, normalize=False): super(SoftmaxAttention, self).__init__() self.residual = residual + self.normalize = normalize self.bn1 = nn.BatchNorm2d(planes) self.conv1 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, 1, kernel_size=1, bias=False) + if self.normalize == True: + self.bn3 = nn.BatchNorm2d(1) self.softmax = nn.Softmax() self.mask = None @@ -34,6 +37,10 @@ def forward(self, x): mask = self.bn2(mask) mask = self.relu(mask) mask = self.conv2(mask) + # print('min: %.4f | max: %.4f' % (mask.data.min(), mask.data.max())) + if self.normalize == True: + mask = self.bn3(mask) + # print('min: %.4f | max: %.4f' % (mask.data.min(), mask.data.max())) mask = mask.view(mask.size(0), -1) mask = self.softmax(mask) mask = mask.view(mask.size(0), 1, x.size(2), x.size(3)) @@ -43,7 +50,7 @@ def forward(self, x): out = x * mask.expand_as(x) - if self.residual: + if self.residual == True: out += x return out @@ -51,7 +58,7 @@ def forward(self, x): class SigmoidAttention(nn.Module): # implementation of Wang et al. "Residual Attention Network for Image Classification". CVPR, 2017. - def __init__(self, planes, residual=True): + def __init__(self, planes, residual=True, normalize=False): super(SigmoidAttention, self).__init__() self.residual = residual self.bn1 = nn.BatchNorm2d(planes) diff --git a/models/resadvnet.py b/models/resadvnet.py index 0fec9dd..c71b89a 100644 --- a/models/resadvnet.py +++ b/models/resadvnet.py @@ -9,51 +9,49 @@ from .hourglass import Hourglass from .modules import SoftmaxAttention -__all__ = ['Attention', 'ResAdvNet', 'resadvnet20', 'resadvnet32', 'resadvnet44', 'resadvnet56', - 'resadvnet110', 'resadvnet1202'] - -class Attention(nn.Module): - # implementation of Wang et al. "Residual Attention Network for Image Classification". CVPR, 2017. - def __init__(self, block, p, t, r, planes, depth): - super(Attention, self).__init__() - self.p = p - self.t = t - out_planes = planes*block.expansion - self.residual = block(out_planes, planes) - self.hourglass = Hourglass(block, r, planes, depth) - self.fc1 = nn.Conv2d(out_planes, out_planes, kernel_size=1, bias=False) - self.fc2 = nn.Conv2d(out_planes, 1, kernel_size=1, bias=False) - - def get_mask(self): - return self.mx - - def forward(self, x): - # preprocessing - for i in range(0, self.p): - x = self.residual(x) - - # trunk branch - tx = x - for i in range(0, self.p): - tx = self.residual(tx) - - # mask branch - self.mx = F.sigmoid(self.fc2(self.fc1(self.hourglass(x)))) - - # residual attented feature - out = tx + tx*self.mx.expand_as(tx) - - return out +# __all__ = ['Attention', 'ResAdvNet', 'resadvnet20', 'resadvnet32', 'resadvnet44', 'resadvnet56', +# 'resadvnet110', 'resadvnet1202'] + +# class Attention(nn.Module): +# # implementation of Wang et al. "Residual Attention Network for Image Classification". CVPR, 2017. +# def __init__(self, block, p, t, r, planes, depth): +# super(Attention, self).__init__() +# self.p = p +# self.t = t +# out_planes = planes*block.expansion +# self.residual = block(out_planes, planes) +# self.hourglass = Hourglass(block, r, planes, depth) +# self.fc1 = nn.Conv2d(out_planes, out_planes, kernel_size=1, bias=False) +# self.fc2 = nn.Conv2d(out_planes, 1, kernel_size=1, bias=False) + +# def get_mask(self): +# return self.mx + +# def forward(self, x): +# # preprocessing +# for i in range(0, self.p): +# x = self.residual(x) + +# # trunk branch +# tx = x +# for i in range(0, self.p): +# tx = self.residual(tx) + +# # mask branch +# self.mx = F.sigmoid(self.fc2(self.fc1(self.hourglass(x)))) + +# # residual attented feature +# out = tx + tx*self.mx.expand_as(tx) + +# return out class StackedAdversary(nn.Module): - def __init__(self, block, p, t, r, planes, num_stacks=3, depth=3): + def __init__(self, block, planes, num_stacks=3, residual=False, normalize=False): super(StackedAdversary, self).__init__() self.num_stacks = num_stacks - # self.attention = Attention(block, p, t, r, planes, depth) - # self.attention = SoftmaxAttention(planes) attentions = [] for s in range(0, self.num_stacks): - attentions.append(SoftmaxAttention(planes, residual=False)) + attentions.append(SoftmaxAttention(planes, normalize=normalize, residual=residual)) self.attention = nn.ModuleList(attentions) def get_mask(self): @@ -77,20 +75,20 @@ def forward(self, x): class ResAdvNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): + def __init__(self, block, layers, residual=False, normalize=False, num_classes=1000): self.inplanes = 16 super(ResAdvNet, self).__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False) self.layer1 = self._make_layer(block, 16, layers[0]) - self.adv1 = StackedAdversary(block, 1, 2, 1, 16 * block.expansion, - num_stacks=5, depth=3) + self.adv1 = StackedAdversary(block, 16 * block.expansion, + num_stacks=5, residual=False, normalize=False) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) - self.adv2 = StackedAdversary(block, 1, 2, 1, 32 * block.expansion, - num_stacks=5, depth=3) + self.adv2 = StackedAdversary(block, 32 * block.expansion, + num_stacks=5, residual=False, normalize=False) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) - self.adv3 = StackedAdversary(block, 1, 2, 1, 64 * block.expansion, - num_stacks=5, depth=3) + self.adv3 = StackedAdversary(block, 64 * block.expansion, + num_stacks=5, residual=False, normalize=False) self.bn = nn.BatchNorm2d(64 * block.expansion) self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(8) @@ -183,4 +181,88 @@ def resadvnet1202(**kwargs): """Constructs a ResAdvNet-1202 model. """ model = ResAdvNet(Bottleneck, [200, 200, 200], **kwargs) + return model + +# ---------------------------- + +def resadvbn20(**kwargs): + """Constructs a ResAdvNet-20 model. + """ + model = ResAdvNet(BasicBlock, [3, 3, 3], normalize=True, **kwargs) + return model + + +def resadvbn32(**kwargs): + """Constructs a ResAdvNet-32 model. + """ + model = ResAdvNet(BasicBlock, [5, 5, 5], normalize=True, **kwargs) + return model + + +def resadvbn44(**kwargs): + """Constructs a ResAdvNet-44 model. + """ + model = ResAdvNet(Bottleneck, [7, 7, 7], normalize=True, **kwargs) + return model + + +def resadvbn56(**kwargs): + """Constructs a ResAdvNet-56 model. + """ + model = ResAdvNet(Bottleneck, [9, 9, 9], normalize=True, **kwargs) + return model + + +def resadvbn110(**kwargs): + """Constructs a ResAdvNet-110 model. + """ + model = ResAdvNet(Bottleneck, [18, 18, 18], normalize=True, **kwargs) + return model + +def resadvbn1202(**kwargs): + """Constructs a ResAdvNet-1202 model. + """ + model = ResAdvNet(Bottleneck, [200, 200, 200], normalize=True, **kwargs) + return model + +# ------------------------------------- + +def resadvbnres20(**kwargs): + """Constructs a ResAdvNet-20 model. + """ + model = ResAdvNet(BasicBlock, [3, 3, 3], normalize=True, residual=True, **kwargs) + return model + + +def resadvbnres32(**kwargs): + """Constructs a ResAdvNet-32 model. + """ + model = ResAdvNet(BasicBlock, [5, 5, 5], normalize=True, residual=True, **kwargs) + return model + + +def resadvbnres44(**kwargs): + """Constructs a ResAdvNet-44 model. + """ + model = ResAdvNet(Bottleneck, [7, 7, 7], normalize=True, residual=True, **kwargs) + return model + + +def resadvbnres56(**kwargs): + """Constructs a ResAdvNet-56 model. + """ + model = ResAdvNet(Bottleneck, [9, 9, 9], normalize=True, residual=True, **kwargs) + return model + + +def resadvbnres110(**kwargs): + """Constructs a ResAdvNet-110 model. + """ + model = ResAdvNet(Bottleneck, [18, 18, 18], normalize=True, residual=True, **kwargs) + return model + +def resadvbnres1202(**kwargs): + """Constructs a ResAdvNet-1202 model. + """ + model = ResAdvNet(Bottleneck, [200, 200, 200], normalize=True, residual=True, **kwargs) return model \ No newline at end of file diff --git a/models/ressoftattnet.py b/models/ressoftattnet.py index 87b33c2..9b33be4 100644 --- a/models/ressoftattnet.py +++ b/models/ressoftattnet.py @@ -12,12 +12,12 @@ from .modules import * -__all__ = ['ResSoftAttNet', 'ressoftattnet20', 'ressoftattnet32', 'ressoftattnet44', 'ressoftattnet56', - 'ressoftattnet110', 'ressoftattnet1202'] +# __all__ = ['ResSoftAttNet', 'ressoftattnet20', 'ressoftattnet32', 'ressoftattnet44', 'ressoftattnet56', +# 'ressoftattnet110', 'ressoftattnet1202'] class ResSoftAttNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): + def __init__(self, block, layers, normalize=False, residual=False, num_classes=1000): self.inplanes = 16 super(ResSoftAttNet, self).__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, @@ -26,7 +26,7 @@ def __init__(self, block, layers, num_classes=1000): self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) - self.att = SoftmaxAttention(64 * block.expansion) + self.att = SoftmaxAttention(64 * block.expansion, normalize=normalize, residual=residual) self.bn = nn.BatchNorm2d(64 * block.expansion) self.avgpool = nn.AvgPool2d(8) self.fc = nn.Linear(64 * block.expansion, num_classes) @@ -115,4 +115,86 @@ def ressoftattnet1202(**kwargs): """Constructs a ResSoftAttNet-1202 model. """ model = ResSoftAttNet(Bottleneck, [200, 200, 200], **kwargs) + return model + +# -------------------------------------- +def ressoftattbn20(**kwargs): + """Constructs a ResSoftAttNet-20 model. + """ + model = ResSoftAttNet(BasicBlock, [3, 3, 3], normalize=True, **kwargs) + return model + + +def ressoftattbn32(**kwargs): + """Constructs a ResSoftAttNet-32 model. + """ + model = ResSoftAttNet(BasicBlock, [5, 5, 5], normalize=True, **kwargs) + return model + + +def ressoftattbn44(**kwargs): + """Constructs a ResSoftAttNet-44 model. + """ + model = ResSoftAttNet(Bottleneck, [7, 7, 7], normalize=True, **kwargs) + return model + + +def ressoftattbn56(**kwargs): + """Constructs a ResSoftAttNet-56 model. + """ + model = ResSoftAttNet(Bottleneck, [9, 9, 9], normalize=True, **kwargs) + return model + + +def ressoftattbn110(**kwargs): + """Constructs a ResSoftAttNet-110 model. + """ + model = ResSoftAttNet(Bottleneck, [18, 18, 18], normalize=True, **kwargs) + return model + +def ressoftattbn1202(**kwargs): + """Constructs a ResSoftAttNet-1202 model. + """ + model = ResSoftAttNet(Bottleneck, [200, 200, 200], normalize=True, **kwargs) + return model + +# -------------------------------------- +def ressoftattbnres20(**kwargs): + """Constructs a ResSoftAttNet-20 model. + """ + model = ResSoftAttNet(BasicBlock, [3, 3, 3], normalize=True, residual=True, **kwargs) + return model + + +def ressoftattbnres32(**kwargs): + """Constructs a ResSoftAttNet-32 model. + """ + model = ResSoftAttNet(BasicBlock, [5, 5, 5], normalize=True, residual=True, **kwargs) + return model + + +def ressoftattbnres44(**kwargs): + """Constructs a ResSoftAttNet-44 model. + """ + model = ResSoftAttNet(Bottleneck, [7, 7, 7], normalize=True, residual=True, **kwargs) + return model + + +def ressoftattbnres56(**kwargs): + """Constructs a ResSoftAttNet-56 model. + """ + model = ResSoftAttNet(Bottleneck, [9, 9, 9], normalize=True, residual=True, **kwargs) + return model + + +def ressoftattbnres110(**kwargs): + """Constructs a ResSoftAttNet-110 model. + """ + model = ResSoftAttNet(Bottleneck, [18, 18, 18], normalize=True, residual=True, **kwargs) + return model + +def ressoftattbnres1202(**kwargs): + """Constructs a ResSoftAttNet-1202 model. + """ + model = ResSoftAttNet(Bottleneck, [200, 200, 200], normalize=True, residual=True, **kwargs) return model \ No newline at end of file diff --git a/utils/visualize.py b/utils/visualize.py index 3bf0892..51abeed 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -6,7 +6,7 @@ import numpy as np from .misc import * -__all__ = ['make_image', 'show_batch', 'show_mask'] +__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] # functions to show an image def make_image(img, mean=(0,0,0), std=(1,1,1)): @@ -41,17 +41,34 @@ def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): plt.show() -# def show_mask(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): -# plt.figure(1) -# images = imshow(torchvision.utils.make_grid(images), Mean, Std) -# plt.subplot(211) -# plt.imshow(images) +def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): + im_size = images.size(2) -# masks = imshow(torchvision.utils.make_grid(mask)) -# plt.subplot(212) -# plt.imshow(masks) + # save for adding mask + im_data = images.clone() + for i in range(0, 3): + im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize -# plt.show() + images = make_image(torchvision.utils.make_grid(images), Mean, Std) + plt.subplot(2, 1, 1) + plt.imshow(images) + plt.axis('off') + + # for b in range(mask.size(0)): + # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) + mask_size = mask.size(2) + # print('Max %f Min %f' % (mask.max(), mask.min())) + mask = (upsampling(mask, scale_factor=im_size/mask_size)) + # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) + # for c in range(3): + # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] + + # print(mask.size()) + mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) + # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) + plt.subplot(2, 1, 2) + plt.imshow(mask) + plt.axis('off') def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): im_size = images.size(2) @@ -68,7 +85,6 @@ def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): for i in range(len(masklist)): mask = masklist[i].data.cpu() - # print(mask.size()) # for b in range(mask.size(0)): # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) mask_size = mask.size(2)