mirror of https://github.com/hero-y/BHRL
100 lines
3.8 KiB
Python
100 lines
3.8 KiB
Python
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.runner import BaseModule
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
@NECKS.register_module()
|
|
class ChannelMapper(BaseModule):
|
|
r"""Channel Mapper to reduce/increase channels of backbone features.
|
|
|
|
This is used to reduce/increase channels of backbone features.
|
|
|
|
Args:
|
|
in_channels (List[int]): Number of input channels per scale.
|
|
out_channels (int): Number of output channels (used at each scale).
|
|
kernel_size (int, optional): kernel_size for reducing channels (used
|
|
at each scale). Default: 3.
|
|
conv_cfg (dict, optional): Config dict for convolution layer.
|
|
Default: None.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: None.
|
|
act_cfg (dict, optional): Config dict for activation layer in
|
|
ConvModule. Default: dict(type='ReLU').
|
|
num_outs (int, optional): Number of output feature maps. There
|
|
would be extra_convs when num_outs larger than the length
|
|
of in_channels.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Example:
|
|
>>> import torch
|
|
>>> in_channels = [2, 3, 5, 7]
|
|
>>> scales = [340, 170, 84, 43]
|
|
>>> inputs = [torch.rand(1, c, s, s)
|
|
... for c, s in zip(in_channels, scales)]
|
|
>>> self = ChannelMapper(in_channels, 11, 3).eval()
|
|
>>> outputs = self.forward(inputs)
|
|
>>> for i in range(len(outputs)):
|
|
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
|
outputs[0].shape = torch.Size([1, 11, 340, 340])
|
|
outputs[1].shape = torch.Size([1, 11, 170, 170])
|
|
outputs[2].shape = torch.Size([1, 11, 84, 84])
|
|
outputs[3].shape = torch.Size([1, 11, 43, 43])
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=dict(type='ReLU'),
|
|
num_outs=None,
|
|
init_cfg=dict(
|
|
type='Xavier', layer='Conv2d', distribution='uniform')):
|
|
super(ChannelMapper, self).__init__(init_cfg)
|
|
assert isinstance(in_channels, list)
|
|
self.extra_convs = None
|
|
if num_outs is None:
|
|
num_outs = len(in_channels)
|
|
self.convs = nn.ModuleList()
|
|
for in_channel in in_channels:
|
|
self.convs.append(
|
|
ConvModule(
|
|
in_channel,
|
|
out_channels,
|
|
kernel_size,
|
|
padding=(kernel_size - 1) // 2,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
if num_outs > len(in_channels):
|
|
self.extra_convs = nn.ModuleList()
|
|
for i in range(len(in_channels), num_outs):
|
|
if i == len(in_channels):
|
|
in_channel = in_channels[-1]
|
|
else:
|
|
in_channel = out_channels
|
|
self.extra_convs.append(
|
|
ConvModule(
|
|
in_channel,
|
|
out_channels,
|
|
3,
|
|
stride=2,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
assert len(inputs) == len(self.convs)
|
|
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
|
|
if self.extra_convs:
|
|
for i in range(len(self.extra_convs)):
|
|
if i == 0:
|
|
outs.append(self.extra_convs[0](inputs[-1]))
|
|
else:
|
|
outs.append(self.extra_convs[i](outs[-1]))
|
|
return tuple(outs)
|