Skip to content

Commit

Permalink
Add ResNeSt Segmentation Code (#1254)
Browse files Browse the repository at this point in the history
* add resnest segmentation code

* add warmup for segmentation

* fix typo

* add pretrained and fix dilation

* pretrained as args

* update sha1

* model zoo page
  • Loading branch information
zhanghang1989 authored Apr 16, 2020
1 parent f40e731 commit cb07ee2
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 37 deletions.
37 changes: 22 additions & 15 deletions docs/model_zoo/segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,25 @@ Table of pre-trained models for semantic segmentation and their performance.
ADE20K Dataset
--------------

+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| Name | Method | pixAcc | mIoU | Command | log |
+=======================+=================+===========+===========+==============================================================================================================================+=====================================================================================================================+
| fcn_resnet50_ade | FCN [2]_ | 79.0 | 39.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet50_ade.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| fcn_resnet101_ade | FCN [2]_ | 80.6 | 41.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet101_ade.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| psp_resnet50_ade | PSP [3]_ | 80.1 | 41.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet50_ade.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| psp_resnet101_ade | PSP [3]_ | 80.8 | 43.3 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_ade.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnet50_ade | DeepLabV3 [4]_ | 80.5 | 42.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet50_ade.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnet101_ade | DeepLabV3 [4]_ | 81.1 | 44.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet101_ade.log>`_ |
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| Name | Method | pixAcc | mIoU | Command | log |
+=======================+================================+===========+===========+==============================================================================================================================+=====================================================================================================================+
| fcn_resnet50_ade | FCN [2]_ | 79.0 | 39.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet50_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| fcn_resnet101_ade | FCN [2]_ | 80.6 | 41.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/fcn_resnet101_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| psp_resnet50_ade | PSP [3]_ | 80.1 | 41.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet50_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| psp_resnet101_ade | PSP [3]_ | 80.8 | 43.3 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnet50_ade | DeepLabV3 [4]_ | 80.5 | 42.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet50_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnet101_ade | DeepLabV3 [4]_ | 81.1 | 44.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet101_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnest50_ade | DeepLabV3 + ResNeSt [4]_ [7]_ | 81.2 | 45.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnest50_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnest50_ade.log>`_ |
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
| deeplab_resnest101_ade| DeepLabV3 + ResNeSt [4]_ [7]_ | 82.1 | 46.9 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnest101_ade.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnest101_ade.log>`_|
+-----------------------+--------------------------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+

MS-COCO Dataset Pretrain
------------------------
Expand Down Expand Up @@ -170,3 +174,6 @@ MS COCO
ECCV 2018.
.. [6] Zhu, Yi, et al. "Improving Semantic Segmentation via Video Propagation and Label Relaxation." \
CVPR 2019.
.. [7] Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Muller, R. Manmatha, Mu Li and Alex Smola \
"ResNeSt: Split-Attention Network" \
arXiv preprint (2020).
79 changes: 79 additions & 0 deletions gluoncv/model_zoo/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

__all__ = ['DeepLabV3', 'get_deeplab', 'get_deeplab_resnet101_coco',
'get_deeplab_resnet101_voc', 'get_deeplab_resnet50_ade', 'get_deeplab_resnet101_ade',
'get_deeplab_resnest50_ade', 'get_deeplab_resnest101_ade',
'get_deeplab_resnest200_ade', 'get_deeplab_resnest269_ade',
'get_deeplab_resnet152_coco', 'get_deeplab_resnet152_voc', 'get_deeplab_resnet50_citys',
'get_deeplab_resnet101_citys']

Expand Down Expand Up @@ -327,6 +329,83 @@ def get_deeplab_resnet101_ade(**kwargs):
"""
return get_deeplab('ade20k', 'resnet101', **kwargs)

def get_deeplab_resnest50_ade(**kwargs):
r"""DeepLabV3
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnest50_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest50', **kwargs)


def get_deeplab_resnest101_ade(**kwargs):
r"""DeepLabV3
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnest101_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest101', **kwargs)

def get_deeplab_resnest200_ade(**kwargs):
r"""DeepLabV3
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnest200_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest200', **kwargs)

def get_deeplab_resnest269_ade(**kwargs):
r"""DeepLabV3
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_deeplab_resnest269_ade(pretrained=True)
>>> print(model)
"""
return get_deeplab('ade20k', 'resnest269', **kwargs)

def get_deeplab_resnet50_citys(**kwargs):
r"""DeepLabV3
Parameters
Expand Down
2 changes: 2 additions & 0 deletions gluoncv/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
('d35bea8817935d1ab310ef1e6dd06bb18c2d5f0d', 'deeplab_resnet152_voc'),
('c7789b237adc7253405bee57c84d53b15db45942', 'deeplab_resnet50_ade'),
('bf1584dfcec12063eff3075ee643e181c0f6d443', 'deeplab_resnet101_ade'),
('a8312db6e30a464151580f2bda83479786455724', 'deeplab_resnest50_ade'),
('6d05c630fb7acb38615f7f4d360fb90f47b25042', 'deeplab_resnest101_ade'),
('09f89cad0e107cb2bffdb1b07706ba31798096f2', 'psp_resnet101_coco'),
('2c2f4e1c2b11461b52598a4b2038bccbcfc166eb', 'psp_resnet101_voc'),
('3f220f537400dfa607c3d041ed3b172db39b0b01', 'psp_resnet50_ade'),
Expand Down
4 changes: 4 additions & 0 deletions gluoncv/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@
'deeplab_resnet152_voc': get_deeplab_resnet152_voc,
'deeplab_resnet50_ade': get_deeplab_resnet50_ade,
'deeplab_resnet101_ade': get_deeplab_resnet101_ade,
'deeplab_resnest50_ade': get_deeplab_resnest50_ade,
'deeplab_resnest101_ade': get_deeplab_resnest101_ade,
'deeplab_resnest200_ade': get_deeplab_resnest200_ade,
'deeplab_resnest269_ade': get_deeplab_resnest269_ade,
'deeplab_resnet50_citys': get_deeplab_resnet50_citys,
'deeplab_resnet101_citys': get_deeplab_resnet101_citys,
'deeplab_v3b_plus_wideresnet_citys': get_deeplab_v3b_plus_wideresnet_citys,
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/model_zoo/resnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, channels, cardinality=1, bottleneck_width=64, strides=1, dila
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.dropblock_prob = dropblock_prob
self.use_splat = use_splat
self.avd = avd and (strides > 1 or dilation > 1)
self.avd = avd and (strides > 1 or previous_dilation != dilation)
self.avd_first = avd_first
if self.dropblock_prob > 0:
self.dropblock1 = DropBlock(dropblock_prob, 3, group_width, *input_size)
Expand Down
32 changes: 20 additions & 12 deletions gluoncv/model_zoo/segbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mxnet.gluon.nn import HybridBlock
from ..utils.parallel import parallel_apply
from .resnetv1b import resnet18_v1b, resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
from .resnest import resnest50, resnest101, resnest200, resnest269
from ..utils.parallel import tuple_map
# pylint: disable=wildcard-import,abstract-method,arguments-differ,dangerous-default-value,missing-docstring

Expand All @@ -29,6 +30,24 @@ def get_segmentation_model(model, **kwargs):
}
return models[model](**kwargs)

def get_backbone(name, **kwargs):
models = {
'resnet18': resnet18_v1b,
'resnet34': resnet34_v1b,
'resnet50': resnet50_v1s,
'resnet101': resnet101_v1s,
'resnet152': resnet152_v1s,
'resnest50': resnest50,
'resnest101': resnest101,
'resnest200': resnest200,
'resnest269': resnest269,
}
name = name.lower()
if name not in models:
raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
net = models[name](**kwargs)
return net

class SegBaseModel(HybridBlock):
r"""Base Model for Semantic Segmentation
Expand All @@ -48,18 +67,7 @@ def __init__(self, nclass, aux, backbone='resnet50', height=None, width=None,
self.aux = aux
self.nclass = nclass
with self.name_scope():
if backbone == 'resnet18':
pretrained = resnet18_v1b(pretrained=pretrained_base, dilated=True, **kwargs)
elif backbone == 'resnet34':
pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=True, **kwargs)
elif backbone == 'resnet50':
pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=True, **kwargs)
elif backbone == 'resnet101':
pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=True, **kwargs)
elif backbone == 'resnet152':
pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=True, **kwargs)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
pretrained = get_backbone(backbone, pretrained=pretrained_base, dilated=True, **kwargs)
self.conv1 = pretrained.conv1
self.bn1 = pretrained.bn1
self.relu = pretrained.relu
Expand Down
9 changes: 8 additions & 1 deletion scripts/segmentation/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

def parse_args():
parser = argparse.ArgumentParser(description='Validation on Semantic Segmentation model')
parser.add_argument('--model-zoo', type=str, default=None,
help='evaluating on model zoo model')
parser.add_argument('--model', type=str, default='fcn',
help='model name (default: fcn)')
parser.add_argument('--backbone', type=str, default='resnet101',
Expand Down Expand Up @@ -257,7 +259,12 @@ def benchmarking(model, args):
if args.calibration:
args.pretrained = True
# create network
if args.pretrained:
if args.model_zoo is not None:
model = get_model(args.model_zoo, norm_layer=args.norm_layer,
norm_kwargs=args.norm_kwargs, aux=args.aux,
base_size=args.base_size, crop_size=args.crop_size,
ctx=args.ctx, pretrained=True)
elif args.pretrained:
if 'icnet' in model_prefix:
model = get_model(model_prefix, pretrained=True, height=args.height, width=args.width)
else:
Expand Down
Loading

0 comments on commit cb07ee2

Please sign in to comment.