Skip to content

Commit

Permalink
add ResNeSt to Faster R-CNN (#1259)
Browse files Browse the repository at this point in the history
* add ResNeSt to Faster R-CNN
add custom voc dataset

* fix style

* fix style

* trigger ci

* add docs

* docs update

* add custom voc dataset

* small fix
  • Loading branch information
Jerryzcn authored Apr 17, 2020
1 parent cb07ee2 commit 18f8ab5
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 33 deletions.
35 changes: 21 additions & 14 deletions docs/model_zoo/detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,30 @@ Checkout SSD demo tutorial here: :ref:`sphx_glr_build_examples_detection_demo_ss
Faster-RCNN
-----------

Faster-RCNN models of VOC dataset are evaluated with native resolutions with ``shorter side >= 800`` but ``longer side <= 1300`` without changing aspect ratios.
Faster-RCNN models of VOC dataset are evaluated with native resolutions with ``shorter side >= 800`` but ``longer side <= 1333`` without changing aspect ratios.

Checkout Faster-RCNN demo tutorial here: :ref:`sphx_glr_build_examples_detection_demo_faster_rcnn.py`

.. table::
:widths: 50 5 25 20

+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Model | Box AP | Training Command | Training Log |
+===========================================+=================+=========================================================================================================================================+=======================================================================================================================================+
| faster_rcnn_resnet50_v1b_coco [2]_ | 37.0/57.8/39.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_resnet101_v1d_coco [2]_ | 40.1/60.9/43.3 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet50_v1b_coco [4]_ | 38.4/60.2/41.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet101_v1d_coco [4]_ | 40.8/62.4/44.7 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_bn_resnet50_v1b_coco [5]_ | 39.3/61.3/42.9 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco_train.log>`_ |
+-------------------------------------------+-----------------+-----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| Model | Box AP | Training Command | Training Log |
+=============================================+=================+===========================================================================================================================================+=========================================================================================================================================+
| faster_rcnn_resnet50_v1b_coco [2]_ | 37.0/57.8/39.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet50_v1b_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_resnet101_v1d_coco [2]_ | 40.1/60.9/43.3 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_resnet101_v1d_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet50_v1b_coco [4]_ | 38.4/60.2/41.6 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet50_v1b_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_resnet101_v1d_coco [4]_ | 40.8/62.4/44.7 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_resnet101_v1d_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_bn_resnet50_v1b_coco [5]_ | 39.3/61.3/42.9 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_bn_resnet50_v1b_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_syncbn_resnest50_coco [7]_ | 42.7/64.1/46.4 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_syncbn_resnest50_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_syncbn_resnest50_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| faster_rcnn_fpn_syncbn_resnest101_coco [7]_ | 44.9/66.4/48.9 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_syncbn_resnest101_coco.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/detection/faster_rcnn_fpn_syncbn_resnest101_coco_train.log>`_ |
+---------------------------------------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+

YOLO-v3
-------
Expand Down Expand Up @@ -284,3 +288,6 @@ Note that ``dcnv2`` indicate that models include Modulated Deformable Convolutio
.. [6] Zhou, Xingyi, Dequan Wang, and Philipp Krähenbühl. \
"Objects as Points." \
arXiv preprint arXiv:1904.07850 (2019).
.. [7] Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Muller, R. Manmatha, Mu Li and Alex Smola \
"ResNeSt: Split-Attention Network" \
arXiv preprint (2020).
2 changes: 1 addition & 1 deletion gluoncv/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import batchify
from .imagenet.classification import ImageNet, ImageNet1kAttr
from .dataloader import DetectionDataLoader, RandomTransformDataLoader
from .pascal_voc.detection import VOCDetection
from .pascal_voc.detection import VOCDetection, CustomVOCDetection
from .mscoco.detection import COCODetection
from .mscoco.detection import COCODetectionDALI
from .mscoco.instance import COCOInstance
Expand Down
42 changes: 37 additions & 5 deletions gluoncv/data/pascal_voc/detection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os

import glob
import logging
import os
import warnings

import numpy as np

try:
import xml.etree.cElementTree as ET
except ImportError:
Expand Down Expand Up @@ -87,8 +91,9 @@ def __getitem__(self, idx):
def _load_items(self, splits):
"""Load individual image indices from splits."""
ids = []
for year, name in splits:
root = os.path.join(self._root, 'VOC' + str(year))
for subfolder, name in splits:
root = os.path.join(
self._root, ('VOC' + str(subfolder)) if isinstance(subfolder, int) else subfolder)
lf = os.path.join(root, 'ImageSets', 'Main', name + '.txt')
with open(lf, 'r') as f:
ids += [(root, line.strip()) for line in f.readlines()]
Expand Down Expand Up @@ -122,9 +127,9 @@ def _load_label(self, idx):
ymax = (float(xml_box.find('ymax').text) - 1)
try:
self._validate_label(xmin, ymin, xmax, ymax, width, height)
label.append([xmin, ymin, xmax, ymax, cls_id, difficult])
except AssertionError as e:
raise RuntimeError("Invalid label at {}, {}".format(anno_path, e))
label.append([xmin, ymin, xmax, ymax, cls_id, difficult])
logging.warning("Invalid label at %s, %s", anno_path, e)
return np.array(label)

def _validate_label(self, xmin, ymin, xmax, ymax, width, height):
Expand All @@ -145,3 +150,30 @@ def _preload_labels(self):
"""Preload all labels into memory."""
logging.debug("Preloading %s labels into memory...", str(self))
return [self._load_label(idx) for idx in range(len(self))]


class CustomVOCDetection(VOCDetection):
"""Custom Pascal VOC detection Dataset.
Classes are generated from dataset
generate_classes : bool, default False
If True, generate class labels base on the annotations instead of the default classe labels.
"""

def __init__(self, generate_classes=False, **kwargs):
super(CustomVOCDetection, self).__init__(**kwargs)
if generate_classes:
self.CLASSES = self._generate_classes()

def _generate_classes(self):
classes = set()
all_xml = glob.glob(os.path.join(self._root, 'Annotations', '*.xml'))
for each_xml_file in all_xml:
tree = ET.parse(each_xml_file)
root = tree.getroot()
for child in root:
if child.tag == 'object':
for item in child:
if item.tag == 'name':
classes.add(item.text)
classes = sorted(list(classes))
return classes
2 changes: 2 additions & 0 deletions gluoncv/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
('da9756faa5b9b4e34dedcf83ee0733d5895796ad', 'ssd_512_mobilenet1.0_coco'),
('447328d89d70ae1e2ca49226b8d834e5a5456df3', 'faster_rcnn_resnet50_v1b_voc'),
('5b4690fb7c5b62c44fb36c67d0642b633697f1bb', 'faster_rcnn_resnet50_v1b_coco'),
('6df46961827647d418b11ffaf616a6a60d9dd16e', 'faster_rcnn_fpn_syncbn_resnest50_coco'),
('a465eca35e78aba6ebdf99bf52031a447e501063', 'faster_rcnn_resnet101_v1d_coco'),
('233572743bc537291590f4edf8a0c17c14b234bb', 'faster_rcnn_fpn_resnet50_v1b_coco'),
('1194ab4ec6e06386aadd55820add312c8ef59c74', 'faster_rcnn_fpn_resnet101_v1d_coco'),
('baebfa1b7d7f56dd33a7687efea4b014736bd791', 'faster_rcnn_fpn_syncbn_resnest101_coco'),
('e071cf1550bc0331c218a9072b59e9550595d1e7', 'mask_rcnn_resnet18_v1b_coco'),
('a3527fdc2cee5b1f32a61e5fd7cda8fb673e86e5', 'mask_rcnn_resnet50_v1b_coco'),
('4a3249c584f81c2a9b5d852b742637cd692ebdcb', 'mask_rcnn_resnet101_v1d_coco'),
Expand Down
3 changes: 3 additions & 0 deletions gluoncv/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@
'faster_rcnn_resnet50_v1b_coco': faster_rcnn_resnet50_v1b_coco,
'faster_rcnn_fpn_resnet50_v1b_coco': faster_rcnn_fpn_resnet50_v1b_coco,
'faster_rcnn_fpn_syncbn_resnet50_v1b_coco': faster_rcnn_fpn_syncbn_resnet50_v1b_coco,
'faster_rcnn_fpn_syncbn_resnest50_coco': faster_rcnn_fpn_syncbn_resnest50_coco,
'faster_rcnn_resnet50_v1b_custom': faster_rcnn_resnet50_v1b_custom,
'faster_rcnn_resnet101_v1d_voc': faster_rcnn_resnet101_v1d_voc,
'faster_rcnn_resnet101_v1d_coco': faster_rcnn_resnet101_v1d_coco,
'faster_rcnn_fpn_resnet101_v1d_coco': faster_rcnn_fpn_resnet101_v1d_coco,
'faster_rcnn_fpn_syncbn_resnet101_v1d_coco': faster_rcnn_fpn_syncbn_resnet101_v1d_coco,
'faster_rcnn_fpn_syncbn_resnest101_coco': faster_rcnn_fpn_syncbn_resnest101_coco,
'faster_rcnn_resnet101_v1d_custom': faster_rcnn_resnet101_v1d_custom,
'faster_rcnn_fpn_syncbn_resnest269_coco': faster_rcnn_fpn_syncbn_resnest269_coco,
'custom_faster_rcnn_fpn': custom_faster_rcnn_fpn,
'mask_rcnn_resnet50_v1b_coco': mask_rcnn_resnet50_v1b_coco,
'mask_rcnn_fpn_resnet50_v1b_coco': mask_rcnn_fpn_resnet50_v1b_coco,
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/model_zoo/rcnn/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def custom_faster_rcnn_fpn(classes, transfer=None, dataset='custom', pretrained_
module_list.append('bn')
net = get_model(
'_'.join(['faster_rcnn'] + module_list + [base_network_name, str(transfer)]),
pretrained=True)
pretrained=True, per_device_batch_size=kwargs['per_device_batch_size'])
reuse_classes = [x for x in classes if x in net.classes]
net.reset_class(classes, reuse_weights=reuse_classes)
return net
Loading

0 comments on commit 18f8ab5

Please sign in to comment.