Skip to content

Commit

Permalink
Add deeplabv3+b model for semantic segmentation (#924)
Browse files Browse the repository at this point in the history
* new seg model for citys

* pylint

* add model store

* remove inplace add

* remove copy

* add init

* fix pylint

* fix pylint

* fix shape

* add unit test

* add init

* model zoo
  • Loading branch information
bryanyzhu authored Sep 3, 2019
1 parent 9ea81b4 commit 384f30e
Show file tree
Hide file tree
Showing 9 changed files with 634 additions and 7 deletions.
21 changes: 16 additions & 5 deletions docs/model_zoo/segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,23 @@ Pascal VOC Dataset
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnet152_voc | DeepLabV3 [4]_ | N/A | 86.7_ | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet152_voc.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet152_voc.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| psp_resnet101_citys | PSP [3]_ | N/A | 77.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+

.. _83.6: http://host.robots.ox.ac.uk:8080/anonymous/YB1AN5.html
.. _85.1: http://host.robots.ox.ac.uk:8080/anonymous/9RTTZC.html
.. _86.2: http://host.robots.ox.ac.uk:8080/anonymous/ZPN6II.html
.. _86.7: http://host.robots.ox.ac.uk:8080/anonymous/XZEXL2.html

Cityscapes Dataset
------------------

+-------------------------------------+-----------------+-----------+-----------+---------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------+
| Name | Method | pixAcc | mIoU | Command | log |
+=====================================+=================+===========+===========+=============================================================================================================================================+====================================================================================================================================+
| psp_resnet101_citys | PSP [3]_ | N/A | 77.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.log>`_ |
+-------------------------------------+-----------------+-----------+-----------+---------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------+
| deeplab_v3b_plus_wideresnet_citys | VPLR [5]_ | N/A | 83.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_v3b_plus_wideresnet_citys.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_v3b_plus_wideresnet_citys.log>`_ |
+-------------------------------------+-----------------+-----------+-----------+---------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------+


Instance Segmentation
~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -125,12 +134,14 @@ MS COCO
+------------------------------------+---------------------------+--------------------------+------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+

.. [1] He, Kaming, Georgia Gkioxari, Piotr Dollár and Ross Girshick. \
"Mask R-CNN." \
In IEEE International Conference on Computer Vision (ICCV), 2017.
"Mask R-CNN." \
In IEEE International Conference on Computer Vision (ICCV), 2017.
.. [2] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \
"Fully convolutional networks for semantic segmentation." \
Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.
.. [3] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \
"Pyramid scene parsing network." *CVPR*, 2017
"Pyramid scene parsing network." *CVPR*, 2017.
.. [4] Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation." \
arXiv preprint arXiv:1706.05587 (2017).
.. [5] Zhu, Yi, et al. "Improving Semantic Segmentation via Video Propagation and Label Relaxation." \
CVPR 2019.
2 changes: 2 additions & 0 deletions gluoncv/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from .pspnet import *
from .deeplabv3 import *
from .deeplabv3_plus import *
from .deeplabv3b_plus import *
from . import segbase
from .resnetv1b import *
from .se_resnet import *
from .nasnet import *
from .simple_pose.simple_pose_resnet import *
from .action_recognition import *
from .wideresnet import *

from .alexnet import *
from .densenet import *
Expand Down
266 changes: 266 additions & 0 deletions gluoncv/model_zoo/deeplabv3b_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""DeepLabV3+ with wideresnet backbone for semantic segmentation"""
# pylint: disable=missing-docstring,arguments-differ,unused-argument
from mxnet.gluon import nn
from mxnet.context import cpu
from mxnet.gluon.nn import HybridBlock
from .wideresnet import wider_resnet38_a2

__all__ = ['DeepLabWV3Plus', 'get_deeplabv3b_plus', 'get_deeplab_v3b_plus_wideresnet_citys']

class DeepLabWV3Plus(HybridBlock):
r"""DeepLabWV3Plus
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'wideresnet').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
aux : bool
Auxiliary loss.
Reference:
Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic
Image Segmentation.", https://arxiv.org/abs/1802.02611, ECCV 2018
"""
def __init__(self, nclass, backbone='wideresnet', aux=False, ctx=cpu(), pretrained_base=True,
height=None, width=None, base_size=520, crop_size=480, dilated=True, **kwargs):
super(DeepLabWV3Plus, self).__init__()

height = height if height is not None else crop_size
width = width if width is not None else crop_size
self._up_kwargs = {'height': height, 'width': width}
self.base_size = base_size
self.crop_size = crop_size
print('self.crop_size', self.crop_size)

with self.name_scope():
pretrained = wider_resnet38_a2(classes=1000, dilation=True)
pretrained.initialize(ctx=ctx)
self.mod1 = pretrained.mod1
self.mod2 = pretrained.mod2
self.mod3 = pretrained.mod3
self.mod4 = pretrained.mod4
self.mod5 = pretrained.mod5
self.mod6 = pretrained.mod6
self.mod7 = pretrained.mod7
self.pool2 = pretrained.pool2
self.pool3 = pretrained.pool3
del pretrained
self.head = _DeepLabHead(nclass, height=height//2, width=width//2, **kwargs)
self.head.initialize(ctx=ctx)

def hybrid_forward(self, F, x):
outputs = []
x = self.mod1(x)
m2 = self.mod2(self.pool2(x))
x = self.mod3(self.pool3(m2))
x = self.mod4(x)
x = self.mod5(x)
x = self.mod6(x)
x = self.mod7(x)
x = self.head(x, m2)
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
outputs.append(x)
return tuple(outputs)

def demo(self, x):
return self.predict(x)

def predict(self, x):
h, w = x.shape[2:]
self._up_kwargs['height'] = h
self._up_kwargs['width'] = w
x = self.mod1(x)
m2 = self.mod2(self.pool2(x))
x = self.mod3(self.pool3(m2))
x = self.mod4(x)
x = self.mod5(x)
x = self.mod6(x)
x = self.mod7(x)
x = self.head.demo(x, m2)
import mxnet.ndarray as F
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
return x

class _DeepLabHead(HybridBlock):
def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm, norm_kwargs=None,
height=240, width=240, **kwargs):
super(_DeepLabHead, self).__init__()
self._up_kwargs = {'height': height, 'width': width}
with self.name_scope():
self.aspp = _ASPP(in_channels=4096, atrous_rates=[12, 24, 36], norm_layer=norm_layer,
norm_kwargs=norm_kwargs, height=height//4, width=width//4, **kwargs)

self.c1_block = nn.HybridSequential(prefix='bot_fine_')
self.c1_block.add(nn.Conv2D(in_channels=c1_channels, channels=48,
kernel_size=1, use_bias=False))

self.block = nn.HybridSequential(prefix='final_')
self.block.add(nn.Conv2D(in_channels=304, channels=256,
kernel_size=3, padding=1, use_bias=False))
self.block.add(norm_layer(in_channels=256,
**({} if norm_kwargs is None else norm_kwargs)))
self.block.add(nn.Activation('relu'))
self.block.add(nn.Conv2D(in_channels=256, channels=256,
kernel_size=3, padding=1, use_bias=False))
self.block.add(norm_layer(in_channels=256,
**({} if norm_kwargs is None else norm_kwargs)))
self.block.add(nn.Activation('relu'))
self.block.add(nn.Conv2D(in_channels=256, channels=nclass,
kernel_size=1, use_bias=False))

def hybrid_forward(self, F, x, c1):
c1 = self.c1_block(c1)
x = self.aspp(x)
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
return self.block(F.concat(c1, x, dim=1))

def demo(self, x, c1):
h, w = c1.shape[2:]
self._up_kwargs['height'] = h
self._up_kwargs['width'] = w
c1 = self.c1_block(c1)
x = self.aspp.demo(x)
import mxnet.ndarray as F
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
return self.block(F.concat(c1, x, dim=1))

def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs):
block = nn.HybridSequential()
with block.name_scope():
block.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
kernel_size=3, padding=atrous_rate,
dilation=atrous_rate, use_bias=False))
block.add(norm_layer(in_channels=out_channels,
**({} if norm_kwargs is None else norm_kwargs)))
block.add(nn.Activation('relu'))
return block

class _AsppPooling(nn.HybridBlock):
def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs,
height=60, width=60, **kwargs):
super(_AsppPooling, self).__init__()
self.gap = nn.HybridSequential()
self._up_kwargs = {'height': height, 'width': width}
with self.gap.name_scope():
self.gap.add(nn.GlobalAvgPool2D())
self.gap.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
kernel_size=1, use_bias=False))
self.gap.add(norm_layer(in_channels=out_channels,
**({} if norm_kwargs is None else norm_kwargs)))
self.gap.add(nn.Activation("relu"))

def hybrid_forward(self, F, x):
pool = self.gap(x)
return F.contrib.BilinearResize2D(pool, **self._up_kwargs)

def demo(self, x):
h, w = x.shape[2:]
self._up_kwargs['height'] = h
self._up_kwargs['width'] = w
pool = self.gap(x)
import mxnet.ndarray as F
return F.contrib.BilinearResize2D(pool, **self._up_kwargs)

class _ASPP(nn.HybridBlock):
def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs,
height=60, width=60):
super(_ASPP, self).__init__()
out_channels = 256
self.b0 = nn.HybridSequential()
self.b0.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
kernel_size=1, use_bias=False))
self.b0.add(norm_layer(in_channels=out_channels,
**({} if norm_kwargs is None else norm_kwargs)))
self.b0.add(nn.Activation("relu"))

rate1, rate2, rate3 = tuple(atrous_rates)
self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs)
self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs)
self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs)
self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer,
norm_kwargs=norm_kwargs, height=height, width=width)

self.project = nn.HybridSequential(prefix='bot_aspp_')
self.project.add(nn.Conv2D(in_channels=5*out_channels, channels=out_channels,
kernel_size=1, use_bias=False))

def hybrid_forward(self, F, x):
feat1 = self.b0(x)
feat2 = self.b1(x)
feat3 = self.b2(x)
feat4 = self.b3(x)
x = self.b4(x)
x = F.concat(x, feat1, feat2, feat3, feat4, dim=1)
return self.project(x)

def demo(self, x):
feat1 = self.b0(x)
feat2 = self.b1(x)
feat3 = self.b2(x)
feat4 = self.b3(x)
x = self.b4.demo(x)
import mxnet.ndarray as F
x = F.concat(x, feat1, feat2, feat3, feat4, dim=1)
return self.project(x)

def get_deeplabv3b_plus(dataset='citys', backbone='wideresnet', pretrained=False,
root='~/.mxnet/models', ctx=cpu(0), **kwargs):
r"""DeepLabWV3Plus
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k, citys)
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplabv3b_plus(dataset='citys', backbone='wideresnet', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data import datasets
# infer number of classes
model = DeepLabWV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs)
model.classes = datasets[dataset].classes
if pretrained:
from .model_store import get_model_file
model.load_parameters(get_model_file('deeplab_v3b_plus_%s_%s'%(backbone, acronyms[dataset]),
tag=pretrained, root=root), ctx=ctx)
return model

