Skip to content

Commit

Permalink
Add SMOT tracker (#1573)
Browse files Browse the repository at this point in the history
* add smot

* lint

* lint

* lint

* tutorial

* tutorial change

* fix comments
  • Loading branch information
bryanyzhu authored Dec 19, 2020
1 parent b8a3135 commit 0f0b226
Show file tree
Hide file tree
Showing 23 changed files with 3,007 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Check the HD video at [Youtube](https://www.youtube.com/watch?v=nfpouVAzXt0) or
| [Semantic Segmentation:](https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation) <br/>associate each pixel of an image <br/> with a categorical label. | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation"><img src="docs/_static/semantic-segmentation.png" alt="semantic" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">FCN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">PSP</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">ICNet</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">DeepLab-v3</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">DeepLab-v3+</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">DANet</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">FastSCNN</a> |
| [Instance Segmentation:](https://gluon-cv.mxnet.io/model_zoo/segmentation.html#instance-segmentation) <br/>detect objects and associate <br/> each pixel inside object area with an <br/> instance label. | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#instance-segmentation"><img src="docs/_static/instance-segmentation.png" alt="instance" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#instance-segmentation">Mask RCNN</a>|
| [Pose Estimation:](https://gluon-cv.mxnet.io/model_zoo/pose.html) <br/>detect human pose <br/> from images. | <a href="https://gluon-cv.mxnet.io/model_zoo/pose.html"><img src="docs/_static/pose-estimation.svg" alt="pose" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/pose.html#simple-pose-with-resnet">Simple Pose</a>|
| [Video Action Recognition:](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) <br/>recognize human actions <br/> in a video. | <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html"><img src="docs/_static/action-recognition.png" alt="action_recognition" height="200"/></a> | MXNet: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">C3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">P3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">SlowFast</a> <br/> PyTorch: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">CSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/SlowFast.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TPN</a> |
| [Video Action Recognition:](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) <br/>recognize human actions <br/> in a video. | <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html"><img src="docs/_static/action-recognition.png" alt="action_recognition" height="200"/></a> | MXNet: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">C3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">P3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">SlowFast</a> <br/> PyTorch: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">CSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">SlowFast</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TPN</a> |
| [Depth Prediction:](https://gluon-cv.mxnet.io/model_zoo/depth.html) <br/>predict depth map <br/> from images. | <a href="https://gluon-cv.mxnet.io/model_zoo/depth.html"><img src="docs/_static/depth.png" alt="depth" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/depth.html#kitti-dataset">Monodepth2</a>|
| [GAN:](https://github.com/dmlc/gluon-cv/tree/master/scripts/gan) <br/>generate visually deceptive images | <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan"><img src="https://github.com/dmlc/gluon-cv/raw/master/scripts/gan/wgan/fake_samples_400000.png" alt="lsun" height="200"/></a> | <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan/wgan">WGAN</a>, <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan/cycle_gan">CycleGAN</a>, <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan/stylegan">StyleGAN</a>|
| [Person Re-ID:](https://github.com/dmlc/gluon-cv/tree/master/scripts/re-id/baseline) <br/>re-identify pedestrians across scenes | <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/re-id/baseline"><img src="https://user-images.githubusercontent.com/3307514/46702937-f4311800-cbd9-11e8-8eeb-c945ec5643fb.png" alt="re-id" height="160"/></a> |<a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/re-id/baseline">Market1501 baseline</a> |
Expand Down
Binary file added docs/_static/smot_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ Object Tracking

SiamRPN training on VID、DET、COCO、Youtube_bb and test on Otb2015

.. card::
:title: Pre-trained SMOT Models
:link: ../build/examples_tracking/demo_smot.html

Perform Multi-Object Tracking in real-world video with pre-trained SMOT models.


Depth Prediction
---------------------
Expand Down
46 changes: 46 additions & 0 deletions docs/tutorials/tracking/demo_smot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""03. Multiple object tracking with pre-trained SMOT models
=============================================================
In this tutorial, we present a method,
called `Single-Shot Multi Object Tracking (SMOT) <https://arxiv.org/abs/2010.16031>`_, to perform multi-object tracking.
SMOT is a new tracking framework that converts any single-shot detector (SSD) model into an online multiple object tracker,
which emphasizes simultaneously detecting and tracking of the object paths.
As an example below, we directly use the SSD-Mobilenet object detector pretrained on COCO from :ref:`gluoncv-model-zoo`
and perform multiple object tracking on an arbitrary video.
We want to point out that, SMOT is very efficient, its runtime is close to the runtime of the chosen detector.
"""

######################################################################
# Predict with a SMOT model
# ----------------------------
#
# First, we download a video from MOT challenge website,

from gluoncv import utils
video_path = 'https://motchallenge.net/sequenceVideos/MOT17-02-FRCNN-raw.webm'
im_video = utils.download(video_path)

################################################################
# Then you can simply use our provided script under `/scripts/tracking/smot/demo.py` to obtain the multi-object tracking result.
#
# ::
#
# python demo.py MOT17-02-FRCNN-raw.webm
#
#
################################################################
# You can see the tracking results below. Here, we only track persons,
# but you can track other objects as long as your detector is trained on that category.
#
# .. raw:: html
#
# <div align="center">
# <img src="../../_static/smot_demo.gif">
# </div>
#
# <br>

################################################################
# Our model is able to track multiple persons even when they are partially occluded.
# Try it on your own video and see the results!
1 change: 1 addition & 0 deletions gluoncv/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .siamrpn import *
from .fastscnn import *
from .monodepthv2 import *
from .smot import *
1 change: 1 addition & 0 deletions gluoncv/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@
('661ee2e1bf824f4f4549b3488c59dec0b0078c38', 'monodepth2_resnet18_posenet_kitti_mono_640x192'),
('c14979bb016ed4f555fa09004ddc7616dd60b8b9', 'monodepth2_resnet18_posenet_kitti_mono_stereo_640x192'),
('299b1d9d8a2bcf7c122acd0d23606af4fdfbe7e1', 'i3d_slow_resnet101_f16s4_kinetics700'),
('d6758fc8cddfaaa8d0f7ff2e21adf5b0180f6b4b', 'smot_ssd_bifpn_mobilenet'),
]}

apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
Expand Down
8 changes: 8 additions & 0 deletions gluoncv/model_zoo/smot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# pylint: disable=wildcard-import
"""
SMOT: Single-Shot Multi Object Tracking
https://arxiv.org/abs/2010.16031
"""
from __future__ import absolute_import
from .smot_tracker import *
from .tracktors import *
79 changes: 79 additions & 0 deletions gluoncv/model_zoo/smot/anchor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# pylint: disable=unused-import
"""Anchor box generator for SSD detector."""
from __future__ import absolute_import

import numpy as np
from mxnet import gluon


class SSDAnchorGenerator(gluon.HybridBlock):
"""Bounding box anchor generator for Single-shot Object Detection.
Parameters
----------
index : int
Index of this generator in SSD models, this is required for naming.
sizes : iterable of floats
Sizes of anchor boxes.
ratios : iterable of floats
Aspect ratios of anchor boxes.
step : int or float
Step size of anchor boxes.
alloc_size : tuple of int
Allocate size for the anchor boxes as (H, W).
Usually we generate enough anchors for large feature map, e.g. 128x128.
Later in inference we can have variable input sizes,
at which time we can crop corresponding anchors from this large
anchor map so we can skip re-generating anchors for each input.
offsets : tuple of float
Center offsets of anchor boxes as (h, w) in range(0, 1).
"""
def __init__(self, index, im_size, sizes, ratios, step, alloc_size=(128, 128),
offsets=(0.5, 0.5), clip=False, **kwargs):
super(SSDAnchorGenerator, self).__init__(**kwargs)
assert len(im_size) == 2
self._im_size = im_size
self._clip = clip
self._sizes = sizes
self._ratios = ratios
anchors = self._generate_anchors(self._sizes, self._ratios, step, alloc_size, offsets)
self._num_anchors = np.size(anchors) / 4
self.anchors = self.params.get_constant('anchor_%d'%(index), anchors)

def _generate_anchors(self, sizes, ratios, step, alloc_size, offsets):
# pylint: disable=unused-argument,too-many-function-args
"""Generate anchors for once. Anchors are stored with (center_x, center_y, w, h) format."""
anchors = []
for i in range(alloc_size[0]):
for j in range(alloc_size[1]):
cy = (i + offsets[0]) * step
cx = (j + offsets[1]) * step

for sz in self._sizes:
for r in ratios:
sr = np.sqrt(r)
w = sz * sr
h = sz / sr
anchors.append([cx, cy, w, h])
return np.array(anchors).reshape(1, 1, alloc_size[0], alloc_size[1], -1)

@property
def num_depth(self):
"""Number of anchors at each pixel."""
return len(self._sizes) * len(self._ratios)

@property
def num_anchors(self):
"""Number of anchors at each pixel."""
return self._num_anchors

# pylint: disable=arguments-differ
def hybrid_forward(self, F, x, anchors):
a = F.slice_like(anchors, x * 0, axes=(2, 3))
a = a.reshape((1, -1, 4))
if self._clip:
cx, cy, cw, ch = a.split(axis=-1, num_outputs=4)
H, W = self._im_size
a = F.concat(*[cx.clip(0, W), cy.clip(0, H), cw.clip(0, W), ch.clip(0, H)], dim=-1)
return a.reshape((1, -1, 4))
106 changes: 106 additions & 0 deletions gluoncv/model_zoo/smot/decoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
MXNet implementation of SMOT: Single-Shot Multi Object Tracking
https://arxiv.org/abs/2010.16031
"""
from mxnet import gluon
from gluoncv.nn.bbox import BBoxCenterToCorner


class NormalizedLandmarkCenterDecoder(gluon.HybridBlock):
"""
Decode bounding boxes training target with normalized center offsets.
This decoder must cooperate with NormalizedBoxCenterEncoder of same `stds`
in order to get properly reconstructed bounding boxes.
Returned bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.
Parameters
----------
stds : array-like of size 4
Std value to be divided from encoded values, default is (0.1, 0.1, 0.2, 0.2).
means : array-like of size 4
Mean value to be subtracted from encoded values, default is (0., 0., 0., 0.).
clip: float, default is None
If given, bounding box target will be clipped to this value.
"""

def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.),
convert_anchor=True):
super(NormalizedLandmarkCenterDecoder, self).__init__()
assert len(stds) == 4, "Box Encoder requires 4 std values."
self._stds = stds
self._means = means
if convert_anchor:
self.center_to_conner = BBoxCenterToCorner(split=True)
else:
self.center_to_conner = None

def hybrid_forward(self, F, x, anchors):
"""center decoder forward"""
if self.center_to_conner is not None:
a = self.center_to_conner(anchors)
else:
a = anchors.split(axis=-1, num_outputs=4)
ld = F.split(x, axis=-1, num_outputs=10)

x0 = F.broadcast_add(F.broadcast_mul(ld[0] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
y0 = F.broadcast_add(F.broadcast_mul(ld[1] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
x1 = F.broadcast_add(F.broadcast_mul(ld[2] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
y1 = F.broadcast_add(F.broadcast_mul(ld[3] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
x2 = F.broadcast_add(F.broadcast_mul(ld[4] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
y2 = F.broadcast_add(F.broadcast_mul(ld[5] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
x3 = F.broadcast_add(F.broadcast_mul(ld[6] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
y3 = F.broadcast_add(F.broadcast_mul(ld[7] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
x4 = F.broadcast_add(F.broadcast_mul(ld[8] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
y4 = F.broadcast_add(F.broadcast_mul(ld[9] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])

return F.concat(x0, y0, x1, y1, x2, y2, x3, y3, x4, y4, dim=-1)


class GeneralNormalizedKeyPointsDecoder(gluon.HybridBlock):
"""
Decode bounding boxes training target with normalized center offsets.
This decoder must cooperate with NormalizedBoxCenterEncoder of same `stds`
in order to get properly reconstructed bounding boxes.
Returned bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.
Parameters
----------
stds : array-like of size 4
Std value to be divided from encoded values, default is (0.1, 0.1, 0.2, 0.2).
means : array-like of size 4
Mean value to be subtracted from encoded values, default is (0., 0., 0., 0.).
clip: float, default is None
If given, bounding box target will be clipped to this value.
"""

def __init__(self, num_points, stds=(0.2, 0.2), means=(0.5, 0.2),
convert_anchor=True):
super(GeneralNormalizedKeyPointsDecoder, self).__init__()
assert len(stds) == 2, "Box Encoder requires 4 std values."
self._stds = stds
self._means = means
self._size = num_points * 2
if convert_anchor:
self.center_to_conner = BBoxCenterToCorner(split=True)
else:
self.center_to_conner = None

def hybrid_forward(self, F, x, anchors):
"""key point decoder forward"""
if self.center_to_conner is not None:
a = self.center_to_conner(anchors)
else:
a = anchors.split(axis=-1, num_outputs=4)
ld = F.split(x, axis=-1, num_outputs=self._size)

outputs = []
for i in range(0, self._size, 2):
x = F.broadcast_add(F.broadcast_mul(ld[i] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
y = F.broadcast_add(F.broadcast_mul(ld[i+1] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
outputs.extend([x, y])

return F.concat(*outputs, dim=-1)
Loading

0 comments on commit 0f0b226

Please sign in to comment.