mirror of https://github.com/JDAI-CV/fast-reid.git
154 lines
6.0 KiB
Python
154 lines
6.0 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
import copy
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from fastreid.modeling.backbones import ResNet, Bottleneck
|
|
from fastreid.modeling.model_utils import *
|
|
|
|
|
|
class MGN(nn.Module):
|
|
in_planes = 2048
|
|
feats = 256
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
num_classes,
|
|
last_stride,
|
|
with_ibn,
|
|
gcb,
|
|
stage_with_gcb,
|
|
pretrain=True,
|
|
model_path=''):
|
|
super().__init__()
|
|
try:
|
|
base_module = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
|
except:
|
|
print(f'not support {backbone} backbone')
|
|
|
|
if pretrain:
|
|
base_module.load_pretrain(model_path)
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.backbone = nn.Sequential(
|
|
base_module.conv1,
|
|
base_module.bn1,
|
|
base_module.relu,
|
|
base_module.maxpool,
|
|
base_module.layer1,
|
|
base_module.layer2,
|
|
base_module.layer3[0]
|
|
)
|
|
|
|
res_conv4 = nn.Sequential(*base_module.layer3[1:])
|
|
|
|
res_g_conv5 = base_module.layer4
|
|
|
|
res_p_conv5 = nn.Sequential(
|
|
Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False),
|
|
nn.BatchNorm2d(2048))),
|
|
Bottleneck(2048, 512),
|
|
Bottleneck(2048, 512)
|
|
)
|
|
res_p_conv5.load_state_dict(base_module.layer4.state_dict())
|
|
|
|
self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
|
|
self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
|
|
self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
|
self.maxpool_zp2 = nn.MaxPool2d((12, 9))
|
|
self.maxpool_zp3 = nn.MaxPool2d((8, 9))
|
|
|
|
self.reduction = nn.Conv2d(2048, self.feats, 1, bias=False)
|
|
self.bn_neck = BN_no_bias(self.feats)
|
|
# self.bn_neck_2048_0 = BN_no_bias(self.feats)
|
|
# self.bn_neck_2048_1 = BN_no_bias(self.feats)
|
|
# self.bn_neck_2048_2 = BN_no_bias(self.feats)
|
|
# self.bn_neck_256_1_0 = BN_no_bias(self.feats)
|
|
# self.bn_neck_256_1_1 = BN_no_bias(self.feats)
|
|
# self.bn_neck_256_2_0 = BN_no_bias(self.feats)
|
|
# self.bn_neck_256_2_1 = BN_no_bias(self.feats)
|
|
# self.bn_neck_256_2_2 = BN_no_bias(self.feats)
|
|
|
|
self.fc_id_2048_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
self.fc_id_2048_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
self.fc_id_2048_2 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
|
|
self.fc_id_256_1_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
self.fc_id_256_1_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
self.fc_id_256_2_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
self.fc_id_256_2_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
self.fc_id_256_2_2 = nn.Linear(self.feats, self.num_classes, bias=False)
|
|
|
|
self.fc_id_2048_0.apply(weights_init_classifier)
|
|
self.fc_id_2048_1.apply(weights_init_classifier)
|
|
self.fc_id_2048_2.apply(weights_init_classifier)
|
|
self.fc_id_256_1_0.apply(weights_init_classifier)
|
|
self.fc_id_256_1_1.apply(weights_init_classifier)
|
|
self.fc_id_256_2_0.apply(weights_init_classifier)
|
|
self.fc_id_256_2_1.apply(weights_init_classifier)
|
|
self.fc_id_256_2_2.apply(weights_init_classifier)
|
|
|
|
def forward(self, x, label=None):
|
|
global_feat = self.backbone(x)
|
|
|
|
p1 = self.p1(global_feat) # (bs, 2048, 18, 9)
|
|
p2 = self.p2(global_feat) # (bs, 2048, 18, 9)
|
|
p3 = self.p3(global_feat) # (bs, 2048, 18, 9)
|
|
|
|
zg_p1 = self.avgpool(p1) # (bs, 2048, 1, 1)
|
|
zg_p2 = self.avgpool(p2) # (bs, 2048, 1, 1)
|
|
zg_p3 = self.avgpool(p3) # (bs, 2048, 1, 1)
|
|
|
|
zp2 = self.maxpool_zp2(p2)
|
|
z0_p2 = zp2[:, :, 0:1, :]
|
|
z1_p2 = zp2[:, :, 1:2, :]
|
|
|
|
zp3 = self.maxpool_zp3(p3)
|
|
z0_p3 = zp3[:, :, 0:1, :]
|
|
z1_p3 = zp3[:, :, 1:2, :]
|
|
z2_p3 = zp3[:, :, 2:3, :]
|
|
|
|
g_p1 = zg_p1.squeeze(3).squeeze(2) # (bs, 2048)
|
|
fg_p1 = self.reduction(zg_p1).squeeze(3).squeeze(2)
|
|
bn_fg_p1 = self.bn_neck(fg_p1)
|
|
g_p2 = zg_p2.squeeze(3).squeeze(2)
|
|
fg_p2 = self.reduction(zg_p2).squeeze(3).squeeze(2) # (bs, 256)
|
|
bn_fg_p2 = self.bn_neck(fg_p2)
|
|
g_p3 = zg_p3.squeeze(3).squeeze(2)
|
|
fg_p3 = self.reduction(zg_p3).squeeze(3).squeeze(2)
|
|
bn_fg_p3 = self.bn_neck(fg_p3)
|
|
|
|
f0_p2 = self.bn_neck(self.reduction(z0_p2).squeeze(3).squeeze(2))
|
|
f1_p2 = self.bn_neck(self.reduction(z1_p2).squeeze(3).squeeze(2))
|
|
f0_p3 = self.bn_neck(self.reduction(z0_p3).squeeze(3).squeeze(2))
|
|
f1_p3 = self.bn_neck(self.reduction(z1_p3).squeeze(3).squeeze(2))
|
|
f2_p3 = self.bn_neck(self.reduction(z2_p3).squeeze(3).squeeze(2))
|
|
|
|
if self.training:
|
|
l_p1 = self.fc_id_2048_0(bn_fg_p1)
|
|
l_p2 = self.fc_id_2048_1(bn_fg_p2)
|
|
l_p3 = self.fc_id_2048_2(bn_fg_p3)
|
|
|
|
l0_p2 = self.fc_id_256_1_0(f0_p2)
|
|
l1_p2 = self.fc_id_256_1_1(f1_p2)
|
|
l0_p3 = self.fc_id_256_2_0(f0_p3)
|
|
l1_p3 = self.fc_id_256_2_1(f1_p3)
|
|
l2_p3 = self.fc_id_256_2_2(f2_p3)
|
|
return g_p1, g_p2, g_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
|
|
# return g_p2, l_p2, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
|
|
else:
|
|
return torch.cat([bn_fg_p1, bn_fg_p2, bn_fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
|
|
|
|
def load_params_wo_fc(self, state_dict):
|
|
# state_dict.pop('classifier.weight')
|
|
res = self.load_state_dict(state_dict, strict=False)
|
|
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|