mmDetection源码分析(三):Faster R-CNN模块解读(一)

Faster R-CNN模块解读(一)— 检测器的构建

根据之前的介绍,config文件中的 model 中的 type 指定了检测器是一个Faster R-CNN检测器。

我们知道要调用一个检测模型要用到函数model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)。接下来,我们看看其是如何调用到Faster R-CNN模块的。

mmdetection/configs/faster_rcnn_r50_fpn_1x.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
model = dict(
type='FasterRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))

mmdetection/mmdet/models/registry.py

1
2
3
4
5
6
7
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

mmdetection/mmdet/models/builder.py

1
2
3
4
5
6
7
8
9
10
11
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)

def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

mmdetection/mmdet/utils/registry.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.

Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.

Returns:
obj: The constructed object.
"""

assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type)
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)

通过build_from_cfg返回了cfg.model.type定义的类。registry指明了该类在mmdetection/mmdet/models/detectors中,查看文件faster_rcnn.py(此处有其框架设计的问题,比如注册器的使用,在本系列中暂时不分析)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):

def __init__(self,
backbone,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
pretrained=None):

super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)

可以看到FasterRCNN继承于TwoStageDetector

mmdetection/mmdet/models/detectors/two_stage.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):


def __init__(self,
backbone,
neck=None,
shared_head=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):

super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)

if neck is not None:
self.neck = builder.build_neck(neck)

if shared_head is not None:
self.shared_head = builder.build_shared_head(shared_head)

if rpn_head is not None:
self.rpn_head = builder.build_head(rpn_head)

if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_head(bbox_head)

if mask_head is not None:
if mask_roi_extractor is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.share_roi_extractor = False
else:
self.share_roi_extractor = True
self.mask_roi_extractor = self.bbox_roi_extractor
self.mask_head = builder.build_head(mask_head)

self.train_cfg = train_cfg
self.test_cfg = test_cfg

self.init_weights(pretrained=pretrained)

def init_weights(self, pretrained=None):
super(TwoStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
if self.with_shared_head:
self.shared_head.init_weights(pretrained=pretrained)
if self.with_rpn:
self.rpn_head.init_weights()
if self.with_bbox:
self.bbox_roi_extractor.init_weights()
self.bbox_head.init_weights()
if self.with_mask:
self.mask_head.init_weights()
if not self.share_roi_extractor:
self.mask_roi_extractor.init_weights()

可以看到通过builder将检测器的backbone,neck,shared_head,bbox_head,mask_head等全部创建了。并通过init_weights初始话权重,后续有:

extract_feat()
forward_train()
simple_test && aug_test

下一篇blog将介绍Backbone的构建。