""" Author: Guan'an Wang Contact: guan.wang0706@gmail.com """ import torch from torch import nn from collections import OrderedDict import logging from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message from fastreid.layers import get_norm from fastreid.modeling.backbones import BACKBONE_REGISTRY logger = logging.getLogger(__name__) class ShuffleV2Block(nn.Module): """ Reference: https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2 """ def __init__(self, bn_norm, inp, oup, mid_channels, *, ksize, stride): super(ShuffleV2Block, self).__init__() self.stride = stride assert stride in [1, 2] self.mid_channels = mid_channels self.ksize = ksize pad = ksize // 2 self.pad = pad self.inp = inp outputs = oup - inp branch_main = [ # pw nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False), get_norm(bn_norm, mid_channels), nn.ReLU(inplace=True), # dw nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False), get_norm(bn_norm, mid_channels), # pw-linear nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False), get_norm(bn_norm, outputs), nn.ReLU(inplace=True), ] self.branch_main = nn.Sequential(*branch_main) if stride == 2: branch_proj = [ # dw nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False), get_norm(bn_norm, inp), # pw-linear nn.Conv2d(inp, inp, 1, 1, 0, bias=False), get_norm(bn_norm, inp), nn.ReLU(inplace=True), ] self.branch_proj = nn.Sequential(*branch_proj) else: self.branch_proj = None def forward(self, old_x): if self.stride == 1: x_proj, x = self.channel_shuffle(old_x) return torch.cat((x_proj, self.branch_main(x)), 1) elif self.stride == 2: x_proj = old_x x = old_x return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1) def channel_shuffle(self, x): batchsize, num_channels, height, width = x.data.size() assert (num_channels % 4 == 0) x = x.reshape(batchsize * num_channels // 2, 2, height * width) x = x.permute(1, 0, 2) x = x.reshape(2, -1, num_channels // 2, height, width) return x[0], x[1] class ShuffleNetV2(nn.Module): """ Reference: https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2 """ def __init__(self, bn_norm, model_size='1.5x'): super(ShuffleNetV2, self).__init__() self.stage_repeats = [4, 8, 4] self.model_size = model_size if model_size == '0.5x': self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] elif model_size == '1.0x': self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] elif model_size == '1.5x': self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] elif model_size == '2.0x': self.stage_out_channels = [-1, 24, 244, 488, 976, 2048] else: raise NotImplementedError # building first layer input_channel = self.stage_out_channels[1] self.first_conv = nn.Sequential( nn.Conv2d(3, input_channel, 3, 2, 1, bias=False), get_norm(bn_norm, input_channel), nn.ReLU(inplace=True), ) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.features = [] for idxstage in range(len(self.stage_repeats)): numrepeat = self.stage_repeats[idxstage] output_channel = self.stage_out_channels[idxstage + 2] for i in range(numrepeat): if i == 0: self.features.append(ShuffleV2Block(bn_norm, input_channel, output_channel, mid_channels=output_channel // 2, ksize=3, stride=2)) else: self.features.append(ShuffleV2Block(bn_norm, input_channel // 2, output_channel, mid_channels=output_channel // 2, ksize=3, stride=1)) input_channel = output_channel self.features = nn.Sequential(*self.features) self.conv_last = nn.Sequential( nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=False), get_norm(bn_norm, self.stage_out_channels[-1]), nn.ReLU(inplace=True) ) self._initialize_weights() def forward(self, x): x = self.first_conv(x) x = self.maxpool(x) x = self.features(x) x = self.conv_last(x) return x def _initialize_weights(self): for name, m in self.named_modules(): if isinstance(m, nn.Conv2d): if 'first' in name: nn.init.normal_(m.weight, 0, 0.01) else: nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0.0001) nn.init.constant_(m.running_mean, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0.0001) nn.init.constant_(m.running_mean, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) @BACKBONE_REGISTRY.register() def build_shufflenetv2_backbone(cfg): # fmt: off pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH bn_norm = cfg.MODEL.BACKBONE.NORM model_size = cfg.MODEL.BACKBONE.DEPTH # fmt: on model = ShuffleNetV2(bn_norm, model_size=model_size) if pretrain: new_state_dict = OrderedDict() state_dict = torch.load(pretrain_path)["state_dict"] for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] new_state_dict[k] = v incompatible = model.load_state_dict(new_state_dict, strict=False) if incompatible.missing_keys: logger.info( get_missing_parameters_message(incompatible.missing_keys) ) if incompatible.unexpected_keys: logger.info( get_unexpected_parameters_message(incompatible.unexpected_keys) ) return model