mirror of https://github.com/alibaba/EasyCV.git
459 lines
17 KiB
Python
459 lines
17 KiB
Python
import copy
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.cuda.amp import autocast
|
|
from torch.nn import functional as F
|
|
from torch.nn.init import constant_, normal_, uniform_, xavier_uniform_
|
|
|
|
from .transformer_decoder import PositionEmbeddingSine, _get_activation_fn
|
|
|
|
try:
|
|
from thirdparty.deformable_transformer.modules import MSDeformAttn
|
|
except:
|
|
pass
|
|
|
|
|
|
def _get_clones(module, N):
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
|
|
def c2_xavier_fill(module: nn.Module):
|
|
"""
|
|
Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
|
|
Also initializes `module.bias` to 0.
|
|
Args:
|
|
module (torch.nn.Module): module to initialize.
|
|
"""
|
|
# Caffe2 implementation of XavierFill in fact
|
|
# corresponds to kaiming_uniform_ in PyTorch
|
|
# pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`.
|
|
nn.init.kaiming_uniform_(module.weight, a=1)
|
|
if module.bias is not None:
|
|
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
|
|
# torch.Tensor]`.
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
"""
|
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
|
Args:
|
|
norm (nn.Module, optional): a normalization layer
|
|
activation (callable(Tensor) -> Tensor): a callable activation function
|
|
It assumes that norm layer is used before activation.
|
|
"""
|
|
norm = kwargs.pop('norm', None)
|
|
activation = kwargs.pop('activation', None)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.norm = norm
|
|
self.activation = activation
|
|
|
|
def forward(self, x):
|
|
# torchscript does not support SyncBatchNorm yet
|
|
# https://github.com/pytorch/pytorch/issues/40507
|
|
# and we skip these codes in torchscript since:
|
|
# 1. currently we only support torchscript in evaluation mode
|
|
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
|
|
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
|
|
if not torch.jit.is_scripting():
|
|
if x.numel() == 0 and self.training:
|
|
# https://github.com/pytorch/pytorch/issues/12013
|
|
assert not isinstance(
|
|
self.norm, torch.nn.SyncBatchNorm
|
|
), 'SyncBatchNorm does not support empty inputs!'
|
|
|
|
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
|
self.dilation, self.groups)
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
if self.activation is not None:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
# Modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/pixel_decoder/msdeformattn.py
|
|
class MSDeformAttnTransformerEncoderOnly(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
d_model=256,
|
|
nhead=8,
|
|
num_encoder_layers=6,
|
|
dim_feedforward=1024,
|
|
dropout=0.1,
|
|
activation='relu',
|
|
num_feature_levels=4,
|
|
enc_n_points=4,
|
|
):
|
|
super().__init__()
|
|
|
|
self.d_model = d_model
|
|
self.nhead = nhead
|
|
|
|
encoder_layer = MSDeformAttnTransformerEncoderLayer(
|
|
d_model, dim_feedforward, dropout, activation, num_feature_levels,
|
|
nhead, enc_n_points)
|
|
self.encoder = MSDeformAttnTransformerEncoder(encoder_layer,
|
|
num_encoder_layers)
|
|
|
|
self.level_embed = nn.Parameter(
|
|
torch.Tensor(num_feature_levels, d_model))
|
|
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
for m in self.modules():
|
|
if isinstance(m, MSDeformAttn):
|
|
m._reset_parameters()
|
|
normal_(self.level_embed)
|
|
|
|
def get_valid_ratio(self, mask):
|
|
_, H, W = mask.shape
|
|
valid_H = torch.sum(~mask[:, :, 0], 1)
|
|
valid_W = torch.sum(~mask[:, 0, :], 1)
|
|
valid_ratio_h = valid_H.float() / H
|
|
valid_ratio_w = valid_W.float() / W
|
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
|
return valid_ratio
|
|
|
|
def forward(self, srcs, pos_embeds):
|
|
masks = [
|
|
torch.zeros((x.size(0), x.size(2), x.size(3)),
|
|
device=x.device,
|
|
dtype=torch.bool) for x in srcs
|
|
]
|
|
# prepare input for encoder
|
|
src_flatten = []
|
|
mask_flatten = []
|
|
lvl_pos_embed_flatten = []
|
|
spatial_shapes = []
|
|
for lvl, (src, mask,
|
|
pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
|
bs, c, h, w = src.shape
|
|
spatial_shape = (h, w)
|
|
spatial_shapes.append(spatial_shape)
|
|
src = src.flatten(2).transpose(1, 2)
|
|
mask = mask.flatten(1)
|
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
|
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
|
src_flatten.append(src)
|
|
mask_flatten.append(mask)
|
|
src_flatten = torch.cat(src_flatten, 1)
|
|
mask_flatten = torch.cat(mask_flatten, 1)
|
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
|
spatial_shapes = torch.as_tensor(
|
|
spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
|
level_start_index = torch.cat((spatial_shapes.new_zeros(
|
|
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
|
# encoder
|
|
memory = self.encoder(src_flatten, spatial_shapes, level_start_index,
|
|
valid_ratios, lvl_pos_embed_flatten,
|
|
mask_flatten)
|
|
|
|
return memory, spatial_shapes, level_start_index
|
|
|
|
|
|
class MSDeformAttnTransformerEncoderLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
d_model=256,
|
|
d_ffn=1024,
|
|
dropout=0.1,
|
|
activation='relu',
|
|
n_levels=4,
|
|
n_heads=8,
|
|
n_points=4):
|
|
super().__init__()
|
|
|
|
# self attention
|
|
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
# ffn
|
|
self.linear1 = nn.Linear(d_model, d_ffn)
|
|
self.activation = _get_activation_fn(activation)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(d_ffn, d_model)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
@staticmethod
|
|
def with_pos_embed(tensor, pos):
|
|
return tensor if pos is None else tensor + pos
|
|
|
|
def forward_ffn(self, src):
|
|
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
|
src = src + self.dropout3(src2)
|
|
src = self.norm2(src)
|
|
return src
|
|
|
|
def forward(self,
|
|
src,
|
|
pos,
|
|
reference_points,
|
|
spatial_shapes,
|
|
level_start_index,
|
|
padding_mask=None):
|
|
# self attention
|
|
src2 = self.self_attn(
|
|
self.with_pos_embed(src, pos), reference_points, src,
|
|
spatial_shapes, level_start_index, padding_mask)
|
|
src = src + self.dropout1(src2)
|
|
src = self.norm1(src)
|
|
|
|
# ffn
|
|
src = self.forward_ffn(src)
|
|
|
|
return src
|
|
|
|
|
|
class MSDeformAttnTransformerEncoder(nn.Module):
|
|
|
|
def __init__(self, encoder_layer, num_layers):
|
|
super().__init__()
|
|
self.layers = _get_clones(encoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
|
|
@staticmethod
|
|
def get_reference_points(spatial_shapes, valid_ratios, device):
|
|
reference_points_list = []
|
|
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
|
|
|
ref_y, ref_x = torch.meshgrid(
|
|
torch.linspace(
|
|
0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
|
torch.linspace(
|
|
0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
|
|
ref_y = ref_y.reshape(-1)[None] / (
|
|
valid_ratios[:, None, lvl, 1] * H_)
|
|
ref_x = ref_x.reshape(-1)[None] / (
|
|
valid_ratios[:, None, lvl, 0] * W_)
|
|
ref = torch.stack((ref_x, ref_y), -1)
|
|
reference_points_list.append(ref)
|
|
reference_points = torch.cat(reference_points_list, 1)
|
|
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
|
return reference_points
|
|
|
|
def forward(self,
|
|
src,
|
|
spatial_shapes,
|
|
level_start_index,
|
|
valid_ratios,
|
|
pos=None,
|
|
padding_mask=None):
|
|
output = src
|
|
reference_points = self.get_reference_points(
|
|
spatial_shapes, valid_ratios, device=src.device)
|
|
for _, layer in enumerate(self.layers):
|
|
output = layer(output, pos, reference_points, spatial_shapes,
|
|
level_start_index, padding_mask)
|
|
|
|
return output
|
|
|
|
|
|
class MSDeformAttnPixelDecoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
input_stride,
|
|
input_channel,
|
|
*,
|
|
transformer_dropout: float = 0.0,
|
|
transformer_nheads: int = 8,
|
|
transformer_dim_feedforward: int = 1024,
|
|
transformer_enc_layers: int = 6,
|
|
conv_dim: int = 256,
|
|
mask_dim: int = 256,
|
|
norm: Optional[Union[str, Callable]] = 'GN',
|
|
# deformable transformer encoder args
|
|
transformer_in_features: List[int] = [1, 2, 3],
|
|
common_stride: int = 4,
|
|
):
|
|
"""
|
|
Args:
|
|
input_stride: stride of the input features
|
|
input_channel: channels of the input features
|
|
transformer_dropout: dropout probability in transformer
|
|
transformer_nheads: number of heads in transformer
|
|
transformer_dim_feedforward: dimension of feedforward network
|
|
transformer_enc_layers: number of transformer encoder layers
|
|
conv_dims: number of output channels for the intermediate conv layers.
|
|
mask_dim: number of output channels for the final conv layer.
|
|
norm (str or callable): normalization for all conv layers
|
|
"""
|
|
super().__init__()
|
|
self.in_features = [i for i in range(len(input_stride))]
|
|
self.feature_strides = input_stride
|
|
self.feature_channels = input_channel
|
|
|
|
# this is the input shape of transformer encoder (could use less features than pixel decoder
|
|
# transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
|
|
self.transformer_in_features = transformer_in_features # starting from "res2" to "res5"
|
|
transformer_in_channels = [
|
|
input_channel[i] for i in transformer_in_features
|
|
]
|
|
self.transformer_feature_strides = [
|
|
input_stride[i] for i in transformer_in_features
|
|
] # to decide extra FPN layers
|
|
|
|
self.transformer_num_feature_levels = len(transformer_in_features)
|
|
if self.transformer_num_feature_levels > 1:
|
|
input_proj_list = []
|
|
# from low resolution to high resolution (res5 -> res2)
|
|
for in_channels in transformer_in_channels[::-1]:
|
|
input_proj_list.append(
|
|
nn.Sequential(
|
|
nn.Conv2d(in_channels, conv_dim, kernel_size=1),
|
|
nn.GroupNorm(32, conv_dim),
|
|
))
|
|
self.input_proj = nn.ModuleList(input_proj_list)
|
|
else:
|
|
self.input_proj = nn.ModuleList([
|
|
nn.Sequential(
|
|
nn.Conv2d(
|
|
transformer_in_channels[-1], conv_dim, kernel_size=1),
|
|
nn.GroupNorm(32, conv_dim),
|
|
)
|
|
])
|
|
|
|
for proj in self.input_proj:
|
|
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
|
nn.init.constant_(proj[0].bias, 0)
|
|
|
|
self.transformer = MSDeformAttnTransformerEncoderOnly(
|
|
d_model=conv_dim,
|
|
dropout=transformer_dropout,
|
|
nhead=transformer_nheads,
|
|
dim_feedforward=transformer_dim_feedforward,
|
|
num_encoder_layers=transformer_enc_layers,
|
|
num_feature_levels=self.transformer_num_feature_levels,
|
|
)
|
|
N_steps = conv_dim // 2
|
|
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
|
|
|
self.mask_dim = mask_dim
|
|
# use 1x1 conv instead
|
|
self.mask_features = Conv2d(
|
|
conv_dim,
|
|
mask_dim,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
)
|
|
c2_xavier_fill(self.mask_features)
|
|
|
|
self.maskformer_num_feature_levels = 3 # always use 3 scales
|
|
self.common_stride = common_stride
|
|
|
|
# extra fpn levels
|
|
stride = min(self.transformer_feature_strides)
|
|
self.num_fpn_levels = int(
|
|
np.log2(stride) - np.log2(self.common_stride))
|
|
|
|
lateral_convs = []
|
|
output_convs = []
|
|
|
|
# use_bias = norm == ""
|
|
use_bias = False
|
|
for idx, in_channels in enumerate(
|
|
self.feature_channels[:self.num_fpn_levels]):
|
|
lateral_norm = torch.nn.GroupNorm(32, conv_dim)
|
|
output_norm = torch.nn.GroupNorm(32, conv_dim)
|
|
|
|
lateral_conv = Conv2d(
|
|
in_channels,
|
|
conv_dim,
|
|
kernel_size=1,
|
|
bias=use_bias,
|
|
norm=lateral_norm)
|
|
output_conv = Conv2d(
|
|
conv_dim,
|
|
conv_dim,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=use_bias,
|
|
norm=output_norm,
|
|
activation=F.relu,
|
|
)
|
|
c2_xavier_fill(lateral_conv)
|
|
c2_xavier_fill(output_conv)
|
|
self.add_module('adapter_{}'.format(idx + 1), lateral_conv)
|
|
self.add_module('layer_{}'.format(idx + 1), output_conv)
|
|
|
|
lateral_convs.append(lateral_conv)
|
|
output_convs.append(output_conv)
|
|
# Place convs into top-down order (from low to high resolution)
|
|
# to make the top-down computation in forward clearer.
|
|
self.lateral_convs = lateral_convs[::-1]
|
|
self.output_convs = output_convs[::-1]
|
|
|
|
@autocast(enabled=False)
|
|
def forward_features(self, features):
|
|
srcs = []
|
|
pos = []
|
|
# Reverse feature maps into top-down order (from low to high resolution)
|
|
for idx, f in enumerate(self.transformer_in_features[::-1]):
|
|
x = features[f].float(
|
|
) # deformable detr does not support half precision
|
|
srcs.append(self.input_proj[idx](x))
|
|
pos.append(self.pe_layer(x))
|
|
|
|
y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
|
|
bs = y.shape[0]
|
|
|
|
split_size_or_sections = [None] * self.transformer_num_feature_levels
|
|
for i in range(self.transformer_num_feature_levels):
|
|
if i < self.transformer_num_feature_levels - 1:
|
|
split_size_or_sections[i] = level_start_index[
|
|
i + 1] - level_start_index[i]
|
|
else:
|
|
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
|
|
y = torch.split(y, split_size_or_sections, dim=1)
|
|
|
|
out = []
|
|
multi_scale_features = []
|
|
num_cur_levels = 0
|
|
for i, z in enumerate(y):
|
|
out.append(
|
|
z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0],
|
|
spatial_shapes[i][1]))
|
|
|
|
# append `out` with extra FPN levels
|
|
# Reverse feature maps into top-down order (from low to high resolution)
|
|
for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
|
|
x = features[f].float()
|
|
lateral_conv = self.lateral_convs[idx]
|
|
output_conv = self.output_convs[idx]
|
|
cur_fpn = lateral_conv(x)
|
|
# Following FPN implementation, we use nearest upsampling here
|
|
y = cur_fpn + F.interpolate(
|
|
out[-1],
|
|
size=cur_fpn.shape[-2:],
|
|
mode='bilinear',
|
|
align_corners=False)
|
|
y = output_conv(y)
|
|
out.append(y)
|
|
|
|
for o in out:
|
|
if num_cur_levels < self.maskformer_num_feature_levels:
|
|
multi_scale_features.append(o)
|
|
num_cur_levels += 1
|
|
|
|
return self.mask_features(out[-1]), out[0], multi_scale_features
|