EasyCV/easycv/models/segmentation/heads/pixel_decoder.py

459 lines
17 KiB
Python

import copy
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.nn.init import constant_, normal_, uniform_, xavier_uniform_
from .transformer_decoder import PositionEmbeddingSine, _get_activation_fn
try:
from thirdparty.deformable_transformer.modules import MSDeformAttn
except:
pass
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def c2_xavier_fill(module: nn.Module):
"""
Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
# Caffe2 implementation of XavierFill in fact
# corresponds to kaiming_uniform_ in PyTorch
# pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`.
nn.init.kaiming_uniform_(module.weight, a=1)
if module.bias is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
# torch.Tensor]`.
nn.init.constant_(module.bias, 0)
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop('norm', None)
activation = kwargs.pop('activation', None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), 'SyncBatchNorm does not support empty inputs!'
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
# Modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/pixel_decoder/msdeformattn.py
class MSDeformAttnTransformerEncoderOnly(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_encoder_layers=6,
dim_feedforward=1024,
dropout=0.1,
activation='relu',
num_feature_levels=4,
enc_n_points=4,
):
super().__init__()
self.d_model = d_model
self.nhead = nhead
encoder_layer = MSDeformAttnTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation, num_feature_levels,
nhead, enc_n_points)
self.encoder = MSDeformAttnTransformerEncoder(encoder_layer,
num_encoder_layers)
self.level_embed = nn.Parameter(
torch.Tensor(num_feature_levels, d_model))
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, srcs, pos_embeds):
masks = [
torch.zeros((x.size(0), x.size(2), x.size(3)),
device=x.device,
dtype=torch.bool) for x in srcs
]
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask,
pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index,
valid_ratios, lvl_pos_embed_flatten,
mask_flatten)
return memory, spatial_shapes, level_start_index
class MSDeformAttnTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation='relu',
n_levels=4,
n_heads=8,
n_points=4):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask=None):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, pos), reference_points, src,
spatial_shapes, level_start_index, padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class MSDeformAttnTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(
0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self,
src,
spatial_shapes,
level_start_index,
valid_ratios,
pos=None,
padding_mask=None):
output = src
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes,
level_start_index, padding_mask)
return output
class MSDeformAttnPixelDecoder(nn.Module):
def __init__(
self,
input_stride,
input_channel,
*,
transformer_dropout: float = 0.0,
transformer_nheads: int = 8,
transformer_dim_feedforward: int = 1024,
transformer_enc_layers: int = 6,
conv_dim: int = 256,
mask_dim: int = 256,
norm: Optional[Union[str, Callable]] = 'GN',
# deformable transformer encoder args
transformer_in_features: List[int] = [1, 2, 3],
common_stride: int = 4,
):
"""
Args:
input_stride: stride of the input features
input_channel: channels of the input features
transformer_dropout: dropout probability in transformer
transformer_nheads: number of heads in transformer
transformer_dim_feedforward: dimension of feedforward network
transformer_enc_layers: number of transformer encoder layers
conv_dims: number of output channels for the intermediate conv layers.
mask_dim: number of output channels for the final conv layer.
norm (str or callable): normalization for all conv layers
"""
super().__init__()
self.in_features = [i for i in range(len(input_stride))]
self.feature_strides = input_stride
self.feature_channels = input_channel
# this is the input shape of transformer encoder (could use less features than pixel decoder
# transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
self.transformer_in_features = transformer_in_features # starting from "res2" to "res5"
transformer_in_channels = [
input_channel[i] for i in transformer_in_features
]
self.transformer_feature_strides = [
input_stride[i] for i in transformer_in_features
] # to decide extra FPN layers
self.transformer_num_feature_levels = len(transformer_in_features)
if self.transformer_num_feature_levels > 1:
input_proj_list = []
# from low resolution to high resolution (res5 -> res2)
for in_channels in transformer_in_channels[::-1]:
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, conv_dim, kernel_size=1),
nn.GroupNorm(32, conv_dim),
))
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
transformer_in_channels[-1], conv_dim, kernel_size=1),
nn.GroupNorm(32, conv_dim),
)
])
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
self.transformer = MSDeformAttnTransformerEncoderOnly(
d_model=conv_dim,
dropout=transformer_dropout,
nhead=transformer_nheads,
dim_feedforward=transformer_dim_feedforward,
num_encoder_layers=transformer_enc_layers,
num_feature_levels=self.transformer_num_feature_levels,
)
N_steps = conv_dim // 2
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
self.mask_dim = mask_dim
# use 1x1 conv instead
self.mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=1,
stride=1,
padding=0,
)
c2_xavier_fill(self.mask_features)
self.maskformer_num_feature_levels = 3 # always use 3 scales
self.common_stride = common_stride
# extra fpn levels
stride = min(self.transformer_feature_strides)
self.num_fpn_levels = int(
np.log2(stride) - np.log2(self.common_stride))
lateral_convs = []
output_convs = []
# use_bias = norm == ""
use_bias = False
for idx, in_channels in enumerate(
self.feature_channels[:self.num_fpn_levels]):
lateral_norm = torch.nn.GroupNorm(32, conv_dim)
output_norm = torch.nn.GroupNorm(32, conv_dim)
lateral_conv = Conv2d(
in_channels,
conv_dim,
kernel_size=1,
bias=use_bias,
norm=lateral_norm)
output_conv = Conv2d(
conv_dim,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
c2_xavier_fill(lateral_conv)
c2_xavier_fill(output_conv)
self.add_module('adapter_{}'.format(idx + 1), lateral_conv)
self.add_module('layer_{}'.format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
@autocast(enabled=False)
def forward_features(self, features):
srcs = []
pos = []
# Reverse feature maps into top-down order (from low to high resolution)
for idx, f in enumerate(self.transformer_in_features[::-1]):
x = features[f].float(
) # deformable detr does not support half precision
srcs.append(self.input_proj[idx](x))
pos.append(self.pe_layer(x))
y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
bs = y.shape[0]
split_size_or_sections = [None] * self.transformer_num_feature_levels
for i in range(self.transformer_num_feature_levels):
if i < self.transformer_num_feature_levels - 1:
split_size_or_sections[i] = level_start_index[
i + 1] - level_start_index[i]
else:
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
y = torch.split(y, split_size_or_sections, dim=1)
out = []
multi_scale_features = []
num_cur_levels = 0
for i, z in enumerate(y):
out.append(
z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0],
spatial_shapes[i][1]))
# append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution)
for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
x = features[f].float()
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(x)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(
out[-1],
size=cur_fpn.shape[-2:],
mode='bilinear',
align_corners=False)
y = output_conv(y)
out.append(y)
for o in out:
if num_cur_levels < self.maskformer_num_feature_levels:
multi_scale_features.append(o)
num_cur_levels += 1
return self.mask_features(out[-1]), out[0], multi_scale_features