mmpretrain/mmcls/models/utils/embed.py

254 lines
9.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from .helpers import to_2tuple
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
img_size (int | tuple): The size of input image. Default: 224
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None
conv_cfg (dict, optional): The config dict for conv layers.
Default: None
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None
"""
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=768,
norm_cfg=None,
conv_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.embed_dims = embed_dims
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
# Calculate how many patches a input image is splited to.
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] -
self.projection.dilation[i] *
(self.projection.kernel_size[i] - 1) - 1) //
self.projection.stride[i] + 1 for i in range(2)]
self.patches_resolution = (h_out, w_out)
self.num_patches = h_out * w_out
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't " \
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
# Modified from pytorch-image-models
class HybridEmbed(BaseModule):
"""CNN Feature Map Embedding.
Extract feature map from CNN, flatten,
project to embedding dim.
Args:
backbone (nn.Module): CNN backbone
img_size (int | tuple): The size of input image. Default: 224
feature_size (int | tuple, optional): Size of feature map extracted by
CNN backbone. Default: None
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_cfg (dict, optional): The config dict for conv layers.
Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_channels=3,
embed_dims=768,
conv_cfg=None,
init_cfg=None):
super(HybridEmbed, self).__init__(init_cfg)
assert isinstance(backbone, nn.Module)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of
# determining the exact dim of the output feature
# map for all networks, the feature metadata has
# reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of
# each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_channels, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
# last feature if backbone outputs list/tuple of features
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
# last feature if backbone outputs list/tuple of features
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
input_resolution (tuple): The size of input patch resolution.
in_channels (int): The num of input channels.
expansion_ratio (Number): Expansion ratio of output channels. The num
of output channels is equal to int(expansion_ratio * in_channels).
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Defaults to be equal with kernel_size.
padding (int | tuple, optional): zero padding width in the unfold
layer. Defaults to 0.
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Defaults to 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
input_resolution,
in_channels,
expansion_ratio,
kernel_size=2,
stride=None,
padding=0,
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
H, W = input_resolution
self.input_resolution = input_resolution
self.in_channels = in_channels
self.out_channels = int(expansion_ratio * in_channels)
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
dilation = to_2tuple(dilation)
self.sampler = nn.Unfold(kernel_size, dilation, padding, stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, self.out_channels, bias=bias)
# See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
H_out = (H + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
W_out = (W + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.output_resolution = (H_out, W_out)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x