mmselfsup/openselfsup/models/necks.py

133 lines
4.5 KiB
Python
Raw Normal View History

2020-06-16 00:05:18 +08:00
import torch.nn as nn
from mmcv.cnn import kaiming_init, normal_init
from .registry import NECKS
@NECKS.register_module
class LinearNeck(nn.Module):
def __init__(self, in_channels, out_channels, with_avg_pool=True):
super(LinearNeck, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(in_channels, out_channels)
def init_weights(self, init_linear='normal'):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=0.01)
else:
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
assert len(x) == 1
if self.with_avg_pool:
x = self.avgpool(x[0])
return [self.fc(x.view(x.size(0), -1))]
@NECKS.register_module
class NonLinearNeckV0(nn.Module):
def __init__(self,
in_channels,
hid_channels,
out_channels,
with_avg_pool=True):
super(NonLinearNeckV0, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels),
nn.BatchNorm1d(hid_channels, momentum=0.001, affine=False),
nn.ReLU(inplace=True), nn.Dropout(),
nn.Linear(hid_channels, out_channels), nn.ReLU(inplace=True))
def init_weights(self, init_linear='normal'):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=0.01)
else:
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
assert len(x) == 1
if self.with_avg_pool:
x = self.avgpool(x[0])
return [self.mlp(x.view(x.size(0), -1))]
@NECKS.register_module
class NonLinearNeckV1(nn.Module):
def __init__(self,
in_channels,
hid_channels,
out_channels,
with_avg_pool=True):
super(NonLinearNeckV1, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
nn.Linear(hid_channels, out_channels))
def init_weights(self, init_linear='normal'):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=0.01)
else:
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
assert len(x) == 1
if self.with_avg_pool:
x = self.avgpool(x[0])
return [self.mlp(x.view(x.size(0), -1))]
@NECKS.register_module
class AvgPoolNeck(nn.Module):
def __init__(self):
super(AvgPoolNeck, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def init_weights(self, **kwargs):
pass
def forward(self, x):
assert len(x) == 1
return [self.avg_pool(x[0])]