# Copyright (c) Alibaba, Inc. and its affiliates. import math from functools import partial, reduce import torch import torch.nn as nn from mmcv.cnn import kaiming_init, normal_init from packaging import version from easycv.models.utils import GeMPooling, ResLayer from ..backbones.hrnet import Bottleneck, get_expansion from ..registry import NECKS from ..utils import (ConvModule, _init_weights, build_conv_layer, build_norm_layer) @NECKS.register_module class LinearNeck(nn.Module): '''Linear neck: fc only ''' def __init__(self, in_channels, out_channels, with_avg_pool=True, with_norm=False): super(LinearNeck, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(in_channels, out_channels) self.with_norm = with_norm def init_weights(self, init_linear='normal'): _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) x = self.fc(x.view(x.size(0), -1)) if self.with_norm: x = nn.functional.normalize(x, p=2, dim=1) return [x] @NECKS.register_module class RetrivalNeck(nn.Module): '''RetrivalNeck: refer, Combination of Multiple Global Descriptors for Image Retrieval https://arxiv.org/pdf/1903.10663.pdf CGD feature : only use avg pool + gem pooling + max pooling, by pool -> fc -> norm -> concat -> norm Avg feature : use avg pooling, avg pool -> syncbn -> fc len(cgd_config) > 0: return [CGD, Avg] len(cgd_config) = 0 : return [Avg] ''' def __init__( self, in_channels, out_channels, with_avg_pool=True, cdg_config=[ 'G', 'M' ]): # with_avg_pool=True, with_gem_pool=True, with_norm=False): """ Init RetrivalNeck, faceid neck doesn't pool for input feature map, doesn't support dynamic input Args: in_channels: Int - input feature map channels out_channels: Int - output feature map channels with_avg_pool: bool do avg pool for BNneck or not cdg_config : list('G','M','S'), to configure output feature, CGD = [gempooling] + [maxpooling] + [meanpooling], if len(cgd_config) > 0: return [CGD, Avg] if len(cgd_config) = 0 : return [Avg] """ super(RetrivalNeck, self).__init__() self.with_avg_pool = with_avg_pool self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(in_channels, out_channels, bias=False) self.dropout = nn.Dropout(p=0.3) _, self.bn_output = build_norm_layer(dict(type='BN'), in_channels) # dict(type='SyncBN'), in_channels) self.cdg_config = cdg_config cgd_length = int(len(cdg_config)) if cgd_length > 0: assert (out_channels % cgd_length == 0) if 'M' in cdg_config: self.mpool = nn.AdaptiveMaxPool2d((1, 1)) self.fc_mx = nn.Linear( in_channels, int(out_channels / cgd_length), bias=False) if 'S' in cdg_config: self.spool = nn.AdaptiveAvgPool2d((1, 1)) self.fc_sx = nn.Linear( in_channels, int(out_channels / cgd_length), bias=False) if 'G' in cdg_config: self.gpool = GeMPooling() self.fc_gx = nn.Linear( in_channels, int(out_channels / cgd_length), bias=False) def init_weights(self, init_linear='normal'): _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] # BNNeck with avg pool if self.with_avg_pool: ax = self.avgpool(x) else: ax = x cls_x = self.bn_output(ax) cls_x = self.fc(cls_x.view(x.size(0), -1)) cls_x = self.dropout(cls_x) if len(self.cdg_config) > 0: concat_list = [] if 'S' in self.cdg_config: sx = self.spool(x).view(x.size(0), -1) sx = self.fc_sx(sx) sx = nn.functional.normalize(sx, p=2, dim=1) concat_list.append(sx) if 'G' in self.cdg_config: gx = self.gpool(x).view(x.size(0), -1) gx = self.fc_gx(gx) gx = nn.functional.normalize(gx, p=2, dim=1) concat_list.append(gx) if 'M' in self.cdg_config: mx = self.mpool(x).view(x.size(0), -1) mx = self.fc_mx(mx) mx = nn.functional.normalize(mx, p=2, dim=1) concat_list.append(mx) concatx = torch.cat(concat_list, dim=1) concatx = concatx.view(concatx.size(0), -1) # concatx = nn.functional.normalize(concatx, p=2, dim=1) return [concatx, cls_x] else: return [cls_x] @NECKS.register_module class FaceIDNeck(nn.Module): '''FaceID neck: Include BN, dropout, flatten, linear, bn ''' def __init__(self, in_channels, out_channels, map_shape=1, dropout_ratio=0.4, with_norm=False, bn_type='SyncBN'): """ Init FaceIDNeck, faceid neck doesn't pool for input feature map, doesn't support dynamic input Args: in_channels: Int - input feature map channels out_channels: Int - output feature map channels map_shape: Int or list(int,...), input feature map (w,h) or w when w=h, dropout_ratio : float, drop out ratio with_norm : normalize output feature or not bn_type : SyncBN or BN """ super(FaceIDNeck, self).__init__() if version.parse(torch.__version__) < version.parse('1.4.0'): self.expand_for_syncbn = True else: self.expand_for_syncbn = False # self.bn_input = nn.BatchNorm2d(in_channels) _, self.bn_input = build_norm_layer(dict(type=bn_type), in_channels) self.dropout = nn.Dropout(p=dropout_ratio) if type(map_shape) == list: in_ = int(reduce(lambda x, y: x * y, map_shape) * in_channels) else: assert type(map_shape) == int in_ = in_channels * map_shape * map_shape self.fc = nn.Linear(in_, out_channels) self.with_norm = with_norm self.syncbn = bn_type == 'SyncBN' if self.syncbn: _, self.bn_output = build_norm_layer( dict(type=bn_type), out_channels) else: self.bn_output = nn.BatchNorm1d(out_channels) def _forward_syncbn(self, module, x): assert x.dim() == 2 if self.expand_for_syncbn: x = module(x.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: x = module(x) return x def init_weights(self, init_linear='normal'): _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] x = self.bn_input(x) x = self.dropout(x) x = self.fc(x.view(x.size(0), -1)) # if self.syncbn: x = self._forward_syncbn(self.bn_output, x) # else: # x = self.bn_output(x) if self.with_norm: x = nn.functional.normalize(x, p=2, dim=1) return [x] @NECKS.register_module class MultiLinearNeck(nn.Module): '''MultiLinearNeck neck: MultiFc head ''' def __init__(self, in_channels, out_channels, num_layers=1, with_avg_pool=True): """ Args: in_channels: int or list[int] out_channels: int or list[int] num_layers : total fc num with_avg_pool : input will be avgPool if True Returns: None Raises: len(in_channel) != len(out_channels) len(in_channel) != len(num_layers) """ super(MultiLinearNeck, self).__init__() self.with_avg_pool = with_avg_pool self.num_layers = num_layers if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if num_layers == 1: self.fc = nn.Linear(in_channels, out_channels) else: assert len(in_channels) == len(out_channels) assert len(in_channels) == num_layers self.fc = nn.ModuleList( [nn.Linear(i, j) for i, j in zip(in_channels, out_channels)]) def init_weights(self, init_linear='normal'): _init_weights(self, init_linear) def forward(self, x): assert len(x) == 1 or len(x) == 2 # to fit vit model x = x[0] if self.with_avg_pool: x = self.avgpool(x) x = self.fc(x.view(x.size(0), -1)) return [x] @NECKS.register_module() class HRFuseScales(nn.Module): """Fuse feature map of multiple scales in HRNet. Args: in_channels (list[int]): The input channels of all scales. out_channels (int): The channels of fused feature map. Defaults to 2048. norm_cfg (dict): dictionary to construct norm layers. Defaults to ``dict(type='BN', momentum=0.1)``. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``. """ def __init__(self, in_channels, out_channels=2048, norm_cfg=dict(type='BN', momentum=0.1)): super(HRFuseScales, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.norm_cfg = norm_cfg block_type = Bottleneck out_channels = [128, 256, 512, 1024] # Increase the channels on each resolution # from C, 2C, 4C, 8C to 128, 256, 512, 1024 increase_layers = [] for i in range(len(in_channels)): increase_layers.append( ResLayer( block_type, in_channels=in_channels[i], out_channels=out_channels[i], num_blocks=1, stride=1, )) self.increase_layers = nn.ModuleList(increase_layers) # Downsample feature maps in each scale. downsample_layers = [] for i in range(len(in_channels) - 1): downsample_layers.append( ConvModule( in_channels=out_channels[i], out_channels=out_channels[i + 1], kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, bias=False, )) self.downsample_layers = nn.ModuleList(downsample_layers) # The final conv block before final classifier linear layer. self.final_layer = ConvModule( in_channels=out_channels[3], out_channels=self.out_channels, kernel_size=1, norm_cfg=self.norm_cfg, bias=False, ) def init_weights(self, init_linear='normal'): _init_weights(self, init_linear) def forward(self, x): assert len(x) == len(self.in_channels) feat = self.increase_layers[0](x[0]) for i in range(len(self.downsample_layers)): feat = self.downsample_layers[i](feat) + \ self.increase_layers[i + 1](x[i + 1]) return [self.final_layer(feat)]