254 lines
9.4 KiB
Python
254 lines
9.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
|
from mmcv.runner.base_module import BaseModule
|
|
|
|
from .helpers import to_2tuple
|
|
|
|
|
|
class PatchEmbed(BaseModule):
|
|
"""Image to Patch Embedding.
|
|
|
|
We use a conv layer to implement PatchEmbed.
|
|
|
|
Args:
|
|
img_size (int | tuple): The size of input image. Default: 224
|
|
in_channels (int): The num of input channels. Default: 3
|
|
embed_dims (int): The dimensions of embedding. Default: 768
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: None
|
|
conv_cfg (dict, optional): The config dict for conv layers.
|
|
Default: None
|
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
|
Default: None
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_size=224,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
norm_cfg=None,
|
|
conv_cfg=None,
|
|
init_cfg=None):
|
|
super(PatchEmbed, self).__init__(init_cfg)
|
|
|
|
if isinstance(img_size, int):
|
|
img_size = to_2tuple(img_size)
|
|
elif isinstance(img_size, tuple):
|
|
if len(img_size) == 1:
|
|
img_size = to_2tuple(img_size[0])
|
|
assert len(img_size) == 2, \
|
|
f'The size of image should have length 1 or 2, ' \
|
|
f'but got {len(img_size)}'
|
|
|
|
self.img_size = img_size
|
|
self.embed_dims = embed_dims
|
|
|
|
# Use conv layer to embed
|
|
conv_cfg = conv_cfg or dict()
|
|
_conv_cfg = dict(
|
|
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
|
|
_conv_cfg.update(conv_cfg)
|
|
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
|
|
|
|
# Calculate how many patches a input image is splited to.
|
|
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] -
|
|
self.projection.dilation[i] *
|
|
(self.projection.kernel_size[i] - 1) - 1) //
|
|
self.projection.stride[i] + 1 for i in range(2)]
|
|
|
|
self.patches_resolution = (h_out, w_out)
|
|
self.num_patches = h_out * w_out
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
else:
|
|
self.norm = None
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
|
f"Input image size ({H}*{W}) doesn't " \
|
|
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
|
|
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
|
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x
|
|
|
|
|
|
# Modified from pytorch-image-models
|
|
class HybridEmbed(BaseModule):
|
|
"""CNN Feature Map Embedding.
|
|
|
|
Extract feature map from CNN, flatten,
|
|
project to embedding dim.
|
|
|
|
Args:
|
|
backbone (nn.Module): CNN backbone
|
|
img_size (int | tuple): The size of input image. Default: 224
|
|
feature_size (int | tuple, optional): Size of feature map extracted by
|
|
CNN backbone. Default: None
|
|
in_channels (int): The num of input channels. Default: 3
|
|
embed_dims (int): The dimensions of embedding. Default: 768
|
|
conv_cfg (dict, optional): The config dict for conv layers.
|
|
Default: None.
|
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
img_size=224,
|
|
feature_size=None,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
conv_cfg=None,
|
|
init_cfg=None):
|
|
super(HybridEmbed, self).__init__(init_cfg)
|
|
assert isinstance(backbone, nn.Module)
|
|
if isinstance(img_size, int):
|
|
img_size = to_2tuple(img_size)
|
|
elif isinstance(img_size, tuple):
|
|
if len(img_size) == 1:
|
|
img_size = to_2tuple(img_size[0])
|
|
assert len(img_size) == 2, \
|
|
f'The size of image should have length 1 or 2, ' \
|
|
f'but got {len(img_size)}'
|
|
|
|
self.img_size = img_size
|
|
self.backbone = backbone
|
|
if feature_size is None:
|
|
with torch.no_grad():
|
|
# FIXME this is hacky, but most reliable way of
|
|
# determining the exact dim of the output feature
|
|
# map for all networks, the feature metadata has
|
|
# reliable channel and stride info, but using
|
|
# stride to calc feature dim requires info about padding of
|
|
# each stage that isn't captured.
|
|
training = backbone.training
|
|
if training:
|
|
backbone.eval()
|
|
o = self.backbone(
|
|
torch.zeros(1, in_channels, img_size[0], img_size[1]))
|
|
if isinstance(o, (list, tuple)):
|
|
# last feature if backbone outputs list/tuple of features
|
|
o = o[-1]
|
|
feature_size = o.shape[-2:]
|
|
feature_dim = o.shape[1]
|
|
backbone.train(training)
|
|
else:
|
|
feature_size = to_2tuple(feature_size)
|
|
if hasattr(self.backbone, 'feature_info'):
|
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
else:
|
|
feature_dim = self.backbone.num_features
|
|
self.num_patches = feature_size[0] * feature_size[1]
|
|
|
|
# Use conv layer to embed
|
|
conv_cfg = conv_cfg or dict()
|
|
_conv_cfg = dict(
|
|
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1)
|
|
_conv_cfg.update(conv_cfg)
|
|
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
if isinstance(x, (list, tuple)):
|
|
# last feature if backbone outputs list/tuple of features
|
|
x = x[-1]
|
|
x = self.projection(x).flatten(2).transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class PatchMerging(BaseModule):
|
|
"""Merge patch feature map.
|
|
|
|
This layer use nn.Unfold to group feature map by kernel_size, and use norm
|
|
and linear layer to embed grouped feature map.
|
|
|
|
Args:
|
|
input_resolution (tuple): The size of input patch resolution.
|
|
in_channels (int): The num of input channels.
|
|
expansion_ratio (Number): Expansion ratio of output channels. The num
|
|
of output channels is equal to int(expansion_ratio * in_channels).
|
|
kernel_size (int | tuple, optional): the kernel size in the unfold
|
|
layer. Defaults to 2.
|
|
stride (int | tuple, optional): the stride of the sliding blocks in the
|
|
unfold layer. Defaults to be equal with kernel_size.
|
|
padding (int | tuple, optional): zero padding width in the unfold
|
|
layer. Defaults to 0.
|
|
dilation (int | tuple, optional): dilation parameter in the unfold
|
|
layer. Defaults to 1.
|
|
bias (bool, optional): Whether to add bias in linear layer or not.
|
|
Defaults to False.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Defaults to dict(type='LN').
|
|
init_cfg (dict, optional): The extra config for initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
input_resolution,
|
|
in_channels,
|
|
expansion_ratio,
|
|
kernel_size=2,
|
|
stride=None,
|
|
padding=0,
|
|
dilation=1,
|
|
bias=False,
|
|
norm_cfg=dict(type='LN'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg)
|
|
H, W = input_resolution
|
|
self.input_resolution = input_resolution
|
|
self.in_channels = in_channels
|
|
self.out_channels = int(expansion_ratio * in_channels)
|
|
|
|
if stride is None:
|
|
stride = kernel_size
|
|
kernel_size = to_2tuple(kernel_size)
|
|
stride = to_2tuple(stride)
|
|
padding = to_2tuple(padding)
|
|
dilation = to_2tuple(dilation)
|
|
self.sampler = nn.Unfold(kernel_size, dilation, padding, stride)
|
|
|
|
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
|
else:
|
|
self.norm = None
|
|
|
|
self.reduction = nn.Linear(sample_dim, self.out_channels, bias=bias)
|
|
|
|
# See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
H_out = (H + 2 * padding[0] - dilation[0] *
|
|
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
|
W_out = (W + 2 * padding[1] - dilation[1] *
|
|
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
|
self.output_resolution = (H_out, W_out)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: B, H*W, C
|
|
"""
|
|
H, W = self.input_resolution
|
|
B, L, C = x.shape
|
|
assert L == H * W, 'input feature has wrong size'
|
|
|
|
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
|
|
|
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
|
# but need to modify pretrained model for compatibility
|
|
x = self.sampler(x) # B, 4*C, H/2*W/2
|
|
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
|
|
|
x = self.norm(x) if self.norm else x
|
|
x = self.reduction(x)
|
|
|
|
return x
|