Skip to content

Latest commit

 

History

History
109 lines (75 loc) · 2.37 KB

customize_vid_model.md

File metadata and controls

109 lines (75 loc) · 2.37 KB

自定义视频目标检测模型

我们通常将模型组件分为3类:

  • 检测器:通常是从一张图片中检出物体的检测组件,例如:Faster R-CNN。
  • 运动估计器:计算两张图片之间的运动信息的组件,例如:FlowNetSimple。
  • 聚合器:聚合多张图片特征的组件,例如:EmbedAggregator。

增加一个新的检测器

请参考MMDetection教程来开发新检测器

增加一个新的运动估计器

1. 定义一个运动估计模型(例如:MyFlowNet)

新建一个文件 mmtrack/models/motion/my_flownet.py

from mmcv.runner import BaseModule

from ..builder import MOTION

@MOTION.register_module()
class MyFlowNet(BaseModule):

    def __init__(self,
                arg1,
                arg2):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. 引入模块

你可以在 mmtrack/models/motion/__init__.py 中加入下面一行。

from .my_flownet import MyFlowNet

或者,为了避免更改原始代码,你还可以在 config 文件中增加以下几行来实现:

custom_imports = dict(
    imports=['mmtrack.models.motion.my_flownet.py'],
    allow_failed_imports=False)

3. 更改原始 config 文件

motion=dict(
    type='MyFlowNet',
    arg1=xxx,
    arg2=xxx)

增加一个新的聚合器

1. 定义一个聚合器

创建一个新文件 mmtrack/models/aggregators/my_aggregator.py

from mmcv.runner import BaseModule

from ..builder import AGGREGATORS

@AGGREGATORS.register_module()
class MyAggregator(BaseModule):

    def __init__(self,
                arg1,
                arg2):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. 引入模块

你可以在 mmtrack/models/aggregators/__init__.py 中加入下面一行。

from .my_aggregator import MyAggregator

或者,为了避免更改原始代码,你还可以在 config 文件中增加以下几行来实现:

custom_imports = dict(
    imports=['mmtrack.models.aggregators.my_aggregator.py'],
    allow_failed_imports=False)

3. 更改原始 config 文件

aggregator=dict(
    type='MyAggregator',
    arg1=xxx,
    arg2=xxx)