import copy from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch import nn from torch.cuda.amp import autocast from torch.nn import functional as F from torch.nn.init import constant_, normal_, uniform_, xavier_uniform_ from .transformer_decoder import PositionEmbeddingSine, _get_activation_fn try: from thirdparty.deformable_transformer.modules import MSDeformAttn except: pass def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def c2_xavier_fill(module: nn.Module): """ Initialize `module.weight` using the "XavierFill" implemented in Caffe2. Also initializes `module.bias` to 0. Args: module (torch.nn.Module): module to initialize. """ # Caffe2 implementation of XavierFill in fact # corresponds to kaiming_uniform_ in PyTorch # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`. nn.init.kaiming_uniform_(module.weight, a=1) if module.bias is not None: # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, # torch.Tensor]`. nn.init.constant_(module.bias, 0) class Conv2d(torch.nn.Conv2d): """ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. """ def __init__(self, *args, **kwargs): """ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: Args: norm (nn.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function It assumes that norm layer is used before activation. """ norm = kwargs.pop('norm', None) activation = kwargs.pop('activation', None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): # torchscript does not support SyncBatchNorm yet # https://github.com/pytorch/pytorch/issues/40507 # and we skip these codes in torchscript since: # 1. currently we only support torchscript in evaluation mode # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. if not torch.jit.is_scripting(): if x.numel() == 0 and self.training: # https://github.com/pytorch/pytorch/issues/12013 assert not isinstance( self.norm, torch.nn.SyncBatchNorm ), 'SyncBatchNorm does not support empty inputs!' x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x # Modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/pixel_decoder/msdeformattn.py class MSDeformAttnTransformerEncoderOnly(nn.Module): def __init__( self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation='relu', num_feature_levels=4, enc_n_points=4, ): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = MSDeformAttnTransformerEncoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) self.level_embed = nn.Parameter( torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, pos_embeds): masks = [ torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs ] # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) return memory, spatial_shapes, level_start_index class MSDeformAttnTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation='relu', n_levels=4, n_heads=8, n_points=4): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): # self attention src2 = self.self_attn( self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class MSDeformAttnTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace( 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace( 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): output = src reference_points = self.get_reference_points( spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) return output class MSDeformAttnPixelDecoder(nn.Module): def __init__( self, input_stride, input_channel, *, transformer_dropout: float = 0.0, transformer_nheads: int = 8, transformer_dim_feedforward: int = 1024, transformer_enc_layers: int = 6, conv_dim: int = 256, mask_dim: int = 256, norm: Optional[Union[str, Callable]] = 'GN', # deformable transformer encoder args transformer_in_features: List[int] = [1, 2, 3], common_stride: int = 4, ): """ Args: input_stride: stride of the input features input_channel: channels of the input features transformer_dropout: dropout probability in transformer transformer_nheads: number of heads in transformer transformer_dim_feedforward: dimension of feedforward network transformer_enc_layers: number of transformer encoder layers conv_dims: number of output channels for the intermediate conv layers. mask_dim: number of output channels for the final conv layer. norm (str or callable): normalization for all conv layers """ super().__init__() self.in_features = [i for i in range(len(input_stride))] self.feature_strides = input_stride self.feature_channels = input_channel # this is the input shape of transformer encoder (could use less features than pixel decoder # transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride) self.transformer_in_features = transformer_in_features # starting from "res2" to "res5" transformer_in_channels = [ input_channel[i] for i in transformer_in_features ] self.transformer_feature_strides = [ input_stride[i] for i in transformer_in_features ] # to decide extra FPN layers self.transformer_num_feature_levels = len(transformer_in_features) if self.transformer_num_feature_levels > 1: input_proj_list = [] # from low resolution to high resolution (res5 -> res2) for in_channels in transformer_in_channels[::-1]: input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), )) self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d( transformer_in_channels[-1], conv_dim, kernel_size=1), nn.GroupNorm(32, conv_dim), ) ]) for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) self.transformer = MSDeformAttnTransformerEncoderOnly( d_model=conv_dim, dropout=transformer_dropout, nhead=transformer_nheads, dim_feedforward=transformer_dim_feedforward, num_encoder_layers=transformer_enc_layers, num_feature_levels=self.transformer_num_feature_levels, ) N_steps = conv_dim // 2 self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) self.mask_dim = mask_dim # use 1x1 conv instead self.mask_features = Conv2d( conv_dim, mask_dim, kernel_size=1, stride=1, padding=0, ) c2_xavier_fill(self.mask_features) self.maskformer_num_feature_levels = 3 # always use 3 scales self.common_stride = common_stride # extra fpn levels stride = min(self.transformer_feature_strides) self.num_fpn_levels = int( np.log2(stride) - np.log2(self.common_stride)) lateral_convs = [] output_convs = [] # use_bias = norm == "" use_bias = False for idx, in_channels in enumerate( self.feature_channels[:self.num_fpn_levels]): lateral_norm = torch.nn.GroupNorm(32, conv_dim) output_norm = torch.nn.GroupNorm(32, conv_dim) lateral_conv = Conv2d( in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm) output_conv = Conv2d( conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, activation=F.relu, ) c2_xavier_fill(lateral_conv) c2_xavier_fill(output_conv) self.add_module('adapter_{}'.format(idx + 1), lateral_conv) self.add_module('layer_{}'.format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] @autocast(enabled=False) def forward_features(self, features): srcs = [] pos = [] # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.transformer_in_features[::-1]): x = features[f].float( ) # deformable detr does not support half precision srcs.append(self.input_proj[idx](x)) pos.append(self.pe_layer(x)) y, spatial_shapes, level_start_index = self.transformer(srcs, pos) bs = y.shape[0] split_size_or_sections = [None] * self.transformer_num_feature_levels for i in range(self.transformer_num_feature_levels): if i < self.transformer_num_feature_levels - 1: split_size_or_sections[i] = level_start_index[ i + 1] - level_start_index[i] else: split_size_or_sections[i] = y.shape[1] - level_start_index[i] y = torch.split(y, split_size_or_sections, dim=1) out = [] multi_scale_features = [] num_cur_levels = 0 for i, z in enumerate(y): out.append( z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) # append `out` with extra FPN levels # Reverse feature maps into top-down order (from low to high resolution) for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): x = features[f].float() lateral_conv = self.lateral_convs[idx] output_conv = self.output_convs[idx] cur_fpn = lateral_conv(x) # Following FPN implementation, we use nearest upsampling here y = cur_fpn + F.interpolate( out[-1], size=cur_fpn.shape[-2:], mode='bilinear', align_corners=False) y = output_conv(y) out.append(y) for o in out: if num_cur_levels < self.maskformer_num_feature_levels: multi_scale_features.append(o) num_cur_levels += 1 return self.mask_features(out[-1]), out[0], multi_scale_features