EasyCV/easycv/models/selfsup/necks.py

625 lines
20 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
import torch
import torch.nn as nn
from packaging import version
from timm.models.vision_transformer import Block
from easycv.models.utils import get_2d_sincos_pos_embed
from ..registry import NECKS
from ..utils import _init_weights, build_norm_layer, trunc_normal_
@NECKS.register_module
class DINONeck(nn.Module):
def __init__(self,
in_dim,
out_dim,
use_bn=False,
norm_last_layer=True,
nlayers=3,
hidden_dim=2048,
bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
# layers.append(build_norm_layer(dict(type='SyncBN'), hidden_dim)[1])
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
# layers.append(build_norm_layer(dict(type='SyncBN'), hidden_dim)[1])
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
@NECKS.register_module
class MoBYMLP(nn.Module):
def __init__(self,
in_channels=256,
hid_channels=4096,
out_channels=256,
num_layers=2,
with_avg_pool=True):
super(MoBYMLP, self).__init__()
# hidden layers
linear_hidden = [nn.Identity()]
for i in range(num_layers - 1):
linear_hidden.append(
nn.Linear(in_channels if i == 0 else hid_channels,
hid_channels))
linear_hidden.append(nn.BatchNorm1d(hid_channels))
linear_hidden.append(nn.ReLU(inplace=True))
self.linear_hidden = nn.Sequential(*linear_hidden)
self.linear_out = nn.Linear(
in_channels if num_layers == 1 else hid_channels,
out_channels) if num_layers >= 1 else nn.Identity()
self.with_avg_pool = True
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = x[0]
if self.with_avg_pool and len(x.shape) == 4:
bs = x.shape[0]
x = self.avg_pool(x).view([bs, -1])
# print(x.shape)
# exit()
x = self.linear_hidden(x)
x = self.linear_out(x)
return [x]
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear)
@NECKS.register_module
class NonLinearNeckSwav(nn.Module):
'''The non-linear neck in byol: fc-syncbn-relu-fc
'''
def __init__(self,
in_channels,
hid_channels,
out_channels,
with_avg_pool=True,
export=False):
super(NonLinearNeckSwav, self).__init__()
if version.parse(torch.__version__) < version.parse('1.4.0'):
self.expand_for_syncbn = True
else:
self.expand_for_syncbn = False
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.export = export
if not self.export:
_, self.bn0 = build_norm_layer(dict(type='SyncBN'), hid_channels)
else:
_, self.bn0 = build_norm_layer(dict(type='BN'), hid_channels)
self.fc0 = nn.Linear(in_channels, hid_channels)
self.relu = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(hid_channels, out_channels)
def _forward_syncbn(self, module, x):
assert x.dim() == 2
# syncbn < torch1.4.0 or bn while export need unsqueeze 4D dims
if self.expand_for_syncbn or self.export:
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
else:
x = module(x)
return x
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear)
def forward(self, x):
assert len(x) == 1 or len(x) == 2, 'Got: {}'.format(
len(x)) # fit for vit model
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self._forward_syncbn(self.bn0, x)
x = self.relu(x)
x = self.fc1(x)
return [x]
@NECKS.register_module
class NonLinearNeckV0(nn.Module):
'''The non-linear neck in ODC, fc-bn-relu-dropout-fc-relu
'''
def __init__(self,
in_channels,
hid_channels,
out_channels,
sync_bn=False,
with_avg_pool=True):
super(NonLinearNeckV0, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if version.parse(torch.__version__) < version.parse('1.4.0'):
self.expand_for_syncbn = True
else:
self.expand_for_syncbn = False
self.fc0 = nn.Linear(in_channels, hid_channels)
if sync_bn:
_, self.bn0 = build_norm_layer(
dict(type='SyncBN', momentum=0.001, affine=False),
hid_channels)
else:
self.bn0 = nn.BatchNorm1d(
hid_channels, momentum=0.001, affine=False)
self.fc1 = nn.Linear(hid_channels, out_channels)
self.relu = nn.ReLU(inplace=True)
self.drop = nn.Dropout()
self.sync_bn = sync_bn
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear)
def _forward_syncbn(self, module, x):
assert x.dim() == 2
if self.expand_for_syncbn:
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
else:
x = module(x)
return x
def forward(self, x):
assert len(x) == 1 or len(x) == 2 # to fit vit model
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
if self.sync_bn:
x = self._forward_syncbn(self.bn0, x)
else:
x = self.bn0(x)
x = self.relu(x)
x = self.drop(x)
x = self.fc1(x)
x = self.relu(x)
return [x]
@NECKS.register_module
class NonLinearNeckV1(nn.Module):
'''The non-linear neck in MoCO v2: fc-relu-fc
'''
def __init__(self,
in_channels,
hid_channels,
out_channels,
with_avg_pool=True):
super(NonLinearNeckV1, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
nn.Linear(hid_channels, out_channels))
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear)
def forward(self, x):
# assert len(x) == 1 or len(x)==2 # to fit vit model, vit model extract 2 features, we use first
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
return [self.mlp(x.view(x.size(0), -1))]
@NECKS.register_module
class NonLinearNeckV2(nn.Module):
'''The non-linear neck in byol: fc-bn-relu-fc
'''
def __init__(self,
in_channels,
hid_channels,
out_channels,
with_avg_pool=True):
super(NonLinearNeckV2, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels), nn.BatchNorm1d(hid_channels),
nn.ReLU(inplace=True), nn.Linear(hid_channels, out_channels))
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear)
def forward(self, x):
assert len(x) == 1 or len(x) == 2, 'Got: {}'.format(
len(x)) # to fit vit model
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
return [self.mlp(x.view(x.size(0), -1))]
@NECKS.register_module
class NonLinearNeckSimCLR(nn.Module):
'''SimCLR non-linear neck.
Structure: fc(no_bias)-bn(has_bias)-[relu-fc(no_bias)-bn(no_bias)].
The substructures in [] can be repeated. For the SimCLR default setting,
the repeat time is 1.
However, PyTorch does not support to specify (weight=True, bias=False).
It only support \"affine\" including the weight and bias. Hence, the
second BatchNorm has bias in this implementation. This is different from
the offical implementation of SimCLR.
Since SyncBatchNorm in pytorch<1.4.0 does not support 2D input, the input is
expanded to 4D with shape: (N,C,1,1). I am not sure if this workaround
has no bugs. See the pull request here:
https://github.com/pytorch/pytorch/pull/29626
Args:
in_channels: input channel number
hid_channels: hidden channels
out_channels: output channel number
num_layers (int): number of fc layers, it is 2 in the SimCLR default setting.
with_avg_pool: output with average pooling
'''
def __init__(self,
in_channels,
hid_channels,
out_channels,
num_layers=2,
with_avg_pool=True):
super(NonLinearNeckSimCLR, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if version.parse(torch.__version__) < version.parse('1.4.0'):
self.expand_for_syncbn = True
else:
self.expand_for_syncbn = False
self.relu = nn.ReLU(inplace=True)
self.fc0 = nn.Linear(in_channels, hid_channels, bias=False)
_, self.bn0 = build_norm_layer(dict(type='SyncBN'), hid_channels)
self.fc_names = []
self.bn_names = []
for i in range(1, num_layers):
this_channels = out_channels if i == num_layers - 1 \
else hid_channels
self.add_module('fc{}'.format(i),
nn.Linear(hid_channels, this_channels, bias=False))
self.add_module(
'bn{}'.format(i),
build_norm_layer(dict(type='SyncBN'), this_channels)[1])
self.fc_names.append('fc{}'.format(i))
self.bn_names.append('bn{}'.format(i))
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear)
def _forward_syncbn(self, module, x):
assert x.dim() == 2
if self.expand_for_syncbn:
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
else:
x = module(x)
return x
def forward(self, x):
assert len(x) == 1 or len(x) == 2 # to fit vit model
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self._forward_syncbn(self.bn0, x)
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
fc = getattr(self, fc_name)
bn = getattr(self, bn_name)
x = self.relu(x)
x = fc(x)
x = self._forward_syncbn(bn, x)
return [x]
@NECKS.register_module
class RelativeLocNeck(nn.Module):
'''Relative patch location neck: fc-bn-relu-dropout
'''
def __init__(self,
in_channels,
out_channels,
sync_bn=False,
with_avg_pool=True):
super(RelativeLocNeck, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if version.parse(torch.__version__) < version.parse('1.4.0'):
self.expand_for_syncbn = True
else:
self.expand_for_syncbn = False
self.fc = nn.Linear(in_channels * 2, out_channels)
if sync_bn:
_, self.bn = build_norm_layer(
dict(type='SyncBN', momentum=0.003), out_channels)
else:
self.bn = nn.BatchNorm1d(out_channels, momentum=0.003)
self.relu = nn.ReLU(inplace=True)
self.drop = nn.Dropout()
self.sync_bn = sync_bn
def init_weights(self, init_linear='normal'):
_init_weights(self, init_linear, std=0.005, bias=0.1)
def _forward_syncbn(self, module, x):
assert x.dim() == 2
if self.expand_for_syncbn:
x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
else:
x = module(x)
return x
def forward(self, x):
assert len(x) == 1 or len(x) == 2 # to fit vit model
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
if self.sync_bn:
x = self._forward_syncbn(self.bn, x)
else:
x = self.bn(x)
x = self.relu(x)
x = self.drop(x)
return [x]
@NECKS.register_module
class MAENeck(nn.Module):
"""MAE decoder
Args:
num_patches(int): number of patches from encoder
embed_dim(int): encoder embedding dimension
patch_size(int): encoder patch size
in_chans(int): input image channels
decoder_embed_dim(int): decoder embedding dimension
decoder_depth(int): number of decoder layers
decoder_num_heads(int): Parallel attention heads
mlp_ratio(float): mlp ratio
norm_layer: type of normalization layer
"""
def __init__(self,
num_patches,
embed_dim=768,
patch_size=16,
in_chans=3,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
self.num_patches = num_patches
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_embed_dim),
requires_grad=False)
self.decoder_blocks = nn.ModuleList([
Block(
decoder_embed_dim,
decoder_num_heads,
mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer) for _ in range(decoder_depth)
])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(
decoder_embed_dim, patch_size**2 * in_chans, bias=True)
def init_weights(self):
torch.nn.init.normal_(self.mask_token, std=.02)
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.num_patches**.5),
cls_token=True)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
@NECKS.register_module
class FastConvMAENeck(MAENeck):
"""Fast ConvMAE decoder, refer to: https://github.com/Alpha-VL/FastConvMAE
Args:
num_patches (int): number of patches from encoder
embed_dim (int): encoder embedding dimension
patch_size (int): encoder patch size
in_channels (int): input image channels
decoder_embed_dim (int): decoder embedding dimension
decoder_depth (int): number of decoder layers
decoder_num_heads (int): Parallel attention heads
mlp_ratio (float): mlp ratio
norm_layer: type of normalization layer
"""
def __init__(self,
num_patches,
embed_dim=768,
patch_size=16,
in_channels=3,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__(
num_patches=num_patches,
embed_dim=embed_dim,
patch_size=patch_size,
in_chans=in_channels,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer)
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, decoder_embed_dim),
requires_grad=False)
def init_weights(self):
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.num_patches**.5),
cls_token=False)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
torch.nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(super()._init_weights)
def forward(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0],
ids_restore.shape[1] - x.shape[1],
1)
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
B, L, C = x_.shape
x_split1 = x_[:B // 4, :, :]
x_split2 = torch.roll(x_[B // 4:B // 4 * 2, :, :], 49, 1)
x_split3 = torch.roll(x_[B // 4 * 2:B // 4 * 3, :, :], 49 * 2, 1)
x_split4 = torch.roll(x_[B // 4 * 3:, :, :], 49 * 3, 1)
x_ = torch.cat([x_split1, x_split2, x_split3, x_split4])
ids_restore = torch.cat(
[ids_restore, ids_restore, ids_restore, ids_restore])
x = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1,
x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
return x