我们通常将模型组件分为3类:
- 检测器:通常是从一张图片中检出物体的检测组件,例如:Faster R-CNN。
- 运动估计器:计算两张图片之间的运动信息的组件,例如:FlowNetSimple。
- 聚合器:聚合多张图片特征的组件,例如:EmbedAggregator。
请参考MMDetection教程来开发新检测器
新建一个文件 mmtrack/models/motion/my_flownet.py
。
from mmcv.runner import BaseModule
from ..builder import MOTION
@MOTION.register_module()
class MyFlowNet(BaseModule):
def __init__(self,
arg1,
arg2):
pass
def forward(self, inputs):
# implementation is ignored
pass
你可以在 mmtrack/models/motion/__init__.py
中加入下面一行。
from .my_flownet import MyFlowNet
或者,为了避免更改原始代码,你还可以在 config 文件中增加以下几行来实现:
custom_imports = dict(
imports=['mmtrack.models.motion.my_flownet.py'],
allow_failed_imports=False)
motion=dict(
type='MyFlowNet',
arg1=xxx,
arg2=xxx)
创建一个新文件 mmtrack/models/aggregators/my_aggregator.py
。
from mmcv.runner import BaseModule
from ..builder import AGGREGATORS
@AGGREGATORS.register_module()
class MyAggregator(BaseModule):
def __init__(self,
arg1,
arg2):
pass
def forward(self, inputs):
# implementation is ignored
pass
你可以在 mmtrack/models/aggregators/__init__.py
中加入下面一行。
from .my_aggregator import MyAggregator
或者,为了避免更改原始代码,你还可以在 config 文件中增加以下几行来实现:
custom_imports = dict(
imports=['mmtrack.models.aggregators.my_aggregator.py'],
allow_failed_imports=False)
aggregator=dict(
type='MyAggregator',
arg1=xxx,
arg2=xxx)