# encoding: utf-8 import logging import math import torch from torch import nn from fastreid.modeling.backbones import BACKBONE_REGISTRY from fastreid.modeling.backbones.resnet import Bottleneck, model_zoo, model_urls from .non_local_layer import Non_local class ResNetNL(nn.Module): def __init__(self, last_stride, with_ibn, block=Bottleneck, layers=[3, 4, 6, 3], non_layers=[0, 2, 3, 0]): self.inplanes = 64 super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn) self.layer4 = self._make_layer( block, 512, layers[3], stride=last_stride) self.NL_1 = nn.ModuleList( [Non_local(256) for i in range(non_layers[0])]) self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) self.NL_2 = nn.ModuleList( [Non_local(512) for i in range(non_layers[1])]) self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) self.NL_3 = nn.ModuleList( [Non_local(1024) for i in range(non_layers[2])]) self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) self.NL_4 = nn.ModuleList( [Non_local(2048) for i in range(non_layers[3])]) self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] if planes == 512: with_ibn = False layers.append(block(self.inplanes, planes, with_ibn, stride=stride, downsample=downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, with_ibn)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) NL1_counter = 0 if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] for i in range(len(self.layer1)): x = self.layer1[i](x) if i == self.NL_1_idx[NL1_counter]: _, C, H, W = x.shape x = self.NL_1[NL1_counter](x) NL1_counter += 1 # Layer 2 NL2_counter = 0 if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] for i in range(len(self.layer2)): x = self.layer2[i](x) if i == self.NL_2_idx[NL2_counter]: _, C, H, W = x.shape x = self.NL_2[NL2_counter](x) NL2_counter += 1 # Layer 3 NL3_counter = 0 if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] for i in range(len(self.layer3)): x = self.layer3[i](x) if i == self.NL_3_idx[NL3_counter]: _, C, H, W = x.shape x = self.NL_3[NL3_counter](x) NL3_counter += 1 # Layer 4 NL4_counter = 0 if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] for i in range(len(self.layer4)): x = self.layer4[i](x) if i == self.NL_4_idx[NL4_counter]: _, C, H, W = x.shape x = self.NL_4[NL4_counter](x) NL4_counter += 1 return x def load_param(self, model_path): param_dict = torch.load(model_path) for i in param_dict: if 'fc' in i: continue self.state_dict()[i].copy_(param_dict[i]) def random_init(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() @BACKBONE_REGISTRY.register() def build_resnetNL_backbone(cfg): """ Create a ResNet Non-local instance from config. Returns: ResNet: a :class:`ResNet` instance. """ # fmt: off pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE with_ibn = cfg.MODEL.BACKBONE.WITH_IBN with_se = cfg.MODEL.BACKBONE.WITH_SE depth = cfg.MODEL.BACKBONE.DEPTH num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] nl_layers_per_stage = [0, 2, 3, 0] model = ResNetNL(last_stride, with_ibn, Bottleneck, num_blocks_per_stage, nl_layers_per_stage) if pretrain: if not with_ibn: # original resnet state_dict = model_zoo.load_url(model_urls[depth]) else: # ibn resnet state_dict = torch.load(pretrain_path)['state_dict'] # remove module in name new_state_dict = {} for k in state_dict: new_k = '.'.join(k.split('.')[1:]) if model.state_dict()[new_k].shape == state_dict[k].shape: new_state_dict[new_k] = state_dict[k] state_dict = new_state_dict res = model.load_state_dict(state_dict, strict=False) logger = logging.getLogger('fastreid.'+__name__) logger.info('missing keys is {}'.format(res.missing_keys)) logger.info('unexpected keys is {}'.format(res.unexpected_keys)) return model