354 lines
11 KiB
Python
354 lines
11 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from packaging import version
|
|
from mmcv.cnn import kaiming_init, normal_init
|
|
|
|
from .registry import NECKS
|
|
from .utils import build_norm_layer
|
|
|
|
|
|
def _init_weights(module, init_linear='normal', std=0.01, bias=0.):
|
|
assert init_linear in ['normal', 'kaiming'], \
|
|
"Undefined init_linear: {}".format(init_linear)
|
|
for m in module.modules():
|
|
if isinstance(m, nn.Linear):
|
|
if init_linear == 'normal':
|
|
normal_init(m, std=std, bias=bias)
|
|
else:
|
|
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
|
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d,
|
|
nn.GroupNorm, nn.SyncBatchNorm)):
|
|
if m.weight is not None:
|
|
nn.init.constant_(m.weight, 1)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
@NECKS.register_module
|
|
class LinearNeck(nn.Module):
|
|
"""Linear neck: fc only.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, with_avg_pool=True):
|
|
super(LinearNeck, self).__init__()
|
|
self.with_avg_pool = with_avg_pool
|
|
if with_avg_pool:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(in_channels, out_channels)
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
_init_weights(self, init_linear)
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1
|
|
x = x[0]
|
|
if self.with_avg_pool:
|
|
x = self.avgpool(x)
|
|
return [self.fc(x.view(x.size(0), -1))]
|
|
|
|
|
|
@NECKS.register_module
|
|
class RelativeLocNeck(nn.Module):
|
|
"""Relative patch location neck: fc-bn-relu-dropout.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
sync_bn=False,
|
|
with_avg_pool=True):
|
|
super(RelativeLocNeck, self).__init__()
|
|
self.with_avg_pool = with_avg_pool
|
|
if with_avg_pool:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
if version.parse(torch.__version__) < version.parse("1.4.0"):
|
|
self.expand_for_syncbn = True
|
|
else:
|
|
self.expand_for_syncbn = False
|
|
|
|
self.fc = nn.Linear(in_channels * 2, out_channels)
|
|
if sync_bn:
|
|
_, self.bn = build_norm_layer(
|
|
dict(type='SyncBN', momentum=0.003),
|
|
out_channels)
|
|
else:
|
|
self.bn = nn.BatchNorm1d(
|
|
out_channels, momentum=0.003)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.drop = nn.Dropout()
|
|
self.sync_bn = sync_bn
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
_init_weights(self, init_linear, std=0.005, bias=0.1)
|
|
|
|
def _forward_syncbn(self, module, x):
|
|
assert x.dim() == 2
|
|
if self.expand_for_syncbn:
|
|
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1
|
|
x = x[0]
|
|
if self.with_avg_pool:
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
if self.sync_bn:
|
|
x = self._forward_syncbn(self.bn, x)
|
|
else:
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
x = self.drop(x)
|
|
return [x]
|
|
|
|
|
|
@NECKS.register_module
|
|
class NonLinearNeckV0(nn.Module):
|
|
"""The non-linear neck in ODC, fc-bn-relu-dropout-fc-relu.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
hid_channels,
|
|
out_channels,
|
|
sync_bn=False,
|
|
with_avg_pool=True):
|
|
super(NonLinearNeckV0, self).__init__()
|
|
self.with_avg_pool = with_avg_pool
|
|
if with_avg_pool:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
if version.parse(torch.__version__) < version.parse("1.4.0"):
|
|
self.expand_for_syncbn = True
|
|
else:
|
|
self.expand_for_syncbn = False
|
|
|
|
self.fc0 = nn.Linear(in_channels, hid_channels)
|
|
if sync_bn:
|
|
_, self.bn0 = build_norm_layer(
|
|
dict(type='SyncBN', momentum=0.001, affine=False),
|
|
hid_channels)
|
|
else:
|
|
self.bn0 = nn.BatchNorm1d(
|
|
hid_channels, momentum=0.001, affine=False)
|
|
|
|
self.fc1 = nn.Linear(hid_channels, out_channels)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.drop = nn.Dropout()
|
|
self.sync_bn = sync_bn
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
_init_weights(self, init_linear)
|
|
|
|
def _forward_syncbn(self, module, x):
|
|
assert x.dim() == 2
|
|
if self.expand_for_syncbn:
|
|
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1
|
|
x = x[0]
|
|
if self.with_avg_pool:
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc0(x)
|
|
if self.sync_bn:
|
|
x = self._forward_syncbn(self.bn0, x)
|
|
else:
|
|
x = self.bn0(x)
|
|
x = self.relu(x)
|
|
x = self.drop(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
return [x]
|
|
|
|
|
|
@NECKS.register_module
|
|
class NonLinearNeckV1(nn.Module):
|
|
"""The non-linear neck in MoCo v2: fc-relu-fc.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
hid_channels,
|
|
out_channels,
|
|
with_avg_pool=True):
|
|
super(NonLinearNeckV1, self).__init__()
|
|
self.with_avg_pool = with_avg_pool
|
|
if with_avg_pool:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
|
|
nn.Linear(hid_channels, out_channels))
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
_init_weights(self, init_linear)
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1
|
|
x = x[0]
|
|
if self.with_avg_pool:
|
|
x = self.avgpool(x)
|
|
return [self.mlp(x.view(x.size(0), -1))]
|
|
|
|
|
|
@NECKS.register_module
|
|
class NonLinearNeckV2(nn.Module):
|
|
"""The non-linear neck in byol: fc-bn-relu-fc.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
hid_channels,
|
|
out_channels,
|
|
with_avg_pool=True):
|
|
super(NonLinearNeckV2, self).__init__()
|
|
self.with_avg_pool = with_avg_pool
|
|
if with_avg_pool:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(in_channels, hid_channels),
|
|
nn.BatchNorm1d(hid_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(hid_channels, out_channels))
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
_init_weights(self, init_linear)
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1, "Got: {}".format(len(x))
|
|
x = x[0]
|
|
if self.with_avg_pool:
|
|
x = self.avgpool(x)
|
|
return [self.mlp(x.view(x.size(0), -1))]
|
|
|
|
|
|
@NECKS.register_module
|
|
class NonLinearNeckSimCLR(nn.Module):
|
|
"""SimCLR non-linear neck.
|
|
|
|
Structure: fc(no_bias)-bn(has_bias)-[relu-fc(no_bias)-bn(no_bias)].
|
|
The substructures in [] can be repeated. For the SimCLR default setting,
|
|
the repeat time is 1.
|
|
However, PyTorch does not support to specify (weight=True, bias=False).
|
|
It only support \"affine\" including the weight and bias. Hence, the
|
|
second BatchNorm has bias in this implementation. This is different from
|
|
the official implementation of SimCLR.
|
|
Since SyncBatchNorm in pytorch<1.4.0 does not support 2D input, the input is
|
|
expanded to 4D with shape: (N,C,1,1). Not sure if this workaround
|
|
has no bugs. See the pull request here:
|
|
https://github.com/pytorch/pytorch/pull/29626.
|
|
|
|
Args:
|
|
num_layers (int): Number of fc layers, it is 2 in the SimCLR default setting.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
hid_channels,
|
|
out_channels,
|
|
num_layers=2,
|
|
sync_bn=True,
|
|
with_bias=False,
|
|
with_last_bn=True,
|
|
with_avg_pool=True):
|
|
super(NonLinearNeckSimCLR, self).__init__()
|
|
self.sync_bn = sync_bn
|
|
self.with_last_bn = with_last_bn
|
|
self.with_avg_pool = with_avg_pool
|
|
if with_avg_pool:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
if version.parse(torch.__version__) < version.parse("1.4.0"):
|
|
self.expand_for_syncbn = True
|
|
else:
|
|
self.expand_for_syncbn = False
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias)
|
|
if sync_bn:
|
|
_, self.bn0 = build_norm_layer(
|
|
dict(type='SyncBN'), hid_channels)
|
|
else:
|
|
self.bn0 = nn.BatchNorm1d(hid_channels)
|
|
|
|
self.fc_names = []
|
|
self.bn_names = []
|
|
for i in range(1, num_layers):
|
|
this_channels = out_channels if i == num_layers - 1 \
|
|
else hid_channels
|
|
self.add_module(
|
|
"fc{}".format(i),
|
|
nn.Linear(hid_channels, this_channels, bias=with_bias))
|
|
self.fc_names.append("fc{}".format(i))
|
|
if i != num_layers - 1 or self.with_last_bn:
|
|
if sync_bn:
|
|
self.add_module(
|
|
"bn{}".format(i),
|
|
build_norm_layer(dict(type='SyncBN'), this_channels)[1])
|
|
else:
|
|
self.add_module(
|
|
"bn{}".format(i),
|
|
nn.BatchNorm1d(this_channels))
|
|
self.bn_names.append("bn{}".format(i))
|
|
else:
|
|
self.bn_names.append(None)
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
_init_weights(self, init_linear)
|
|
|
|
def _forward_syncbn(self, module, x):
|
|
assert x.dim() == 2
|
|
if self.expand_for_syncbn:
|
|
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
else:
|
|
x = module(x)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1
|
|
x = x[0]
|
|
if self.with_avg_pool:
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc0(x)
|
|
if self.sync_bn:
|
|
x = self._forward_syncbn(self.bn0, x)
|
|
else:
|
|
x = self.bn0(x)
|
|
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
|
|
fc = getattr(self, fc_name)
|
|
x = self.relu(x)
|
|
x = fc(x)
|
|
if bn_name is not None:
|
|
bn = getattr(self, bn_name)
|
|
if self.sync_bn:
|
|
x = self._forward_syncbn(bn, x)
|
|
else:
|
|
x = bn(x)
|
|
return [x]
|
|
|
|
|
|
@NECKS.register_module
|
|
class AvgPoolNeck(nn.Module):
|
|
"""Average pooling neck.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(AvgPoolNeck, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
def init_weights(self, **kwargs):
|
|
pass
|
|
|
|
def forward(self, x):
|
|
assert len(x) == 1
|
|
return [self.avg_pool(x[0])]
|