mirror of https://github.com/hero-y/BHRL
180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
import warnings
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import VGG
|
|
from mmcv.runner import BaseModule, Sequential
|
|
|
|
from ..builder import BACKBONES
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class SSDVGG(VGG, BaseModule):
|
|
"""VGG Backbone network for single-shot-detection.
|
|
|
|
Args:
|
|
input_size (int): width and height of input, from {300, 512}.
|
|
depth (int): Depth of vgg, from {11, 13, 16, 19}.
|
|
out_indices (Sequence[int]): Output from which stages.
|
|
pretrained (str, optional): model pretrained path. Default: None
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None
|
|
|
|
Example:
|
|
>>> self = SSDVGG(input_size=300, depth=11)
|
|
>>> self.eval()
|
|
>>> inputs = torch.rand(1, 3, 300, 300)
|
|
>>> level_outputs = self.forward(inputs)
|
|
>>> for level_out in level_outputs:
|
|
... print(tuple(level_out.shape))
|
|
(1, 1024, 19, 19)
|
|
(1, 512, 10, 10)
|
|
(1, 256, 5, 5)
|
|
(1, 256, 3, 3)
|
|
(1, 256, 1, 1)
|
|
"""
|
|
extra_setting = {
|
|
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
|
|
512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
|
|
}
|
|
|
|
def __init__(self,
|
|
input_size,
|
|
depth,
|
|
with_last_pool=False,
|
|
ceil_mode=True,
|
|
out_indices=(3, 4),
|
|
out_feature_indices=(22, 34),
|
|
l2_norm_scale=20.,
|
|
pretrained=None,
|
|
init_cfg=None):
|
|
# TODO: in_channels for mmcv.VGG
|
|
super(SSDVGG, self).__init__(
|
|
depth,
|
|
with_last_pool=with_last_pool,
|
|
ceil_mode=ceil_mode,
|
|
out_indices=out_indices)
|
|
assert input_size in (300, 512)
|
|
self.input_size = input_size
|
|
|
|
self.features.add_module(
|
|
str(len(self.features)),
|
|
nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
|
|
self.features.add_module(
|
|
str(len(self.features)),
|
|
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6))
|
|
self.features.add_module(
|
|
str(len(self.features)), nn.ReLU(inplace=True))
|
|
self.features.add_module(
|
|
str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1))
|
|
self.features.add_module(
|
|
str(len(self.features)), nn.ReLU(inplace=True))
|
|
self.out_feature_indices = out_feature_indices
|
|
|
|
self.inplanes = 1024
|
|
self.extra = self._make_extra_layers(self.extra_setting[input_size])
|
|
self.l2_norm = L2Norm(
|
|
self.features[out_feature_indices[0] - 1].out_channels,
|
|
l2_norm_scale)
|
|
|
|
assert not (init_cfg and pretrained), \
|
|
'init_cfg and pretrained cannot be setting at the same time'
|
|
if isinstance(pretrained, str):
|
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
|
'please use "init_cfg" instead')
|
|
self.init_cfg = [dict(type='Pretrained', checkpoint=pretrained)]
|
|
elif pretrained is None:
|
|
if init_cfg is None:
|
|
self.init_cfg = [
|
|
dict(type='Kaiming', layer='Conv2d'),
|
|
dict(type='Constant', val=1, layer='BatchNorm2d'),
|
|
dict(type='Normal', std=0.01, layer='Linear'),
|
|
]
|
|
else:
|
|
raise TypeError('pretrained must be a str or None')
|
|
|
|
if init_cfg is None:
|
|
self.init_cfg += [
|
|
dict(
|
|
type='Xavier',
|
|
distribution='uniform',
|
|
override=dict(name='extra')),
|
|
dict(
|
|
type='Constant',
|
|
val=self.l2_norm.scale,
|
|
override=dict(name='l2_norm'))
|
|
]
|
|
|
|
def init_weights(self, pretrained=None):
|
|
super(VGG, self).init_weights()
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
outs = []
|
|
for i, layer in enumerate(self.features):
|
|
x = layer(x)
|
|
if i in self.out_feature_indices:
|
|
outs.append(x)
|
|
for i, layer in enumerate(self.extra):
|
|
x = F.relu(layer(x), inplace=True)
|
|
if i % 2 == 1:
|
|
outs.append(x)
|
|
outs[0] = self.l2_norm(outs[0])
|
|
if len(outs) == 1:
|
|
return outs[0]
|
|
else:
|
|
return tuple(outs)
|
|
|
|
def _make_extra_layers(self, outplanes):
|
|
layers = []
|
|
kernel_sizes = (1, 3)
|
|
num_layers = 0
|
|
outplane = None
|
|
for i in range(len(outplanes)):
|
|
if self.inplanes == 'S':
|
|
self.inplanes = outplane
|
|
continue
|
|
k = kernel_sizes[num_layers % 2]
|
|
if outplanes[i] == 'S':
|
|
outplane = outplanes[i + 1]
|
|
conv = nn.Conv2d(
|
|
self.inplanes, outplane, k, stride=2, padding=1)
|
|
else:
|
|
outplane = outplanes[i]
|
|
conv = nn.Conv2d(
|
|
self.inplanes, outplane, k, stride=1, padding=0)
|
|
layers.append(conv)
|
|
self.inplanes = outplanes[i]
|
|
num_layers += 1
|
|
if self.input_size == 512:
|
|
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))
|
|
|
|
return Sequential(*layers)
|
|
|
|
|
|
class L2Norm(nn.Module):
|
|
|
|
def __init__(self, n_dims, scale=20., eps=1e-10):
|
|
"""L2 normalization layer.
|
|
|
|
Args:
|
|
n_dims (int): Number of dimensions to be normalized
|
|
scale (float, optional): Defaults to 20..
|
|
eps (float, optional): Used to avoid division by zero.
|
|
Defaults to 1e-10.
|
|
"""
|
|
super(L2Norm, self).__init__()
|
|
self.n_dims = n_dims
|
|
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
|
|
self.eps = eps
|
|
self.scale = scale
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
# normalization layer convert to FP32 in FP16 training
|
|
x_float = x.float()
|
|
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
|
|
return (self.weight[None, :, None, None].float().expand_as(x_float) *
|
|
x_float / norm).type_as(x)
|