Skip to content

Commit

Permalink
Add PSPNet (#179)
Browse files Browse the repository at this point in the history
* psp
  • Loading branch information
zhanghang1989 authored Jun 26, 2018
1 parent e915c0b commit 5921740
Show file tree
Hide file tree
Showing 15 changed files with 495 additions and 204 deletions.
18 changes: 8 additions & 10 deletions docs/api/model_zoo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,10 @@ Object Detection
faster_rcnn_resnet50_v2a_coco


.. currentmodule:: gluoncv.model_zoo

Semantic Segmentation
^^^^^^^^^^^^^^^^^^^^^

.. :hidden:`BaseModel`
.. ~~~~~~~~~~~~~~~~~~~
.. .. autosummary::
.. :nosignatures:
.. segbase.SegBaseModel
.. currentmodule:: gluoncv.model_zoo

:hidden:`FCN`
~~~~~~~~~~~~~
Expand All @@ -148,11 +140,17 @@ Semantic Segmentation

get_fcn_ade_resnet50

:hidden:`PSPNet`
~~~~~~~~~~~~~~~~

.. autosummary::
:nosignatures:

PSPNet

get_psp


get_psp_ade_resnet50


API Reference
Expand Down
17 changes: 2 additions & 15 deletions docs/model_zoo/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,25 +231,12 @@ Table of pre-trained models for semantic segmentation and their performance.
+-------------------+--------------+-----------+-----------+-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| fcn_resnet50_ade | FCN [6]_ | 78.6 | 38.7 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet50_ade.log>`_ |
+-------------------+--------------+-----------+-----------+-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| psp_resnet50_ade | PSP [9]_ | 78.4 | 41.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet50_ade.log>`_ |
+-------------------+--------------+-----------+-----------+-----------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+

.. _69.4: http://host.robots.ox.ac.uk:8080/anonymous/TC12D2.html
.. _70.9: http://host.robots.ox.ac.uk:8080/anonymous/FTIQXJ.html

.. raw:: html

<code xml:space="preserve" id="cmd_fcn_50" style="display: none; text-align: left; white-space: pre-wrap">
# First training on augmented set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_aug --model fcn --backbone resnet50 --lr 0.001 --syncbn --checkname mycheckpoint
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_voc --model fcn --backbone resnet50 --lr 0.0001 --syncbn --checkname mycheckpoint --resume runs/pascal_aug/fcn/mycheckpoint/checkpoint.params
</code>

<code xml:space="preserve" id="cmd_fcn_101" style="display: none; text-align: left; white-space: pre-wrap">
# First training on augmented set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_aug --model fcn --backbone resnet101 --lr 0.001 --syncbn --checkname mycheckpoint
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pascal_voc --model fcn --backbone resnet101 --lr 0.0001 --syncbn --checkname mycheckpoint --resume runs/pascal_aug/fcn/mycheckpoint/checkpoint.params
</code>

.. [1] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. \
"Deep residual learning for image recognition." \
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/segmentation/demo_fcn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""1. Getting Started with FCN Pre-trained Models
==============================================
This is a quick demo of using GluonCV FCN model.
This is a quick demo of using GluonCV FCN model on PASCAL VOC dataset.
Please follow the `installation guide <../index.html>`_ to install MXNet and GluonCV if not yet.
"""
import mxnet as mx
Expand Down
65 changes: 65 additions & 0 deletions docs/tutorials/segmentation/demo_psp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""2. Test with PSPNet Pre-trained Models
======================================
This is a quick demo of using GluonCV PSPNet model on ADE20K dataset.
Please follow the `installation guide <../index.html>`_ to install MXNet and GluonCV if not yet.
"""
import mxnet as mx
from mxnet import image
from mxnet.gluon.data.vision import transforms
import gluoncv
# using cpu
ctx = mx.cpu(0)


##############################################################################
# Prepare the image
# -----------------
#
# download the example image
url = 'https://github.com/zhanghang1989/image-data/blob/master/encoding/' + \
'segmentation/ade20k/ADE_val_00001142.jpg?raw=true'
filename = 'ade20k_example.jpg'
gluoncv.utils.download(url, filename)

##############################################################################
# load the image
img = image.imread(filename)

from matplotlib import pyplot as plt
plt.imshow(img.asnumpy())
plt.show()

##############################################################################
# normalize the image using dataset mean
transform_fn = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])
img = transform_fn(img)
img = img.expand_dims(0).as_in_context(ctx)

##############################################################################
# Load the pre-trained model and make prediction
# ----------------------------------------------
#
# get pre-trained model
model = gluoncv.model_zoo.get_model('psp_resnet50_ade', pretrained=True)

