mirror of https://github.com/alibaba/EasyCV.git
374 lines
14 KiB
Python
374 lines
14 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import logging
|
|
import math
|
|
import os
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
|
from timm.models.layers import variance_scaling_
|
|
|
|
from easycv.file import io, is_oss_path
|
|
from easycv.utils.constant import CACHE_DIR
|
|
from ..registry import BACKBONES
|
|
from .conv_mae_vit import FastConvMAEViT
|
|
|
|
|
|
def trunc_normal_(tensor, mean=0.0, std=1.0):
|
|
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
|
|
|
|
|
def lecun_normal_(tensor):
|
|
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
|
|
|
|
|
class RelativePositionBias(nn.Module):
|
|
|
|
def __init__(self, window_size, num_heads):
|
|
super().__init__()
|
|
self.window_size = window_size
|
|
self.num_relative_distance = (2 * window_size[0] -
|
|
1) * (2 * window_size[1] - 1)
|
|
self.relative_position_bias_table = nn.Parameter(
|
|
torch.zeros(self.num_relative_distance,
|
|
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
|
|
|
# get pair-wise relative position index for each token inside the window
|
|
coords_h = torch.arange(window_size[0])
|
|
coords_w = torch.arange(window_size[1])
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
relative_coords = (coords_flatten[:, :, None] -
|
|
coords_flatten[:, None, :]) # 2, Wh*Ww, Wh*Ww
|
|
relative_coords = relative_coords.permute(
|
|
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
|
relative_coords[:, :, 1] += window_size[1] - 1
|
|
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
|
|
self.register_buffer('relative_position_index',
|
|
relative_position_index)
|
|
|
|
def forward(self):
|
|
relative_position_bias = self.relative_position_bias_table[
|
|
self.relative_position_index.view(-1)].view(
|
|
self.window_size[0] * self.window_size[1],
|
|
self.window_size[0] * self.window_size[1],
|
|
-1,
|
|
) # Wh*Ww,Wh*Ww,nH
|
|
return relative_position_bias.permute(
|
|
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
_logger = logging.getLogger(__name__)
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape,
|
|
posemb_new.shape)
|
|
ntok_new = posemb_new.shape[1]
|
|
if num_tokens:
|
|
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0,
|
|
num_tokens:]
|
|
ntok_new -= num_tokens
|
|
else:
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
if not len(gs_new): # backwards compatibility
|
|
gs_new = [int(math.sqrt(ntok_new))] * 2
|
|
assert len(gs_new) >= 2
|
|
_logger.info('Position embedding grid-size from %s to %s',
|
|
[gs_old, gs_old], gs_new)
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
|
-1).permute(0, 3, 1, 2)
|
|
posemb_grid = F.interpolate(
|
|
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
|
posemb_grid = posemb_grid.permute(0, 2, 3,
|
|
1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
return posemb
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class ConvViTDet(FastConvMAEViT):
|
|
"""Reference: https://github.com/Alpha-VL/FastConvMAE
|
|
Args:
|
|
window_size (int): The height and width of the window.
|
|
in_channels (int): The num of input channels. Default: 3
|
|
drop_rate (float): Probability of an element to be zeroed
|
|
after the feed forward layer. Defaults to 0.
|
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
|
img_size (list | tuple): Input image size for three stages.
|
|
embed_dim (list | tuple): The dimensions of embedding for three stages.
|
|
patch_size (list | tuple): The patch size for three stages.
|
|
depth (list | tuple): depth for three stages.
|
|
num_heads (int): Parallel attention heads
|
|
mlp_ratio (list | tuple): Mlp expansion ratio.
|
|
norm_layer (nn.Module): normalization layer
|
|
init_pos_embed_by_sincos: initialize pos_embed by sincos strategy
|
|
pretrained (str): pretrained path
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
window_size,
|
|
in_channels=3,
|
|
drop_rate=0.0,
|
|
drop_path_rate=0.0,
|
|
img_size=[1024, 256, 128],
|
|
embed_dim=[256, 384, 768],
|
|
patch_size=[4, 2, 2],
|
|
depth=[2, 2, 11],
|
|
mlp_ratio=[4, 4, 4],
|
|
num_heads=12,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
init_pos_embed_by_sincos=False,
|
|
pretrained=None,
|
|
):
|
|
super(ConvViTDet, self).__init__(
|
|
img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_channels=in_channels,
|
|
embed_dim=embed_dim,
|
|
depth=depth,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
drop_rate=drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
norm_layer=norm_layer,
|
|
init_pos_embed_by_sincos=init_pos_embed_by_sincos,
|
|
with_fuse=False,
|
|
global_pool=False,
|
|
)
|
|
|
|
self.pretrained = pretrained
|
|
self.num_heads = num_heads
|
|
self.num_features = (self.embed_dim) = embed_dim[
|
|
-1] # num_features for consistency with other models
|
|
self.num_tokens = None
|
|
self.norm = None
|
|
|
|
self.window_size = window_size
|
|
self.ms_adaptor = nn.ModuleList(
|
|
[nn.Identity(),
|
|
nn.Identity(),
|
|
nn.Identity(),
|
|
nn.MaxPool2d(2)])
|
|
self.ms_adaptor.apply(self.init_adaptor)
|
|
|
|
self.windowed_rel_pos_bias = RelativePositionBias(
|
|
window_size=(self.window_size, self.window_size),
|
|
num_heads=self.num_heads)
|
|
self.global_rel_pos_bias = RelativePositionBias(
|
|
window_size=self.patch_embed3.grid_size, num_heads=self.num_heads)
|
|
|
|
# self._out_features = ["s0", "s1", "s2", "s3"]
|
|
# self._out_feature_channels = {"s0": 256, "s1": 384, "s2": 768, "s3": 768}
|
|
# self._out_feature_strides = {"s0": 4, "s1": 8, "s2": 16, "s3": 32}
|
|
# self.global_attn_index = [0, 4, 8, 10]
|
|
|
|
def init_weights(self):
|
|
super().init_weights()
|
|
if self.pretrained is not None:
|
|
logging.info(f'Load pretrained model from {self.pretrained}...')
|
|
self.load_pretrained(self.pretrained)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight, std=0.02)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.Conv2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def init_adaptor(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.GroupNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, images):
|
|
outs = dict()
|
|
x = images
|
|
x = self.patch_embed1(x)
|
|
x = self.pos_drop(x)
|
|
for blk in self.blocks1:
|
|
x = blk(x)
|
|
outs['s0'] = x
|
|
x = self.patch_embed2(x)
|
|
for blk in self.blocks2:
|
|
x = blk(x)
|
|
outs['s1'] = x
|
|
x = self.patch_embed3(x)
|
|
x = x.flatten(2).permute(0, 2, 1)
|
|
x = self.patch_embed4(x)
|
|
x = x + self.pos_embed
|
|
|
|
x = self.blocks3[0](x, rel_pos_bias=self.global_rel_pos_bias())
|
|
x = rearrange(
|
|
x,
|
|
'b (h w) c -> b h w c',
|
|
h=self.patch_embed3.grid_size[0],
|
|
w=self.patch_embed3.grid_size[1],
|
|
)
|
|
x = rearrange(
|
|
x,
|
|
'b (h h1) (w w1) c -> (b h w) (h1 w1) c',
|
|
h1=self.window_size,
|
|
w1=self.window_size,
|
|
)
|
|
for blk in self.blocks3[1:3]:
|
|
x = blk(x, rel_pos_bias=self.windowed_rel_pos_bias())
|
|
|
|
x = rearrange(
|
|
x,
|
|
'(b h w) (h1 w1) c -> b (h h1 w w1) c',
|
|
h=self.patch_embed3.grid_size[0] // self.window_size,
|
|
w=self.patch_embed3.grid_size[1] // self.window_size,
|
|
h1=self.window_size,
|
|
w1=self.window_size,
|
|
)
|
|
x = self.blocks3[3](x, rel_pos_bias=self.global_rel_pos_bias())
|
|
x = rearrange(
|
|
x,
|
|
'b (h w) c -> b h w c',
|
|
h=self.patch_embed3.grid_size[0],
|
|
w=self.patch_embed3.grid_size[1],
|
|
)
|
|
x = rearrange(
|
|
x,
|
|
'b (h h1) (w w1) c -> (b h w) (h1 w1) c',
|
|
h1=self.window_size,
|
|
w1=self.window_size,
|
|
)
|
|
for blk in self.blocks3[4:6]:
|
|
x = blk(x, rel_pos_bias=self.windowed_rel_pos_bias())
|
|
|
|
x = rearrange(
|
|
x,
|
|
'(b h w) (h1 w1) c -> b (h h1 w w1) c',
|
|
h=self.patch_embed3.grid_size[0] // self.window_size,
|
|
w=self.patch_embed3.grid_size[1] // self.window_size,
|
|
h1=self.window_size,
|
|
w1=self.window_size,
|
|
)
|
|
x = self.blocks3[6](x, rel_pos_bias=self.global_rel_pos_bias())
|
|
x = rearrange(
|
|
x,
|
|
'b (h w) c -> b h w c',
|
|
h=self.patch_embed3.grid_size[0],
|
|
w=self.patch_embed3.grid_size[1],
|
|
)
|
|
x = rearrange(
|
|
x,
|
|
'b (h h1) (w w1) c -> (b h w) (h1 w1) c',
|
|
h1=self.window_size,
|
|
w1=self.window_size,
|
|
)
|
|
for blk in self.blocks3[7:10]:
|
|
x = blk(x, rel_pos_bias=self.windowed_rel_pos_bias())
|
|
|
|
x = rearrange(
|
|
x,
|
|
'(b h w) (h1 w1) c -> b (h h1 w w1) c',
|
|
h=self.patch_embed3.grid_size[0] // self.window_size,
|
|
w=self.patch_embed3.grid_size[1] // self.window_size,
|
|
h1=self.window_size,
|
|
w1=self.window_size,
|
|
)
|
|
x = self.blocks3[10](x, rel_pos_bias=self.global_rel_pos_bias())
|
|
x = rearrange(
|
|
x,
|
|
'b (h w) c -> b c h w',
|
|
h=self.patch_embed3.grid_size[0],
|
|
w=self.patch_embed3.grid_size[1],
|
|
)
|
|
|
|
outs['s2'] = x
|
|
outs['s3'] = self.ms_adaptor[-1](x)
|
|
|
|
return [outs['s0'], outs['s1'], outs['s2'], outs['s3']]
|
|
|
|
def load_pretrained(self, pretrained):
|
|
from mmcv.runner.checkpoint import _load_checkpoint
|
|
|
|
if is_oss_path(pretrained):
|
|
_, fname = os.path.split(pretrained)
|
|
cache_file = os.path.join(CACHE_DIR, fname)
|
|
if not os.path.exists(cache_file):
|
|
print(f'download checkpoint from {pretrained} to {cache_file}')
|
|
io.copy(pretrained, cache_file)
|
|
if torch.distributed.is_available(
|
|
) and torch.distributed.is_initialized():
|
|
torch.distributed.barrier()
|
|
pretrained = cache_file
|
|
|
|
unexpected_keys, missing_keys = [], []
|
|
checkpoint = _load_checkpoint(pretrained, map_location='cpu')
|
|
if 'state_dict' in checkpoint:
|
|
checkpoint = checkpoint['state_dict']
|
|
else:
|
|
checkpoint = checkpoint
|
|
|
|
for k in self.state_dict().keys():
|
|
if k not in checkpoint.keys():
|
|
missing_keys.append(k)
|
|
for k in checkpoint.keys():
|
|
if k not in self.state_dict().keys():
|
|
unexpected_keys.append(k)
|
|
|
|
if 'pos_embed' in checkpoint:
|
|
checkpoint['pos_embed'] = resize_pos_embed(
|
|
checkpoint['pos_embed'],
|
|
self.pos_embed,
|
|
self.num_tokens,
|
|
self.patch_embed3.grid_size,
|
|
)
|
|
self.load_state_dict(checkpoint, strict=False)
|
|
print(f'Loading ViT pretrained weights from {pretrained}.')
|
|
print(f'missing keys: {missing_keys}')
|
|
print(f'unexpected keys: {unexpected_keys}')
|
|
|
|
if 'rel_pos_bias.relative_position_bias_table' in unexpected_keys:
|
|
windowed_relative_position_bias_table = resize_pos_embed(
|
|
checkpoint['rel_pos_bias.relative_position_bias_table'][
|
|
None, :-3],
|
|
self.windowed_rel_pos_bias.relative_position_bias_table[None],
|
|
0,
|
|
)
|
|
global_relative_position_bias_table = resize_pos_embed(
|
|
checkpoint['rel_pos_bias.relative_position_bias_table'][
|
|
None, :-3],
|
|
self.global_rel_pos_bias.relative_position_bias_table[None],
|
|
0,
|
|
)
|
|
self.windowed_rel_pos_bias.load_state_dict(
|
|
{
|
|
'relative_position_bias_table':
|
|
windowed_relative_position_bias_table[0]
|
|
},
|
|
strict=False)
|
|
self.global_rel_pos_bias.load_state_dict(
|
|
{
|
|
'relative_position_bias_table':
|
|
global_relative_position_bias_table[0]
|
|
},
|
|
strict=False)
|
|
print('Load positional bias table from pretrained.')
|