1 Star 0 Fork 9

林纽克斯 / mmtracking

forked from OpenMMLab / mmtracking 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
customize_sot_model.md 3.26 KB
一键复制 编辑 原始数据 按行查看 历史
Tao Gong 提交于 2021-12-17 16:35 . [Docs] Refactor documentations (#373)

Customize SOT Models

We basically categorize model components into 4 types.

  • backbone: usually an FCN network to extract feature maps, e.g., ResNet, MobileNet.
  • neck: the component between backbones and heads, e.g., ChannelMapper, FPN.
  • head: the component for specific tasks, e.g., tracking bbox prediction.
  • loss: the component in head for calculating losses, e.g., FocalLoss, L1Loss.

Add a new backbone

Here we show how to develop new components with an example of MobileNet.

1. Define a new backbone (e.g. MobileNet)

Create a new file mmtrack/models/backbones/mobilenet.py.

import torch.nn as nn
from mmcv.runner import BaseModule

from mmdet.models.builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(BaseModule):

    def __init__(self, arg1, arg2, *args, **kwargs):
        pass

    def forward(self, x):  # should return a tuple
        pass

2. Import the module

You can either add the following line to mmtrack/models/backbones/__init__.py

from .mobilenet import MobileNet

or alternatively add

custom_imports = dict(
    imports=['mmtrack.models.backbones.mobilenet'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the backbone in your config file

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

Add a new neck

1. Define a neck (e.g. MyFPN)

Create a new file mmtrack/models/necks/my_fpn.py.

from mmcv.runner import BaseModule

from mmdet.models.builder import NECKS

@NECKS.register_module()
class MyFPN(BaseModule):

    def __init__(self, arg1, arg2, *args, **kwargs):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. Import the module

You can either add the following line to mmtrack/models/necks/__init__.py,

from .my_fpn import MyFPN

or alternatively add

custom_imports = dict(
    imports=['mmtrack.models.necks.my_fpn.py'],
    allow_failed_imports=False)

to the config file and avoid modifying the original code.

3. Modify the config file

neck=dict(
    type='MyFPN',
    arg1=xxx,
    arg2=xxx),

Add a new head

1. Define a head (e.g. MyHead)

Create a new file mmtrack/models/track_heads/my_head.py.

from mmcv.runner import BaseModule

from mmdet.models import HEADS

@HEADS.register_module()
class MyHead(BaseModule):

    def __init__(self, arg1, arg2, *args, **kwargs):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. Import the module

You can either add the following line to mmtrack/models/track_heads/__init__.py,

from .my_head import MyHead

or alternatively add

custom_imports = dict(
    imports=['mmtrack.models.track_heads.my_head.py'],
    allow_failed_imports=False)

to the config file and avoid modifying the original code.

3. Modify the config file

track_head=dict(
    type='MyHead',
    arg1=xxx,
    arg2=xxx)

Add a new loss

Please refer to Add a new loss for developping a new loss.

1
https://gitee.com/AKA-kabuda/mmtracking.git
git@gitee.com:AKA-kabuda/mmtracking.git
AKA-kabuda
mmtracking
mmtracking
master

搜索帮助

53164aa7 5694891 3bd8fe86 5694891