##############################################################################
# make prediction using single scale
output = model.demo(img)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()

##############################################################################
# Add color pallete for visualization
from gluoncv.utils.viz import get_color_pallete
import matplotlib.image as mpimg
mask = get_color_pallete(predict, 'ade20k')
mask.save('output.png')

##############################################################################
# show the predicted mask
mmask = mpimg.imread('output.png')
plt.imshow(mmask)
plt.show()
6 changes: 3 additions & 3 deletions docs/tutorials/segmentation/train_fcn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""2. Train FCN on Pascal VOC Dataset
"""3. Train FCN on Pascal VOC Dataset
=====================================
This is a semantic segmentation tutorial using Gluon Vison, a step-by-step example.
Expand Down Expand Up @@ -133,8 +133,8 @@
# For example, we can easily get the Pascal VOC 2012 dataset:
trainset = gluoncv.data.VOCSegmentation(split='train', transform=input_transform)
print('Training images:', len(trainset))
# set batch_size = 4 for toy example
batch_size = 4
# set batch_size = 2 for toy example
batch_size = 2
# Create Training Loader
train_data = gluon.data.DataLoader(
trainset, batch_size, shuffle=True, last_batch='rollover',
Expand Down
231 changes: 231 additions & 0 deletions docs/tutorials/segmentation/train_psp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""3. Train PSPNet on ADE20K Dataset
=================================
This is a tutorial of training PSPNet on ADE20K dataset using Gluon Vison.
The readers should have basic knowledge of deep learning and should be familiar with Gluon API.
New users may first go through `A 60-minute Gluon Crash Course <http://gluon-crash-course.mxnet.io/>`_.
You can `Start Training Now`_ or `Dive into Deep`_.
Start Training Now
~~~~~~~~~~~~~~~~~~
.. note::
Training PSPNet relies on Synchronized Batch Normalization, which will be available shortly.
.. hint::
Feel free to skip the tutorial because the training script is self-complete and ready to launch.
:download:`Download Full Python Script: train.py<../../../scripts/segmentation/train.py>`
Example training command::
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ade20k --model psp --backbone resnet50 --lr 0.001 --checkname mycheckpoint
For more training command options, please run ``python train.py -h``
Please checkout the `model_zoo <../model_zoo/index.html#semantic-segmentation>`_ for training commands of reproducing the pretrained model.
Dive into Deep
~~~~~~~~~~~~~~
"""
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd
import gluoncv

##############################################################################
# Pyramid Scene Parsing Network
# -----------------------------
#
# .. image:: https://hszhao.github.io/projects/pspnet/figures/pspnet.png
# :width: 80%
# :align: center
#
# (figure credit to `Zhao et al. <https://arxiv.org/pdf/1612.01105.pdf>`_ )
#
# Pyramid Scene Parsing Network (PSPNet) [Zhao17]_ exploit the
# capability of global context information by different-regionbased
# context aggregation through the pyramid pooling module.
#


##############################################################################
# PSPNet Model
# ------------
#
# A Pyramid Pooling Module is built on top of FCN, which combines multiple scale
# features with different receptive field sizes. It pools the featuremaps
# into different sizes and then concatinating together after upsampling.
#
# The Pyramid Pooling Module is defined as::
#
# class _PyramidPooling(HybridBlock):
# def __init__(self, in_channels, **kwargs):
# super(_PyramidPooling, self).__init__()
# out_channels = int(in_channels/4)
# with self.name_scope():
# self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
# self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
# self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
# self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
#
# def pool(self, F, x, size):
# return F.contrib.AdaptiveAvgPooling2D(x, output_size=size)
#
# def upsample(self, F, x, h, w):
# return F.contrib.BilinearResize2D(x, height=h, width=w)
#
# def hybrid_forward(self, F, x):
# _, _, h, w = x.shape
# feat1 = self.upsample(F, self.conv1(self.pool(F, x, 1)), h, w)
# feat2 = self.upsample(F, self.conv2(self.pool(F, x, 2)), h, w)
# feat3 = self.upsample(F, self.conv3(self.pool(F, x, 3)), h, w)
# feat4 = self.upsample(F, self.conv4(self.pool(F, x, 4)), h, w)
# return F.concat(x, feat1, feat2, feat3, feat4, dim=1)
#
# PSPNet model is provided in :class:`gluoncv.model_zoo.PSPNet`. To get
# PSP model using ResNet50 base network for ADE20K dataset:
model = gluoncv.model_zoo.get_psp(dataset='ade20k', backbone='resnet50', pretrained=False)
print(model)

