# Copyright (c) 2022 OpenGVLab # Copyright (c) OpenMMLab. All rights reserved. # modified from # https://github.com/OpenGVLab/InternImage/blob/master/classification/models/intern_image.py import torch import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn.bricks import DropPath, build_activation_layer from mmcv.cnn.bricks.transformer import FFN from mmengine.model.weight_init import trunc_normal_ from ops_dcnv3 import modules as opsm from mmpretrain.models.backbones.base_backbone import BaseBackbone from mmpretrain.models.utils import CrossMultiheadAttention from mmpretrain.registry import MODELS class to_channels_first(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.permute(0, 3, 1, 2) class to_channels_last(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.permute(0, 2, 3, 1) def build_norm_layer(dim, norm_layer, in_format='channels_last', out_format='channels_last', eps=1e-6): layers = [] if norm_layer == 'BN': if in_format == 'channels_last': layers.append(to_channels_first()) layers.append(nn.BatchNorm2d(dim)) if out_format == 'channels_last': layers.append(to_channels_last()) elif norm_layer == 'LN': if in_format == 'channels_first': layers.append(to_channels_last()) layers.append(nn.LayerNorm(dim, eps=eps)) if out_format == 'channels_first': layers.append(to_channels_first()) else: raise NotImplementedError( f'build_norm_layer does not support {norm_layer}') return nn.Sequential(*layers) class AttentiveBlock(nn.Module): """Attentive Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. Default: False. qk_scale (float, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. drop (float, optional): Dropout rate. Default: 0.0. attn_drop (float, optional): Attention dropout rate. Default: 0.0. drop_path (float, optional): Stochastic depth rate. Default: 0.0. norm_cfg (dict, optional): Normalization layer. Default: dict(type='LN') out_dim (int, optional): Dimension of output. Default: None. """ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_cfg=dict(type='LN'), out_dim=None): super().__init__() norm_layer = norm_cfg['type'] self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6) self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6) self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6) self.cross_dcn = CrossMultiheadAttention( embed_dims=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) if out_dim and out_dim != dim: self.cross_dcn.proj = nn.Linear(dim, out_dim) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() def forward(self, x_q, x_kv, pos_q, pos_k): x_q = self.norm1_q(x_q + pos_q) x_k = self.norm1_k(x_kv + pos_k) x_v = self.norm1_v(x_kv) x = self.cross_dcn(x_q, k=x_k, v=x_v) return x class AttentionPoolingBlock(AttentiveBlock): def forward(self, x): x_q = x.mean(1, keepdim=True) x_kv = x pos_q, pos_k = 0, 0 x = super().forward(x_q, x_kv, pos_q, pos_k) x = x.squeeze(1) return x class DownsampleLayer(nn.Module): """Downsample layer of InternImage. Args: channels (int): number of input channels norm_layer (str): normalization layer """ def __init__(self, channels, norm_layer='LN'): super().__init__() self.conv = nn.Conv2d( channels, 2 * channels, kernel_size=3, stride=2, padding=1, bias=False) self.norm = build_norm_layer(2 * channels, norm_layer, 'channels_first', 'channels_last') def forward(self, x): x = self.conv(x.permute(0, 3, 1, 2)) x = self.norm(x) return x class InternImageLayer(nn.Module): """Basic layer of InternImage. Args: core_op (nn.Module): core operation of InternImage channels (int): number of input channels groups (list): Groups of each block. mlp_ratio (float): ratio of mlp hidden features to input channels drop (float): dropout rate drop_path (float): drop path rate act_cfg (dict): activation layer norm_cfg (dict): normalization layer post_norm (bool): whether to use post normalization layer_scale (float): layer scale offset_scale (float): offset scale with_cp (bool): whether to use checkpoint """ def __init__( self, core_op, channels, groups, mlp_ratio=4., drop=0., drop_path=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), post_norm=False, layer_scale=None, offset_scale=1.0, with_cp=False, dw_kernel_size=None, res_post_norm=False, center_feature_scale=False, remove_center=False, ): super().__init__() self.channels = channels self.groups = groups self.mlp_ratio = mlp_ratio self.with_cp = with_cp self.norm1 = build_norm_layer(channels, 'LN') self.post_norm = post_norm self.dcn = core_op( channels=channels, kernel_size=3, stride=1, pad=1, dilation=1, group=groups, offset_scale=offset_scale, act_layer=act_cfg['type'], norm_layer=norm_cfg['type'], dw_kernel_size=dw_kernel_size, center_feature_scale=center_feature_scale, remove_center=remove_center, ) self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.norm2 = build_norm_layer(channels, 'LN') self.mlp = FFN( embed_dims=channels, feedforward_channels=int(channels * mlp_ratio), act_cfg=act_cfg, ffn_drop=drop, add_identity=False) self.layer_scale = layer_scale is not None if self.layer_scale: self.gamma1 = nn.Parameter( layer_scale * torch.ones(channels), requires_grad=True) self.gamma2 = nn.Parameter( layer_scale * torch.ones(channels), requires_grad=True) self.res_post_norm = res_post_norm if res_post_norm: self.res_post_norm1 = build_norm_layer(channels, 'LN') self.res_post_norm2 = build_norm_layer(channels, 'LN') def forward(self, x): def _inner_forward(x): if not self.layer_scale: if self.post_norm: x = x + self.drop_path(self.norm1(self.dcn(x))) x = x + self.drop_path(self.norm2(self.mlp(x))) elif self.res_post_norm: x = x + self.drop_path( self.res_post_norm1(self.dcn(self.norm1(x)))) x = x + self.drop_path( self.res_post_norm2(self.mlp(self.norm2(x)))) else: x = x + self.drop_path(self.dcn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x if self.post_norm: x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x))) x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x))) x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) return x if self.with_cp and x.requires_grad: x = cp.checkpoint(_inner_forward, x) else: x = _inner_forward(x) return x class InternImageBlock(nn.Module): """Block of InternImage. Args: core_op (nn.Module): core operation of InternImage channels (int): number of input channels depths (list): Depth of each block. groups (list): Groups of each block. mlp_ratio (float): ratio of mlp hidden features to input channels drop (float): dropout rate drop_path (float): drop path rate act_cfg (dict): activation layer norm_cfg (dict): normalization layer post_norm (bool): whether to use post normalization layer_scale (float): layer scale offset_scale (float): offset scale with_cp (bool): whether to use checkpoint """ def __init__( self, core_op, channels, depth, groups, downsample=True, mlp_ratio=4., drop=0., drop_path=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), post_norm=False, offset_scale=1.0, layer_scale=None, with_cp=False, dw_kernel_size=None, post_norm_block_ids=None, res_post_norm=False, center_feature_scale=False, remove_center=False, ): super().__init__() self.channels = channels self.depth = depth self.post_norm = post_norm self.center_feature_scale = center_feature_scale self.blocks = nn.ModuleList([ InternImageLayer( core_op=core_op, channels=channels, groups=groups, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, act_cfg=act_cfg, norm_cfg=norm_cfg, post_norm=post_norm, layer_scale=layer_scale, offset_scale=offset_scale, with_cp=with_cp, dw_kernel_size=dw_kernel_size, res_post_norm=res_post_norm, center_feature_scale=center_feature_scale, remove_center=remove_center, ) for i in range(depth) ]) if not self.post_norm or center_feature_scale: self.norm = build_norm_layer(channels, 'LN') self.post_norm_block_ids = post_norm_block_ids if post_norm_block_ids is not None: self.post_norms = nn.ModuleList([ build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids ]) self.downsample = DownsampleLayer( channels=channels, norm_layer=norm_cfg['type']) if downsample else None def forward(self, x, return_wo_downsample=False): for i, blk in enumerate(self.blocks): x = blk(x) if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids): index = self.post_norm_block_ids.index(i) x = self.post_norms[index](x) if not self.post_norm or self.center_feature_scale: x = self.norm(x) if return_wo_downsample: x_ = x if self.downsample is not None: x = self.downsample(x) if return_wo_downsample: return x, x_ return x @MODELS.register_module() class InternImage(BaseBackbone): """ InternImage A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` - https://arxiv.org/pdf/2103.14030 Args: core_op (str): Core operator. Default: 'DCNv3' stem_channels (int): Number of the first stage. Default: 64 stage_blocks (list): Depth of each block. Default: [3, 4, 18, 5] groups (list): Groups of each block. Default: [3, 6, 12, 24] num_classes (int): Number of classes. Default: 1000 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Probability of an element to be zeroed. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0. act_cfg (dict): Activation layer. Default: dict(type='GELU') norm_cfg (dict): Normalization layer. Default: dict(type='LN') layer_scale (bool): Whether to use layer scale. Default: False cls_scale (bool): Whether to use class scale. Default: False with_cp (bool): Use checkpoint or not. Using checkpoint will save some dw_kernel_size (int): Size of the dwconv. Default: None use_clip_projector (bool): Whether to use clip projector. Default: False level2_post_norm (bool): Whether to use level2 post norm. Default: False level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None res_post_norm (bool): Whether to use res post norm. Default: False center_feature_scale (bool): Whether to use center feature scale. Default: False """ # noqa: E501 def __init__(self, stem_channels=64, stage_blocks=[3, 4, 18, 5], groups=[3, 6, 12, 24], mlp_ratio=4., drop_rate=0., drop_path_rate=0.2, drop_path_type='linear', act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), layer_scale=None, offset_scale=1.0, post_norm=False, cls_scale=1.5, with_cp=False, dw_kernel_size=None, use_clip_projector=False, level2_post_norm=False, level2_post_norm_block_ids=None, res_post_norm=False, center_feature_scale=False, remove_center=False, init_cfg=None): super(InternImage, self).__init__(init_cfg) self.core_op = 'DCNv3' self.num_stages = len(stage_blocks) self.num_features = int(stem_channels * 2**(self.num_stages - 1)) self.post_norm = post_norm self.mlp_ratio = mlp_ratio self.use_clip_projector = use_clip_projector self.level2_post_norm_block_ids = level2_post_norm_block_ids self.remove_center = remove_center self.act_cfg = act_cfg self.norm_cfg = norm_cfg # stem layer self._make_stem_layer(in_channels=3, stem_channels=stem_channels) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth decay rule total_depth = sum(stage_blocks) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] if drop_path_type == 'uniform': for i in range(len(dpr)): dpr[i] = drop_path_rate # InternImage Layers self.layers = nn.ModuleList() for i in range(self.num_stages): if level2_post_norm and i == 2: post_norm_block_ids = level2_post_norm_block_ids else: post_norm_block_ids = None layer = InternImageBlock( core_op=getattr(opsm, self.core_op), channels=int(stem_channels * 2**i), depth=stage_blocks[i], groups=groups[i], mlp_ratio=self.mlp_ratio, drop=drop_rate, drop_path=dpr[sum(stage_blocks[:i]):sum(stage_blocks[:i + 1])], act_cfg=act_cfg, norm_cfg=norm_cfg, post_norm=post_norm, downsample=(i < self.num_stages - 1), layer_scale=layer_scale, offset_scale=offset_scale, with_cp=with_cp, dw_kernel_size=dw_kernel_size, post_norm_block_ids=post_norm_block_ids, res_post_norm=res_post_norm, center_feature_scale=center_feature_scale, remove_center=remove_center, ) self.layers.append(layer) # Conv Head if not use_clip_projector: self.conv_head = nn.Sequential( nn.Conv2d( self.num_features, int(self.num_features * cls_scale), kernel_size=1, bias=False), build_norm_layer( int(self.num_features * cls_scale), 'BN', 'channels_first', 'channels_first'), build_activation_layer(act_cfg)) else: pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim \ = 1024, 2, 16, 768 self.dcnv3_head_x4 = nn.Sequential( nn.Conv2d( in_channels=self.num_features, out_channels=pretrain_embed_dim * (_stride**2), kernel_size=1), nn.PixelShuffle(_stride)) self.dcnv3_head_x3 = nn.Conv2d( in_channels=self.num_features // 2, out_channels=pretrain_embed_dim, kernel_size=1) self.clip_projector = AttentionPoolingBlock( dim=pretrain_embed_dim, num_heads=attnpool_num_heads, qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., norm_cfg=norm_cfg, out_dim=clip_embed_dim) norm_layer = norm_cfg['type'] self.fc_norm = build_norm_layer( clip_embed_dim, norm_layer, eps=1e-6) def init_weights(self): super(InternImage, self).init_weights() for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, getattr(opsm, self.core_op)): m._reset_parameters() def _make_stem_layer(self, in_channels, stem_channels): norm_layer = self.norm_cfg['type'] self.patch_embed = nn.Sequential( nn.Conv2d( in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1), build_norm_layer(stem_channels // 2, norm_layer, 'channels_first', 'channels_first'), build_activation_layer(self.act_cfg), nn.Conv2d( stem_channels // 2, stem_channels, kernel_size=3, stride=2, padding=1), build_norm_layer(stem_channels, norm_layer, 'channels_first', 'channels_last'), ) def forward_features(self, x): x = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.conv_head(x.permute(0, 3, 1, 2)) return (x, ) def forward_features_seq_out(self, x): x = self.patch_embed(x) x = self.pos_drop(x) seq_out = [] for layer in self.layers: x, x_ = layer(x, return_wo_downsample=True) seq_out.append(x_) return seq_out def forward_clip_projector(self, x): # for InternImage-H/G xs = self.forward_features_seq_out(x) x1, x2, x3, x4 = xs x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW x4 = self.dcnv3_head_x4(x4) x = x4 x3 = self.dcnv3_head_x3(x3) x = x + x3 x = x.flatten(-2).transpose(1, 2).contiguous() x = self.clip_projector(x) x = self.fc_norm(x) return (x, ) def forward(self, x): if not self.use_clip_projector: # for InternImage-T/S/B/L/XL return self.forward_features(x) else: # for InternImage-H/G return self.forward_clip_projector(x) @staticmethod def _checkpoint_filter(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): def internimage_to_mmpretrain(): for k, v in state_dict['model'].items(): if 'head.' in k and 'conv_head' not in k: if 'weight' in k: new_k = 'head.fc.weight' else: new_k = 'head.fc.bias' elif 'patch_embed' in k: map_fun = { 'conv1': '0', 'norm1': '1', 'conv2': '3', 'norm2': '4' } new_k = k for old, new in map_fun.items(): new_k = new_k.replace(old, new) new_k = 'backbone.' + new_k elif 'levels' in k: new_k = k.replace('levels', 'layers') if 'mlp' in new_k: new_k = new_k.replace('fc1', 'layers.0.0') new_k = new_k.replace('fc2', 'layers.1') new_k = 'backbone.' + new_k elif 'clip_projector.cross_dcn.k_bias' in k: continue else: new_k = 'backbone.' + k state_dict[new_k] = state_dict['model'][k] del state_dict['model'] # The original weights need to be converted to mmpretrain format. # Some modules in the original weights starts with 'levels', # and in this implement they are replaced with 'layers'. if 'model' in state_dict and 'levels.0.blocks.0.norm1.0.weight'\ in state_dict['model']: internimage_to_mmpretrain()