133 lines
4.5 KiB
Python
133 lines
4.5 KiB
Python
|
import torch.nn as nn
|
||
|
from mmcv.cnn import kaiming_init, normal_init
|
||
|
|
||
|
from .registry import NECKS
|
||
|
|
||
|
|
||
|
@NECKS.register_module
|
||
|
class LinearNeck(nn.Module):
|
||
|
|
||
|
def __init__(self, in_channels, out_channels, with_avg_pool=True):
|
||
|
super(LinearNeck, self).__init__()
|
||
|
self.with_avg_pool = with_avg_pool
|
||
|
if with_avg_pool:
|
||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||
|
self.fc = nn.Linear(in_channels, out_channels)
|
||
|
|
||
|
def init_weights(self, init_linear='normal'):
|
||
|
assert init_linear in ['normal', 'kaiming'], \
|
||
|
"Undefined init_linear: {}".format(init_linear)
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, nn.Linear):
|
||
|
if init_linear == 'normal':
|
||
|
normal_init(m, std=0.01)
|
||
|
else:
|
||
|
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||
|
elif isinstance(m,
|
||
|
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||
|
if m.weight is not None:
|
||
|
nn.init.constant_(m.weight, 1)
|
||
|
if m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(self, x):
|
||
|
assert len(x) == 1
|
||
|
if self.with_avg_pool:
|
||
|
x = self.avgpool(x[0])
|
||
|
return [self.fc(x.view(x.size(0), -1))]
|
||
|
|
||
|
|
||
|
@NECKS.register_module
|
||
|
class NonLinearNeckV0(nn.Module):
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels,
|
||
|
hid_channels,
|
||
|
out_channels,
|
||
|
with_avg_pool=True):
|
||
|
super(NonLinearNeckV0, self).__init__()
|
||
|
self.with_avg_pool = with_avg_pool
|
||
|
if with_avg_pool:
|
||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||
|
self.mlp = nn.Sequential(
|
||
|
nn.Linear(in_channels, hid_channels),
|
||
|
nn.BatchNorm1d(hid_channels, momentum=0.001, affine=False),
|
||
|
nn.ReLU(inplace=True), nn.Dropout(),
|
||
|
nn.Linear(hid_channels, out_channels), nn.ReLU(inplace=True))
|
||
|
|
||
|
def init_weights(self, init_linear='normal'):
|
||
|
assert init_linear in ['normal', 'kaiming'], \
|
||
|
"Undefined init_linear: {}".format(init_linear)
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, nn.Linear):
|
||
|
if init_linear == 'normal':
|
||
|
normal_init(m, std=0.01)
|
||
|
else:
|
||
|
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||
|
elif isinstance(m,
|
||
|
(nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||
|
if m.weight is not None:
|
||
|
nn.init.constant_(m.weight, 1)
|
||
|
if m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(self, x):
|
||
|
assert len(x) == 1
|
||
|
if self.with_avg_pool:
|
||
|
x = self.avgpool(x[0])
|
||
|
return [self.mlp(x.view(x.size(0), -1))]
|
||
|
|
||
|
|
||
|
@NECKS.register_module
|
||
|
class NonLinearNeckV1(nn.Module):
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels,
|
||
|
hid_channels,
|
||
|
out_channels,
|
||
|
with_avg_pool=True):
|
||
|
super(NonLinearNeckV1, self).__init__()
|
||
|
self.with_avg_pool = with_avg_pool
|
||
|
if with_avg_pool:
|
||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||
|
self.mlp = nn.Sequential(
|
||
|
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
|
||
|
nn.Linear(hid_channels, out_channels))
|
||
|
|
||
|
def init_weights(self, init_linear='normal'):
|
||
|
assert init_linear in ['normal', 'kaiming'], \
|
||
|
"Undefined init_linear: {}".format(init_linear)
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, nn.Linear):
|
||
|
if init_linear == 'normal':
|
||
|
normal_init(m, std=0.01)
|
||
|
else:
|
||
|
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||
|
elif isinstance(m,
|
||
|
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||
|
if m.weight is not None:
|
||
|
nn.init.constant_(m.weight, 1)
|
||
|
if m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(self, x):
|
||
|
assert len(x) == 1
|
||
|
if self.with_avg_pool:
|
||
|
x = self.avgpool(x[0])
|
||
|
return [self.mlp(x.view(x.size(0), -1))]
|
||
|
|
||
|
|
||
|
@NECKS.register_module
|
||
|
class AvgPoolNeck(nn.Module):
|
||
|
|
||
|
def __init__(self):
|
||
|
super(AvgPoolNeck, self).__init__()
|
||
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||
|
|
||
|
def init_weights(self, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def forward(self, x):
|
||
|
assert len(x) == 1
|
||
|
return [self.avg_pool(x[0])]
|