def get_deeplab_v3b_plus_wideresnet_citys(**kwargs):
r"""DeepLabWV3Plus
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_v3b_plus_wideresnet_citys(pretrained=True)
>>> print(model)
"""
return get_deeplabv3b_plus('citys', 'wideresnet', **kwargs)
1 change: 1 addition & 0 deletions gluoncv/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
('3f220f537400dfa607c3d041ed3b172db39b0b01', 'psp_resnet50_ade'),
('240a4758b506447faf7c55cd7a7837d66f5039a6', 'psp_resnet101_ade'),
('0f49fb59180c4d91305b858380a4fd6eaf068b6c', 'psp_resnet101_citys'),
('ef2bb40ad8f8f59f451969b2fabe4e548394e80a', 'deeplab_v3b_plus_wideresnet_citys'),
('f5ece5ce1422eeca3ce2908004e469ffdf91fd41', 'yolo3_darknet53_voc'),
('3b47835ac3dd80f29576633949aa58aee3094353', 'yolo3_mobilenet1.0_voc'),
('66dbbae67be8f1e3cd3c995ce626a2bdc89769c6', 'yolo3_mobilenet1.0_coco'),
Expand Down
2 changes: 2 additions & 0 deletions gluoncv/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .cifarresnext import *
from .cifarwideresnet import *
from .deeplabv3 import *
from .deeplabv3b_plus import *
from .densenet import *
from .faster_rcnn import *
from .fcn import *
Expand Down Expand Up @@ -146,6 +147,7 @@
'deeplab_resnet152_voc': get_deeplab_resnet152_voc,
'deeplab_resnet50_ade': get_deeplab_resnet50_ade,
'deeplab_resnet101_ade': get_deeplab_resnet101_ade,
'deeplab_v3b_plus_wideresnet_citys': get_deeplab_v3b_plus_wideresnet_citys,
'resnet18_v1b': resnet18_v1b,
'resnet34_v1b': resnet34_v1b,
'resnet50_v1b': resnet50_v1b,
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/model_zoo/pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def hybrid_forward(self, F, x):
return tuple(outputs)

def demo(self, x):
self.predict(x)
return self.predict(x)

def predict(self, x):
h, w = x.shape[2:]
Expand Down
2 changes: 2 additions & 0 deletions gluoncv/model_zoo/segbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def get_segmentation_model(model, **kwargs):
from .pspnet import get_psp
from .deeplabv3 import get_deeplab
from .deeplabv3_plus import get_deeplab_plus
from .deeplabv3b_plus import get_deeplabv3b_plus
models = {
'fcn': get_fcn,
'psp': get_psp,
'deeplab': get_deeplab,
'deeplabplus': get_deeplab_plus,
'deeplabplusv3b': get_deeplabv3b_plus,
}
return models[model](**kwargs)

Expand Down
Loading

0 comments on commit 384f30e

Please sign in to comment.