mmclassification/mmpretrain/models/utils/position_encoding.py

248 lines
8.6 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import partial
from typing import Optional, Sequence, Union
import torch
import torch.nn as nn
2022-07-12 16:10:59 +08:00
from mmengine.model import BaseModule
from mmengine.utils import digit_version
from ..utils import to_2tuple
# After pytorch v1.10.0, use torch.meshgrid without indexing
# will raise extra warning. For more details,
# refers to https://github.com/pytorch/pytorch/issues/50276
if digit_version(torch.__version__) >= digit_version('1.10.0'):
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
else:
torch_meshgrid = torch.meshgrid
class ConditionalPositionEncoding(BaseModule):
"""The Conditional Position Encoding (CPE) module.
The CPE is the implementation of 'Conditional Positional Encodings
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
Args:
in_channels (int): Number of input channels.
embed_dims (int): The feature dimension. Default: 768.
stride (int): Stride of conv layer. Default: 1.
"""
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=3,
stride=stride,
padding=1,
bias=True,
groups=embed_dims)
self.stride = stride
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
feat_token = x
# convert (B, N, C) to (B, C, H, W)
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous()
if self.stride == 1:
x = self.proj(cnn_feat) + cnn_feat
else:
x = self.proj(cnn_feat)
x = x.flatten(2).transpose(1, 2)
return x
class PositionEncodingFourier(BaseModule):
"""The Position Encoding Fourier (PEF) module.
The PEF is adopted from EdgeNeXt <https://arxiv.org/abs/2206.10589>'_.
Args:
in_channels (int): Number of input channels.
Default: 32
embed_dims (int): The feature dimension.
Default: 768.
temperature (int): Temperature.
Default: 10000.
dtype (torch.dtype): The data type.
Default: torch.float32.
init_cfg (dict): The config dict for initializing the module.
Default: None.
"""
def __init__(self,
in_channels=32,
embed_dims=768,
temperature=10000,
dtype=torch.float32,
init_cfg=None):
super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1)
self.scale = 2 * math.pi
self.in_channels = in_channels
self.embed_dims = embed_dims
self.dtype = dtype
if digit_version(torch.__version__) < digit_version('1.8.0'):
floor_div = torch.floor_divide
else:
floor_div = partial(torch.div, rounding_mode='floor')
dim_t = torch.arange(in_channels, dtype=self.dtype)
self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels)
def forward(self, bhw_shape):
B, H, W = bhw_shape
mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device)
not_mask = ~mask
eps = 1e-6
y_embed = not_mask.cumsum(1, dtype=self.dtype)
x_embed = not_mask.cumsum(2, dtype=self.dtype)
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = self.dim_t.to(mask.device)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
dim=4).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.proj(pos)
return pos
def build_2d_sincos_position_embedding(
patches_resolution: Union[int, Sequence[int]],
embed_dims: int,
temperature: Optional[int] = 10000.,
cls_token: Optional[bool] = False) -> torch.Tensor:
"""The function is to build position embedding for model to obtain the
position information of the image patches.
Args:
patches_resolution (Union[int, Sequence[int]]): The resolution of each
patch.
embed_dims (int): The dimension of the embedding vector.
temperature (int, optional): The temperature parameter. Defaults to
10000.
cls_token (bool, optional): Whether to concatenate class token.
Defaults to False.
Returns:
torch.Tensor: The position embedding vector.
"""
if isinstance(patches_resolution, int):
patches_resolution = (patches_resolution, patches_resolution)
h, w = patches_resolution
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch_meshgrid(grid_w, grid_h)
assert embed_dims % 4 == 0, \
'Embed dimension must be divisible by 4.'
pos_dim = embed_dims // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat(
[
torch.sin(out_w),
torch.cos(out_w),
torch.sin(out_h),
torch.cos(out_h)
],
dim=1,
)[None, :, :]
if cls_token:
cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
return pos_emb
class RotaryEmbeddingFast(BaseModule):
"""Implements 2D rotary embedding (RoPE) for image tokens. Position
encoding is implemented with sin and cos functions,
.. math::
Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\
Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}}
Args:
embed_dims (int): The feature dimension for each head.
patch_resolution (int | tuple): The resolution of the
image, in format (H, W).
theta (float): The hyperparameter for position coding.
Defaults to 10000.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
patch_resolution,
theta=10000.,
init_cfg=None):
super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg)
self.half_dim = embed_dims // 2
self.patch_resolution = to_2tuple(patch_resolution)
self.theta = theta
freqs_cos, freqs_sin = self.compute_position_embedding()
self.register_buffer('freqs_cos', freqs_cos)
self.register_buffer('freqs_sin', freqs_sin)
def compute_position_embedding(self):
frequency = self.theta**(
torch.arange(0, self.half_dim, 2).float() / self.half_dim)
frequency = 1. / frequency
h, w = self.patch_resolution
th = torch.arange(h) / h * self.half_dim
tw = torch.arange(w) / w * self.half_dim
position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2)
position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2)
height = position_h[:, None, :].expand(h, w, self.half_dim)
width = position_w[None, :, :].expand(h, w, self.half_dim)
position = torch.cat((height, width), dim=-1)
freqs_cos = position.cos().view(-1, position.shape[-1])
freqs_sin = position.sin().view(-1, position.shape[-1])
return freqs_cos, freqs_sin
def forward(self, x, patch_resolution):
# Check whether the patch resolution is the predefined size
patch_resolution = to_2tuple(patch_resolution)
if patch_resolution != self.patch_resolution:
self.patch_resolution = patch_resolution
freqs_cos, freqs_sin = self.compute_position_embedding()
self.register_buffer('freqs_cos', freqs_cos.to(x.device))
self.register_buffer('freqs_sin', freqs_sin.to(x.device))
batch, num_heads, num_patches, dim = x.shape
inputs = x
x = x.reshape(batch, num_heads, num_patches, -1, 2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
x = x.reshape(batch, num_heads, num_patches, dim)
return inputs * self.freqs_cos + x * self.freqs_sin