mirror of https://github.com/hero-y/BHRL
158 lines
6.4 KiB
Python
158 lines
6.4 KiB
Python
|
import torch.nn as nn
|
||
|
from mmcv.cnn import ConvModule
|
||
|
from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell
|
||
|
from mmcv.runner import BaseModule, ModuleList
|
||
|
|
||
|
from ..builder import NECKS
|
||
|
|
||
|
|
||
|
@NECKS.register_module()
|
||
|
class NASFPN(BaseModule):
|
||
|
"""NAS-FPN.
|
||
|
|
||
|
Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture
|
||
|
for Object Detection <https://arxiv.org/abs/1904.07392>`_
|
||
|
|
||
|
Args:
|
||
|
in_channels (List[int]): Number of input channels per scale.
|
||
|
out_channels (int): Number of output channels (used at each scale)
|
||
|
num_outs (int): Number of output scales.
|
||
|
stack_times (int): The number of times the pyramid architecture will
|
||
|
be stacked.
|
||
|
start_level (int): Index of the start input backbone level used to
|
||
|
build the feature pyramid. Default: 0.
|
||
|
end_level (int): Index of the end input backbone level (exclusive) to
|
||
|
build the feature pyramid. Default: -1, which means the last level.
|
||
|
add_extra_convs (bool): It decides whether to add conv
|
||
|
layers on top of the original feature maps. Default to False.
|
||
|
If True, its actual mode is specified by `extra_convs_on_inputs`.
|
||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels,
|
||
|
out_channels,
|
||
|
num_outs,
|
||
|
stack_times,
|
||
|
start_level=0,
|
||
|
end_level=-1,
|
||
|
add_extra_convs=False,
|
||
|
norm_cfg=None,
|
||
|
init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
|
||
|
super(NASFPN, self).__init__(init_cfg)
|
||
|
assert isinstance(in_channels, list)
|
||
|
self.in_channels = in_channels
|
||
|
self.out_channels = out_channels
|
||
|
self.num_ins = len(in_channels) # num of input feature levels
|
||
|
self.num_outs = num_outs # num of output feature levels
|
||
|
self.stack_times = stack_times
|
||
|
self.norm_cfg = norm_cfg
|
||
|
|
||
|
if end_level == -1:
|
||
|
self.backbone_end_level = self.num_ins
|
||
|
assert num_outs >= self.num_ins - start_level
|
||
|
else:
|
||
|
# if end_level < inputs, no extra level is allowed
|
||
|
self.backbone_end_level = end_level
|
||
|
assert end_level <= len(in_channels)
|
||
|
assert num_outs == end_level - start_level
|
||
|
self.start_level = start_level
|
||
|
self.end_level = end_level
|
||
|
self.add_extra_convs = add_extra_convs
|
||
|
|
||
|
# add lateral connections
|
||
|
self.lateral_convs = nn.ModuleList()
|
||
|
for i in range(self.start_level, self.backbone_end_level):
|
||
|
l_conv = ConvModule(
|
||
|
in_channels[i],
|
||
|
out_channels,
|
||
|
1,
|
||
|
norm_cfg=norm_cfg,
|
||
|
act_cfg=None)
|
||
|
self.lateral_convs.append(l_conv)
|
||
|
|
||
|
# add extra downsample layers (stride-2 pooling or conv)
|
||
|
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
||
|
self.extra_downsamples = nn.ModuleList()
|
||
|
for i in range(extra_levels):
|
||
|
extra_conv = ConvModule(
|
||
|
out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||
|
self.extra_downsamples.append(
|
||
|
nn.Sequential(extra_conv, nn.MaxPool2d(2, 2)))
|
||
|
|
||
|
# add NAS FPN connections
|
||
|
self.fpn_stages = ModuleList()
|
||
|
for _ in range(self.stack_times):
|
||
|
stage = nn.ModuleDict()
|
||
|
# gp(p6, p4) -> p4_1
|
||
|
stage['gp_64_4'] = GlobalPoolingCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
# sum(p4_1, p4) -> p4_2
|
||
|
stage['sum_44_4'] = SumCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
# sum(p4_2, p3) -> p3_out
|
||
|
stage['sum_43_3'] = SumCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
# sum(p3_out, p4_2) -> p4_out
|
||
|
stage['sum_34_4'] = SumCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
# sum(p5, gp(p4_out, p3_out)) -> p5_out
|
||
|
stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
|
||
|
stage['sum_55_5'] = SumCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
# sum(p7, gp(p5_out, p4_2)) -> p7_out
|
||
|
stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
|
||
|
stage['sum_77_7'] = SumCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
# gp(p7_out, p5_out) -> p6_out
|
||
|
stage['gp_75_6'] = GlobalPoolingCell(
|
||
|
in_channels=out_channels,
|
||
|
out_channels=out_channels,
|
||
|
out_norm_cfg=norm_cfg)
|
||
|
self.fpn_stages.append(stage)
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
"""Forward function."""
|
||
|
# build P3-P5
|
||
|
feats = [
|
||
|
lateral_conv(inputs[i + self.start_level])
|
||
|
for i, lateral_conv in enumerate(self.lateral_convs)
|
||
|
]
|
||
|
# build P6-P7 on top of P5
|
||
|
for downsample in self.extra_downsamples:
|
||
|
feats.append(downsample(feats[-1]))
|
||
|
|
||
|
p3, p4, p5, p6, p7 = feats
|
||
|
|
||
|
for stage in self.fpn_stages:
|
||
|
# gp(p6, p4) -> p4_1
|
||
|
p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
|
||
|
# sum(p4_1, p4) -> p4_2
|
||
|
p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
|
||
|
# sum(p4_2, p3) -> p3_out
|
||
|
p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
|
||
|
# sum(p3_out, p4_2) -> p4_out
|
||
|
p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
|
||
|
# sum(p5, gp(p4_out, p3_out)) -> p5_out
|
||
|
p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
|
||
|
p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
|
||
|
# sum(p7, gp(p5_out, p4_2)) -> p7_out
|
||
|
p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
|
||
|
p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
|
||
|
# gp(p7_out, p5_out) -> p6_out
|
||
|
p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
|
||
|
|
||
|
return p3, p4, p5, p6, p7
|