90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
import torch.nn.functional as F
|
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
|
from mmcv.runner.base_module import BaseModule
|
|
from torch.nn.modules.utils import _pair as to_2tuple
|
|
|
|
|
|
# Modified from Pytorch-Image-Models
|
|
class PatchEmbed(BaseModule):
|
|
"""Image to Patch Embedding V2.
|
|
|
|
We use a conv layer to implement PatchEmbed.
|
|
Args:
|
|
in_channels (int): The num of input channels. Default: 3
|
|
embed_dims (int): The dimensions of embedding. Default: 768
|
|
conv_type (dict, optional): The config dict for conv layers type
|
|
selection. Default: None.
|
|
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
|
stride (int): The slide stride of embedding conv.
|
|
Default: None (Default to be equal with kernel_size).
|
|
padding (int): The padding length of embedding conv. Default: 0.
|
|
dilation (int): The dilation rate of embedding conv. Default: 1.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
conv_type=None,
|
|
kernel_size=16,
|
|
stride=16,
|
|
padding=0,
|
|
dilation=1,
|
|
norm_cfg=None,
|
|
init_cfg=None):
|
|
super(PatchEmbed, self).__init__()
|
|
|
|
self.embed_dims = embed_dims
|
|
self.init_cfg = init_cfg
|
|
|
|
if stride is None:
|
|
stride = kernel_size
|
|
|
|
# The default setting of patch size is eaual to kernel size.
|
|
patch_size = kernel_size
|
|
if isinstance(patch_size, int):
|
|
patch_size = to_2tuple(patch_size)
|
|
elif isinstance(patch_size, tuple):
|
|
if len(patch_size) == 1:
|
|
patch_size = to_2tuple(patch_size[0])
|
|
assert len(patch_size) == 2, \
|
|
f'The size of patch should have length 1 or 2, ' \
|
|
f'but got {len(patch_size)}'
|
|
|
|
self.patch_size = patch_size
|
|
|
|
# Use conv layer to embed
|
|
conv_type = conv_type or dict(type='Conv2d')
|
|
self.projection = build_conv_layer(
|
|
dict(type=conv_type),
|
|
in_channels=in_channels,
|
|
out_channels=embed_dims,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation)
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
else:
|
|
self.norm = None
|
|
|
|
def forward(self, x):
|
|
H, W = x.shape[2], x.shape[3]
|
|
if H % self.patch_size[0] != 0:
|
|
x = F.pad(x,
|
|
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
|
if W % self.patch_size[1] != 0:
|
|
x = F.pad(x,
|
|
(0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
|
|
x = self.projection(x)
|
|
self.DH, self.DW = x.shape[2], x.shape[3]
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x
|