mirror of https://github.com/alibaba/EasyCV.git
471 lines
17 KiB
Python
471 lines
17 KiB
Python
import math
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint as checkpoint
|
|
from timm.models.layers import DropPath, trunc_normal_
|
|
|
|
from easycv.models.utils import Mlp
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.logger import get_root_logger
|
|
from ..registry import BACKBONES
|
|
|
|
|
|
def window_partition(x, window_size):
|
|
"""
|
|
Partition into non-overlapping windows with padding if needed.
|
|
Args:
|
|
x (tensor): input tokens with [B, H, W, C].
|
|
window_size (int): window size.
|
|
Returns:
|
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
|
(Hp, Wp): padded height and width before partition
|
|
"""
|
|
B, H, W, C = x.shape
|
|
|
|
pad_h = (window_size - H % window_size) % window_size
|
|
pad_w = (window_size - W % window_size) % window_size
|
|
if pad_h > 0 or pad_w > 0:
|
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
|
Hp, Wp = H + pad_h, W + pad_w
|
|
|
|
x = x.view(B, Hp // window_size, window_size, Wp // window_size,
|
|
window_size, C)
|
|
windows = x.permute(0, 1, 3, 2, 4,
|
|
5).contiguous().view(-1, window_size, window_size, C)
|
|
return windows, (Hp, Wp)
|
|
|
|
|
|
def window_unpartition(windows, window_size, pad_hw, hw):
|
|
"""
|
|
Window unpartition into original sequences and removing padding.
|
|
Args:
|
|
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
|
window_size (int): window size.
|
|
pad_hw (Tuple): padded height and width (Hp, Wp).
|
|
hw (Tuple): original height and width (H, W) before padding.
|
|
Returns:
|
|
x: unpartitioned sequences with [B, H, W, C].
|
|
"""
|
|
Hp, Wp = pad_hw
|
|
H, W = hw
|
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
|
x = windows.view(B, Hp // window_size, Wp // window_size, window_size,
|
|
window_size, -1)
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
|
|
|
if Hp > H or Wp > W:
|
|
x = x[:, :H, :W, :].contiguous()
|
|
return x
|
|
|
|
|
|
def get_rel_pos(q_size, k_size, rel_pos):
|
|
"""
|
|
Get relative positional embeddings according to the relative positions of
|
|
query and key sizes.
|
|
Args:
|
|
q_size (int): size of query q.
|
|
k_size (int): size of key k.
|
|
rel_pos (Tensor): relative position embeddings (L, C).
|
|
Returns:
|
|
Extracted positional embeddings according to relative positions.
|
|
"""
|
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
# Interpolate rel pos if needed.
|
|
if rel_pos.shape[0] != max_rel_dist:
|
|
# Interpolate rel pos.
|
|
rel_pos_resized = F.interpolate(
|
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
size=max_rel_dist,
|
|
mode='linear',
|
|
)
|
|
rel_pos_resized = rel_pos_resized.reshape(-1,
|
|
max_rel_dist).permute(1, 0)
|
|
else:
|
|
rel_pos_resized = rel_pos
|
|
|
|
# Scale the coords with short length if shapes for q and k are different.
|
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
|
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
|
relative_coords = (q_coords -
|
|
k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
|
|
return rel_pos_resized[relative_coords.long()]
|
|
|
|
|
|
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
|
"""
|
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
|
Args:
|
|
attn (Tensor): attention map.
|
|
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
|
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
|
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
|
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
|
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
|
Returns:
|
|
attn (Tensor): attention map with added relative positional embeddings.
|
|
"""
|
|
q_h, q_w = q_size
|
|
k_h, k_w = k_size
|
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
|
|
B, _, dim = q.shape
|
|
r_q = q.reshape(B, q_h, q_w, dim)
|
|
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
|
|
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
|
|
|
|
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] +
|
|
rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
|
|
|
|
return attn
|
|
|
|
|
|
def get_abs_pos(abs_pos, has_cls_token, hw):
|
|
"""
|
|
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
|
dimension for the original embeddings.
|
|
Args:
|
|
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
|
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
|
hw (Tuple): size of input image tokens.
|
|
Returns:
|
|
Absolute positional embeddings after processing with shape (1, H, W, C)
|
|
"""
|
|
h, w = hw
|
|
if has_cls_token:
|
|
abs_pos = abs_pos[:, 1:]
|
|
xy_num = abs_pos.shape[1]
|
|
size = int(math.sqrt(xy_num))
|
|
assert size * size == xy_num
|
|
|
|
if size != h or size != w:
|
|
new_abs_pos = F.interpolate(
|
|
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
|
size=(h, w),
|
|
mode='bicubic',
|
|
align_corners=False,
|
|
)
|
|
|
|
return new_abs_pos.permute(0, 2, 3, 1)
|
|
else:
|
|
return abs_pos.reshape(1, h, w, -1)
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
"""
|
|
Image to Patch Embedding.
|
|
"""
|
|
|
|
def __init__(self,
|
|
kernel_size=(16, 16),
|
|
stride=(16, 16),
|
|
padding=(0, 0),
|
|
in_chans=3,
|
|
embed_dim=768):
|
|
"""
|
|
Args:
|
|
kernel_size (Tuple): kernel size of the projection layer.
|
|
stride (Tuple): stride of the projection layer.
|
|
padding (Tuple): padding size of the projection layer.
|
|
in_chans (int): Number of input image channels.
|
|
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.proj = nn.Conv2d(
|
|
in_chans,
|
|
embed_dim,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding)
|
|
|
|
def forward(self, x):
|
|
x = self.proj(x)
|
|
# B C H W -> B H W C
|
|
x = x.permute(0, 2, 3, 1)
|
|
return x
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""Multi-head Attention block with relative position embeddings."""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads=8,
|
|
qkv_bias=True,
|
|
use_rel_pos=False,
|
|
rel_pos_zero_init=True,
|
|
input_size=None,
|
|
):
|
|
"""
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads.
|
|
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
|
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
input_size (int or None): Input resolution for calculating the relative positional
|
|
parameter size.
|
|
"""
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
self.scale = head_dim**-0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.use_rel_pos = use_rel_pos
|
|
if self.use_rel_pos:
|
|
# initialize relative positional embeddings
|
|
self.rel_pos_h = nn.Parameter(
|
|
torch.zeros(2 * input_size[0] - 1, head_dim))
|
|
self.rel_pos_w = nn.Parameter(
|
|
torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
|
|
if not rel_pos_zero_init:
|
|
trunc_normal_(self.rel_pos_h, std=0.02)
|
|
trunc_normal_(self.rel_pos_w, std=0.02)
|
|
|
|
def forward(self, x):
|
|
B, H, W, _ = x.shape
|
|
# qkv with shape (3, B, nHead, H * W, C)
|
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,
|
|
-1).permute(2, 0, 3, 1, 4)
|
|
# q, k, v with shape (B * nHead, H * W, C)
|
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
|
|
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h,
|
|
self.rel_pos_w, (H, W), (H, W))
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
x = (attn @ v).view(B, self.num_heads, H, W,
|
|
-1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
|
x = self.proj(x)
|
|
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=True,
|
|
drop_path=0.0,
|
|
norm_layer=nn.LayerNorm,
|
|
act_layer=nn.GELU,
|
|
use_rel_pos=False,
|
|
rel_pos_zero_init=True,
|
|
window_size=0,
|
|
use_residual_block=False,
|
|
input_size=None,
|
|
):
|
|
"""
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads in each ViT block.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
drop_path (float): Stochastic depth rate.
|
|
norm_layer (nn.Module): Normalization layer.
|
|
act_layer (nn.Module): Activation layer.
|
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
|
use window attention.
|
|
use_residual_block (bool): If True, use a residual block after the MLP block.
|
|
input_size (int or None): Input resolution for calculating the relative positional
|
|
parameter size.
|
|
"""
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = Attention(
|
|
dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
use_rel_pos=use_rel_pos,
|
|
rel_pos_zero_init=rel_pos_zero_init,
|
|
input_size=input_size if window_size == 0 else
|
|
(window_size, window_size),
|
|
)
|
|
|
|
self.drop_path = DropPath(
|
|
drop_path) if drop_path > 0.0 else nn.Identity()
|
|
self.norm2 = norm_layer(dim)
|
|
self.mlp = Mlp(
|
|
in_features=dim,
|
|
hidden_features=int(dim * mlp_ratio),
|
|
act_layer=act_layer)
|
|
|
|
self.window_size = window_size
|
|
|
|
self.use_residual_block = use_residual_block
|
|
|
|
def forward(self, x):
|
|
shortcut = x
|
|
x = self.norm1(x)
|
|
# Window partition
|
|
if self.window_size > 0:
|
|
H, W = x.shape[1], x.shape[2]
|
|
x, pad_hw = window_partition(x, self.window_size)
|
|
|
|
x = self.attn(x)
|
|
# Reverse window partition
|
|
if self.window_size > 0:
|
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
if self.use_residual_block:
|
|
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
|
|
|
return x
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class ViTDet(nn.Module):
|
|
"""
|
|
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
|
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
|
https://arxiv.org/abs/2203.16527
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
img_size=1024,
|
|
patch_size=16,
|
|
in_chans=3,
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=True,
|
|
drop_path_rate=0.0,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
act_layer=nn.GELU,
|
|
use_abs_pos=True,
|
|
use_rel_pos=False,
|
|
rel_pos_zero_init=True,
|
|
window_size=0,
|
|
window_block_indexes=(),
|
|
residual_block_indexes=(),
|
|
use_act_checkpoint=False,
|
|
pretrain_img_size=224,
|
|
pretrain_use_cls_token=True,
|
|
pretrained=None,
|
|
):
|
|
"""
|
|
Args:
|
|
img_size (int): Input image size.
|
|
patch_size (int): Patch size.
|
|
in_chans (int): Number of input image channels.
|
|
embed_dim (int): Patch embedding dimension.
|
|
depth (int): Depth of ViT.
|
|
num_heads (int): Number of attention heads in each ViT block.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
drop_path_rate (float): Stochastic depth rate.
|
|
norm_layer (nn.Module): Normalization layer.
|
|
act_layer (nn.Module): Activation layer.
|
|
use_abs_pos (bool): If True, use absolute positional embeddings.
|
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
window_size (int): Window size for window attention blocks.
|
|
window_block_indexes (list): Indexes for blocks using window attention.
|
|
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
|
use_act_checkpoint (bool): If True, use activation checkpointing.
|
|
pretrain_img_size (int): input image size for pretraining models.
|
|
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
|
"""
|
|
super().__init__()
|
|
self.pretrain_use_cls_token = pretrain_use_cls_token
|
|
self.use_act_checkpoint = use_act_checkpoint
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
kernel_size=(patch_size, patch_size),
|
|
stride=(patch_size, patch_size),
|
|
in_chans=in_chans,
|
|
embed_dim=embed_dim,
|
|
)
|
|
|
|
if use_abs_pos:
|
|
# Initialize absolute positional embedding with pretrain image size.
|
|
num_patches = (pretrain_img_size // patch_size) * (
|
|
pretrain_img_size // patch_size)
|
|
num_positions = (num_patches +
|
|
1) if pretrain_use_cls_token else num_patches
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, num_positions, embed_dim))
|
|
else:
|
|
self.pos_embed = None
|
|
|
|
# stochastic depth decay rule
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
|
|
|
self.blocks = nn.ModuleList()
|
|
for i in range(depth):
|
|
block = Block(
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
drop_path=dpr[i],
|
|
norm_layer=norm_layer,
|
|
act_layer=act_layer,
|
|
use_rel_pos=use_rel_pos,
|
|
rel_pos_zero_init=rel_pos_zero_init,
|
|
window_size=window_size if i in window_block_indexes else 0,
|
|
use_residual_block=i in residual_block_indexes,
|
|
input_size=(img_size // patch_size, img_size // patch_size),
|
|
)
|
|
self.blocks.append(block)
|
|
|
|
if self.pos_embed is not None:
|
|
trunc_normal_(self.pos_embed, std=0.02)
|
|
|
|
self.apply(self._init_weights)
|
|
self.pretrained = pretrained
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight, std=0.02)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def init_weights(self):
|
|
if isinstance(self.pretrained, str):
|
|
logger = get_root_logger()
|
|
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
|
|
|
|
def forward(self, x):
|
|
x = self.patch_embed(x)
|
|
if self.pos_embed is not None:
|
|
x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token,
|
|
(x.shape[1], x.shape[2]))
|
|
|
|
for blk in self.blocks:
|
|
if self.use_act_checkpoint:
|
|
x = checkpoint.checkpoint(blk, x)
|
|
else:
|
|
x = blk(x)
|
|
|
|
outputs = [x.permute(0, 3, 1, 2)]
|
|
return outputs
|