422 lines
15 KiB
Python
422 lines
15 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
from typing import Sequence
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
|
from mmcv.cnn.bricks.transformer import AdaptivePadding
|
|
from mmengine.model import BaseModule
|
|
|
|
from .helpers import to_2tuple
|
|
|
|
|
|
def resize_pos_embed(pos_embed,
|
|
src_shape,
|
|
dst_shape,
|
|
mode='bicubic',
|
|
num_extra_tokens=1):
|
|
"""Resize pos_embed weights.
|
|
|
|
Args:
|
|
pos_embed (torch.Tensor): Position embedding weights with shape
|
|
[1, L, C].
|
|
src_shape (tuple): The resolution of downsampled origin training
|
|
image, in format (H, W).
|
|
dst_shape (tuple): The resolution of downsampled new training
|
|
image, in format (H, W).
|
|
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
|
|
'linear', 'bilinear', 'bicubic' and 'trilinear'.
|
|
Defaults to 'bicubic'.
|
|
num_extra_tokens (int): The number of extra tokens, such as cls_token.
|
|
Defaults to 1.
|
|
|
|
Returns:
|
|
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
|
|
"""
|
|
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
|
|
return pos_embed
|
|
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
|
_, L, C = pos_embed.shape
|
|
src_h, src_w = src_shape
|
|
assert L == src_h * src_w + num_extra_tokens, \
|
|
f"The length of `pos_embed` ({L}) doesn't match the expected " \
|
|
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
|
|
'`img_size` argument.'
|
|
extra_tokens = pos_embed[:, :num_extra_tokens]
|
|
|
|
src_weight = pos_embed[:, num_extra_tokens:]
|
|
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
|
|
|
|
dst_weight = F.interpolate(
|
|
src_weight, size=dst_shape, align_corners=False, mode=mode)
|
|
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
|
|
|
|
return torch.cat((extra_tokens, dst_weight), dim=1)
|
|
|
|
|
|
def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head):
|
|
"""Resize relative position bias table.
|
|
|
|
Args:
|
|
src_shape (int): The resolution of downsampled origin training
|
|
image, in format (H, W).
|
|
dst_shape (int): The resolution of downsampled new training
|
|
image, in format (H, W).
|
|
table (tensor): The relative position bias of the pretrained model.
|
|
num_head (int): Number of attention heads.
|
|
|
|
Returns:
|
|
torch.Tensor: The resized relative position bias table.
|
|
"""
|
|
from scipy import interpolate
|
|
|
|
def geometric_progression(a, r, n):
|
|
return a * (1.0 - r**n) / (1.0 - r)
|
|
|
|
left, right = 1.01, 1.5
|
|
while right - left > 1e-6:
|
|
q = (left + right) / 2.0
|
|
gp = geometric_progression(1, q, src_shape // 2)
|
|
if gp > dst_shape // 2:
|
|
right = q
|
|
else:
|
|
left = q
|
|
|
|
dis = []
|
|
cur = 1
|
|
for i in range(src_shape // 2):
|
|
dis.append(cur)
|
|
cur += q**(i + 1)
|
|
|
|
r_ids = [-_ for _ in reversed(dis)]
|
|
|
|
x = r_ids + [0] + dis
|
|
y = r_ids + [0] + dis
|
|
|
|
t = dst_shape // 2.0
|
|
dx = np.arange(-t, t + 0.1, 1.0)
|
|
dy = np.arange(-t, t + 0.1, 1.0)
|
|
|
|
all_rel_pos_bias = []
|
|
|
|
for i in range(num_head):
|
|
z = table[:, i].view(src_shape, src_shape).float().numpy()
|
|
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
|
|
all_rel_pos_bias.append(
|
|
torch.Tensor(f_cubic(dx,
|
|
dy)).contiguous().view(-1,
|
|
1).to(table.device))
|
|
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
|
return new_rel_pos_bias
|
|
|
|
|
|
class PatchEmbed(BaseModule):
|
|
"""Image to Patch Embedding.
|
|
|
|
We use a conv layer to implement PatchEmbed.
|
|
|
|
Args:
|
|
img_size (int | tuple): The size of input image. Default: 224
|
|
in_channels (int): The num of input channels. Default: 3
|
|
embed_dims (int): The dimensions of embedding. Default: 768
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: None
|
|
conv_cfg (dict, optional): The config dict for conv layers.
|
|
Default: None
|
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
|
Default: None
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_size=224,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
norm_cfg=None,
|
|
conv_cfg=None,
|
|
init_cfg=None):
|
|
super(PatchEmbed, self).__init__(init_cfg)
|
|
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. '
|
|
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
|
|
"It's more general and supports dynamic input shape")
|
|
|
|
if isinstance(img_size, int):
|
|
img_size = to_2tuple(img_size)
|
|
elif isinstance(img_size, tuple):
|
|
if len(img_size) == 1:
|
|
img_size = to_2tuple(img_size[0])
|
|
assert len(img_size) == 2, \
|
|
f'The size of image should have length 1 or 2, ' \
|
|
f'but got {len(img_size)}'
|
|
|
|
self.img_size = img_size
|
|
self.embed_dims = embed_dims
|
|
|
|
# Use conv layer to embed
|
|
conv_cfg = conv_cfg or dict()
|
|
_conv_cfg = dict(
|
|
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
|
|
_conv_cfg.update(conv_cfg)
|
|
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
|
|
|
|
# Calculate how many patches a input image is splited to.
|
|
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] -
|
|
self.projection.dilation[i] *
|
|
(self.projection.kernel_size[i] - 1) - 1) //
|
|
self.projection.stride[i] + 1 for i in range(2)]
|
|
|
|
self.patches_resolution = (h_out, w_out)
|
|
self.num_patches = h_out * w_out
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
else:
|
|
self.norm = None
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
|
f"Input image size ({H}*{W}) doesn't " \
|
|
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
|
|
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
|
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x
|
|
|
|
|
|
# Modified from pytorch-image-models
|
|
class HybridEmbed(BaseModule):
|
|
"""CNN Feature Map Embedding.
|
|
|
|
Extract feature map from CNN, flatten,
|
|
project to embedding dim.
|
|
|
|
Args:
|
|
backbone (nn.Module): CNN backbone
|
|
img_size (int | tuple): The size of input image. Default: 224
|
|
feature_size (int | tuple, optional): Size of feature map extracted by
|
|
CNN backbone. Default: None
|
|
in_channels (int): The num of input channels. Default: 3
|
|
embed_dims (int): The dimensions of embedding. Default: 768
|
|
conv_cfg (dict, optional): The config dict for conv layers.
|
|
Default: None.
|
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
img_size=224,
|
|
feature_size=None,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
conv_cfg=None,
|
|
init_cfg=None):
|
|
super(HybridEmbed, self).__init__(init_cfg)
|
|
assert isinstance(backbone, nn.Module)
|
|
if isinstance(img_size, int):
|
|
img_size = to_2tuple(img_size)
|
|
elif isinstance(img_size, tuple):
|
|
if len(img_size) == 1:
|
|
img_size = to_2tuple(img_size[0])
|
|
assert len(img_size) == 2, \
|
|
f'The size of image should have length 1 or 2, ' \
|
|
f'but got {len(img_size)}'
|
|
|
|
self.img_size = img_size
|
|
self.backbone = backbone
|
|
if feature_size is None:
|
|
with torch.no_grad():
|
|
# FIXME this is hacky, but most reliable way of
|
|
# determining the exact dim of the output feature
|
|
# map for all networks, the feature metadata has
|
|
# reliable channel and stride info, but using
|
|
# stride to calc feature dim requires info about padding of
|
|
# each stage that isn't captured.
|
|
training = backbone.training
|
|
if training:
|
|
backbone.eval()
|
|
o = self.backbone(
|
|
torch.zeros(1, in_channels, img_size[0], img_size[1]))
|
|
if isinstance(o, (list, tuple)):
|
|
# last feature if backbone outputs list/tuple of features
|
|
o = o[-1]
|
|
feature_size = o.shape[-2:]
|
|
feature_dim = o.shape[1]
|
|
backbone.train(training)
|
|
else:
|
|
feature_size = to_2tuple(feature_size)
|
|
if hasattr(self.backbone, 'feature_info'):
|
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
else:
|
|
feature_dim = self.backbone.num_features
|
|
self.num_patches = feature_size[0] * feature_size[1]
|
|
|
|
# Use conv layer to embed
|
|
conv_cfg = conv_cfg or dict()
|
|
_conv_cfg = dict(
|
|
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1)
|
|
_conv_cfg.update(conv_cfg)
|
|
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
if isinstance(x, (list, tuple)):
|
|
# last feature if backbone outputs list/tuple of features
|
|
x = x[-1]
|
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class PatchMerging(BaseModule):
|
|
"""Merge patch feature map.
|
|
|
|
Modified from mmcv, and this module supports specifying whether to use
|
|
post-norm.
|
|
|
|
This layer groups feature map by kernel_size, and applies norm and linear
|
|
layers to the grouped feature map ((used in Swin Transformer)). Our
|
|
implementation uses :class:`torch.nn.Unfold` to merge patches, which is
|
|
about 25% faster than the original implementation. However, we need to
|
|
modify pretrained models for compatibility.
|
|
|
|
Args:
|
|
in_channels (int): The num of input channels. To gets fully covered
|
|
by filter and stride you specified.
|
|
out_channels (int): The num of output channels.
|
|
kernel_size (int | tuple, optional): the kernel size in the unfold
|
|
layer. Defaults to 2.
|
|
stride (int | tuple, optional): the stride of the sliding blocks in the
|
|
unfold layer. Defaults to None, which means to be set as
|
|
``kernel_size``.
|
|
padding (int | tuple | string ): The padding length of
|
|
embedding conv. When it is a string, it means the mode
|
|
of adaptive padding, support "same" and "corner" now.
|
|
Defaults to "corner".
|
|
dilation (int | tuple, optional): dilation parameter in the unfold
|
|
layer. Defaults to 1.
|
|
bias (bool, optional): Whether to add bias in linear layer or not.
|
|
Defaults to False.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Defaults to ``dict(type='LN')``.
|
|
use_post_norm (bool): Whether to use post normalization here.
|
|
Defaults to False.
|
|
init_cfg (dict, optional): The extra config for initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=2,
|
|
stride=None,
|
|
padding='corner',
|
|
dilation=1,
|
|
bias=False,
|
|
norm_cfg=dict(type='LN'),
|
|
use_post_norm=False,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.use_post_norm = use_post_norm
|
|
|
|
if stride:
|
|
stride = stride
|
|
else:
|
|
stride = kernel_size
|
|
|
|
kernel_size = to_2tuple(kernel_size)
|
|
stride = to_2tuple(stride)
|
|
dilation = to_2tuple(dilation)
|
|
|
|
if isinstance(padding, str):
|
|
self.adaptive_padding = AdaptivePadding(
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
dilation=dilation,
|
|
padding=padding)
|
|
# disable the padding of unfold
|
|
padding = 0
|
|
else:
|
|
self.adaptive_padding = None
|
|
|
|
padding = to_2tuple(padding)
|
|
self.sampler = nn.Unfold(
|
|
kernel_size=kernel_size,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
stride=stride)
|
|
|
|
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
|
|
|
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
|
|
|
if norm_cfg is not None:
|
|
# build pre or post norm layer based on different channels
|
|
if self.use_post_norm:
|
|
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
|
|
else:
|
|
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
|
else:
|
|
self.norm = None
|
|
|
|
def forward(self, x, input_size):
|
|
"""
|
|
Args:
|
|
x (Tensor): Has shape (B, H*W, C_in).
|
|
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
|
Default: None.
|
|
|
|
Returns:
|
|
tuple: Contains merged results and its spatial shape.
|
|
|
|
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
|
- out_size (tuple[int]): Spatial shape of x, arrange as
|
|
(Merged_H, Merged_W).
|
|
"""
|
|
B, L, C = x.shape
|
|
assert isinstance(input_size, Sequence), f'Expect ' \
|
|
f'input_size is ' \
|
|
f'`Sequence` ' \
|
|
f'but get {input_size}'
|
|
|
|
H, W = input_size
|
|
assert L == H * W, 'input feature has wrong size'
|
|
|
|
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
|
|
|
if self.adaptive_padding:
|
|
x = self.adaptive_padding(x)
|
|
H, W = x.shape[-2:]
|
|
|
|
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
|
# but need to modify pretrained model for compatibility
|
|
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
|
x = self.sampler(x)
|
|
|
|
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
|
|
(self.sampler.kernel_size[0] - 1) -
|
|
1) // self.sampler.stride[0] + 1
|
|
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
|
|
(self.sampler.kernel_size[1] - 1) -
|
|
1) // self.sampler.stride[1] + 1
|
|
|
|
output_size = (out_h, out_w)
|
|
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
|
|
|
if self.use_post_norm:
|
|
# use post-norm here
|
|
x = self.reduction(x)
|
|
x = self.norm(x) if self.norm else x
|
|
else:
|
|
x = self.norm(x) if self.norm else x
|
|
x = self.reduction(x)
|
|
|
|
return x, output_size
|