##############################################################################
# Dataset and Data Augmentation
# -----------------------------
#
# image transform for color normalization
from mxnet.gluon.data.vision import transforms
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])

##############################################################################
# We provide semantic segmentation datasets in :class:`gluoncv.data`.
# For example, we can easily get the ADE20K dataset:
trainset = gluoncv.data.ADE20KSegmentation(split='train', transform=input_transform)
print('Training images:', len(trainset))
# set batch_size = 2 for toy example
batch_size = 2
# Create Training Loader
train_data = gluon.data.DataLoader(
trainset, batch_size, shuffle=True, last_batch='rollover',
num_workers=batch_size)

##############################################################################
# For data augmentation,
# we follow the standard data augmentation routine to transform the input image
# and the ground truth label map synchronously. (*Note that "nearest"
# mode upsample are applied to the label maps to avoid messing up the boundaries.*)
# We first randomly scale the input image from 0.5 to 2.0 times, then rotate
# the image from -10 to 10 degrees, and crop the image with padding if needed.
# Finally a random Gaussian blurring is applied.
#
# Random pick one example for visualization:
import random
from datetime import datetime
random.seed(datetime.now())
idx = random.randint(0, len(trainset))
img, mask = trainset[idx]
from gluoncv.utils.viz import get_color_pallete, DeNormalize
# get color pallete for visualize mask
mask = get_color_pallete(mask.asnumpy(), dataset='ade20k')
mask.save('mask.png')
# denormalize the image
img = DeNormalize([.485, .456, .406], [.229, .224, .225])(img)
img = np.transpose((img.asnumpy()*255).astype(np.uint8), (1, 2, 0))

##############################################################################
# Plot the image and mask
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
# subplot 1 for img
fig = plt.figure()
fig.add_subplot(1,2,1)

plt.imshow(img)
# subplot 2 for the mask
mmask = mpimg.imread('mask.png')
fig.add_subplot(1,2,2)
plt.imshow(mmask)
# display
plt.show()

##############################################################################
# Training Details
# ----------------
#
# - Training Losses:
#
# We apply a standard per-pixel Softmax Cross Entropy Loss to train PSPNet.
# Additionally, an Auxiliary Loss as in PSPNet [Zhao17]_ at Stage 3 can be enabled when
# training with command ``--aux``. This will create an additional FCN "head" after Stage 3.
#
from gluoncv.model_zoo.segbase import SoftmaxCrossEntropyLossWithAux
criterion = SoftmaxCrossEntropyLossWithAux(aux=True)

##############################################################################
# - Learning Rate and Scheduling:
#
# We use different learning rate for PSP "head" and the base network. For the PSP "head",
# we use :math:`10\times` base learning rate, because those layers are learned from scratch.
# We use a poly-like learning rate scheduler for FCN training, provided in :class:`gluoncv.utils.LRScheduler`.
# The learning rate is given by :math:`lr = baselr \times (1-iter)^{power}`
#
lr_scheduler = gluoncv.utils.LRScheduler(mode='poly', baselr=0.001, niters=len(train_data),
nepochs=50)

##############################################################################
# - Dataparallel for multi-gpu training, using cpu for demo only
from gluoncv.utils.parallel import *
ctx_list = [mx.cpu(0)]
model = DataParallelModel(model, ctx_list)
criterion = DataParallelCriterion(criterion, ctx_list)

##############################################################################
# - Create SGD solver
kv = mx.kv.create('local')
optimizer = gluon.Trainer(model.module.collect_params(), 'sgd',
{'lr_scheduler': lr_scheduler,
'wd':0.0001,
'momentum': 0.9,
'multi_precision': True},
kvstore = kv)

##############################################################################
# The training loop
# -----------------
#
train_loss = 0.0
epoch = 0
for i, (data, target) in enumerate(train_data):
lr_scheduler.update(i, epoch)
with autograd.record(True):
outputs = model(data)
losses = criterion(outputs, target)
mx.nd.waitall()
autograd.backward(losses)
optimizer.step(batch_size)
for loss in losses:
train_loss += loss.asnumpy()[0] / len(losses)
print('Epoch %d, batch %d, training loss %.3f'%(epoch, i, train_loss/(i+1)))
# just demo for 2 iters
if i > 1:
print('Terminated for this demo...')
break


##############################################################################
# You can `Start Training Now`_.
#
# References
# ----------
#
# .. [Long15] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \
# "Fully convolutional networks for semantic segmentation." \
# Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.
#
# .. [Zhao17] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \
# "Pyramid scene parsing network." IEEE Conf. on Computer Vision and Pattern Recognition (CVPR). 2017.
#

Loading

0 comments on commit 5921740

Please sign in to comment.