fast-reid/fastreid/modeling/meta_arch/mgn.py

154 lines
6.0 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import torch
from torch import nn
from fastreid.modeling.backbones import ResNet, Bottleneck
from fastreid.modeling.model_utils import *
class MGN(nn.Module):
in_planes = 2048
feats = 256
def __init__(self,
backbone,
num_classes,
last_stride,
with_ibn,
gcb,
stage_with_gcb,
pretrain=True,
model_path=''):
super().__init__()
try:
base_module = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
except:
print(f'not support {backbone} backbone')
if pretrain:
base_module.load_pretrain(model_path)
self.num_classes = num_classes
self.backbone = nn.Sequential(
base_module.conv1,
base_module.bn1,
base_module.relu,
base_module.maxpool,
base_module.layer1,
base_module.layer2,
base_module.layer3[0]
)
res_conv4 = nn.Sequential(*base_module.layer3[1:])
res_g_conv5 = base_module.layer4
res_p_conv5 = nn.Sequential(
Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False),
nn.BatchNorm2d(2048))),
Bottleneck(2048, 512),
Bottleneck(2048, 512)
)
res_p_conv5.load_state_dict(base_module.layer4.state_dict())
self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool_zp2 = nn.MaxPool2d((12, 9))
self.maxpool_zp3 = nn.MaxPool2d((8, 9))
self.reduction = nn.Conv2d(2048, self.feats, 1, bias=False)
self.bn_neck = BN_no_bias(self.feats)
# self.bn_neck_2048_0 = BN_no_bias(self.feats)
# self.bn_neck_2048_1 = BN_no_bias(self.feats)
# self.bn_neck_2048_2 = BN_no_bias(self.feats)
# self.bn_neck_256_1_0 = BN_no_bias(self.feats)
# self.bn_neck_256_1_1 = BN_no_bias(self.feats)
# self.bn_neck_256_2_0 = BN_no_bias(self.feats)
# self.bn_neck_256_2_1 = BN_no_bias(self.feats)
# self.bn_neck_256_2_2 = BN_no_bias(self.feats)
self.fc_id_2048_0 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_2048_1 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_2048_2 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_1_0 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_1_1 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_2_0 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_2_1 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_2_2 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_2048_0.apply(weights_init_classifier)
self.fc_id_2048_1.apply(weights_init_classifier)
self.fc_id_2048_2.apply(weights_init_classifier)
self.fc_id_256_1_0.apply(weights_init_classifier)
self.fc_id_256_1_1.apply(weights_init_classifier)
self.fc_id_256_2_0.apply(weights_init_classifier)
self.fc_id_256_2_1.apply(weights_init_classifier)
self.fc_id_256_2_2.apply(weights_init_classifier)
def forward(self, x, label=None):
global_feat = self.backbone(x)
p1 = self.p1(global_feat) # (bs, 2048, 18, 9)
p2 = self.p2(global_feat) # (bs, 2048, 18, 9)
p3 = self.p3(global_feat) # (bs, 2048, 18, 9)
zg_p1 = self.avgpool(p1) # (bs, 2048, 1, 1)
zg_p2 = self.avgpool(p2) # (bs, 2048, 1, 1)
zg_p3 = self.avgpool(p3) # (bs, 2048, 1, 1)
zp2 = self.maxpool_zp2(p2)
z0_p2 = zp2[:, :, 0:1, :]
z1_p2 = zp2[:, :, 1:2, :]
zp3 = self.maxpool_zp3(p3)
z0_p3 = zp3[:, :, 0:1, :]
z1_p3 = zp3[:, :, 1:2, :]
z2_p3 = zp3[:, :, 2:3, :]
g_p1 = zg_p1.squeeze(3).squeeze(2) # (bs, 2048)
fg_p1 = self.reduction(zg_p1).squeeze(3).squeeze(2)
bn_fg_p1 = self.bn_neck(fg_p1)
g_p2 = zg_p2.squeeze(3).squeeze(2)
fg_p2 = self.reduction(zg_p2).squeeze(3).squeeze(2) # (bs, 256)
bn_fg_p2 = self.bn_neck(fg_p2)
g_p3 = zg_p3.squeeze(3).squeeze(2)
fg_p3 = self.reduction(zg_p3).squeeze(3).squeeze(2)
bn_fg_p3 = self.bn_neck(fg_p3)
f0_p2 = self.bn_neck(self.reduction(z0_p2).squeeze(3).squeeze(2))
f1_p2 = self.bn_neck(self.reduction(z1_p2).squeeze(3).squeeze(2))
f0_p3 = self.bn_neck(self.reduction(z0_p3).squeeze(3).squeeze(2))
f1_p3 = self.bn_neck(self.reduction(z1_p3).squeeze(3).squeeze(2))
f2_p3 = self.bn_neck(self.reduction(z2_p3).squeeze(3).squeeze(2))
if self.training:
l_p1 = self.fc_id_2048_0(bn_fg_p1)
l_p2 = self.fc_id_2048_1(bn_fg_p2)
l_p3 = self.fc_id_2048_2(bn_fg_p3)
l0_p2 = self.fc_id_256_1_0(f0_p2)
l1_p2 = self.fc_id_256_1_1(f1_p2)
l0_p3 = self.fc_id_256_2_0(f0_p3)
l1_p3 = self.fc_id_256_2_1(f1_p3)
l2_p3 = self.fc_id_256_2_2(f2_p3)
return g_p1, g_p2, g_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
# return g_p2, l_p2, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
else:
return torch.cat([bn_fg_p1, bn_fg_p2, bn_fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
def load_params_wo_fc(self, state_dict):
# state_dict.pop('classifier.weight')
res = self.load_state_dict(state_dict, strict=False)
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'