fast-reid/fastreid/modeling/meta_arch/mask_model.py

162 lines
6.3 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
from fastreid.modeling.backbones import *
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import bn_no_bias, GeM
class MaskUnit(nn.Module):
def __init__(self, in_planes=2048):
super().__init__()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=4, stride=2)
self.mask = nn.Linear(in_planes, 1, bias=None)
def forward(self, x):
x1 = self.maxpool1(x)
x2 = self.maxpool2(x)
xx = x.view(x.size(0), x.size(1), -1) # (bs, 2048, 192)
x1 = x1.view(x1.size(0), x1.size(1), -1) # (bs, 2048, 48)
x2 = x2.view(x2.size(0), x2.size(1), -1) # (bs, 2048, 33)
feat = torch.cat((xx, x1, x2), dim=2) # (bs, 2048, 273)
feat = feat.transpose(1, 2) # (bs, 274, 2048)
mask_scores = self.mask(feat) # (bs, 274, 1)
scores = F.normalize(mask_scores[:, :192], p=1, dim=1) # (bs, 192, 1)
mask_feat = torch.bmm(xx, scores) # (bs, 2048, 1)
return mask_feat.squeeze(2), mask_scores.squeeze(2)
class Maskmodel(nn.Module):
def __init__(self,
backbone,
num_classes,
last_stride,
with_ibn=False,
with_se=False,
gcb=None,
stage_with_gcb=[False, False, False, False],
pretrain=True,
model_path=''):
super().__init__()
if 'resnet' in backbone:
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
stage_with_gcb, model_path=model_path)
self.in_planes = 2048
elif 'osnet' in backbone:
if with_ibn:
self.base = osnet_ibn_x1_0(pretrained=pretrain)
else:
self.base = osnet_x1_0(pretrained=pretrain)
self.in_planes = 512
else:
print(f'not support {backbone} backbone')
self.num_classes = num_classes
# self.gap = GeM()
self.gap = nn.AdaptiveAvgPool2d(1)
# self.res_part = Bottleneck(2048, 512)
self.global_reduction = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.1)
)
self.global_bnneck = bn_no_bias(1024)
self.global_bnneck.apply(weights_init_kaiming)
self.global_fc = nn.Linear(1024, self.num_classes, bias=False)
self.global_fc.apply(weights_init_classifier)
self.mask_layer = MaskUnit(self.in_planes)
self.mask_reduction = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.1)
)
self.mask_bnneck = bn_no_bias(1024)
self.mask_bnneck.apply(weights_init_kaiming)
self.mask_fc = nn.Linear(1024, self.num_classes, bias=False)
self.mask_fc.apply(weights_init_classifier)
def forward(self, x, label=None, pose=None):
global_feat = self.base(x) # (bs, 2048, 24, 8)
pool_feat = self.gap(global_feat) # (bs, 2048, 1, 1)
pool_feat = pool_feat.view(-1, 2048) # (bs, 2048)
re_feat = self.global_reduction(pool_feat) # (bs, 1024)
bn_re_feat = self.global_bnneck(re_feat) # normalize for angular softmax
# global_feat = global_feat.view(global_feat.size(0), global_feat.size(1), -1)
# pose = pose.unsqueeze(2)
# pose_feat = torch.bmm(global_feat, pose).squeeze(2) # (bs, 2048)
# fused_feat = pool_feat + pose_feat
# bn_feat = self.bottleneck(fused_feat)
# mask_feat = self.res_part(global_feat)
mask_feat, mask_scores = self.mask_layer(global_feat)
mask_re_feat = self.mask_reduction(mask_feat)
bn_mask_feat = self.mask_bnneck(mask_re_feat)
if self.training:
cls_out = self.global_fc(bn_re_feat)
mask_cls_out = self.mask_fc(bn_mask_feat)
# am_out = self.amsoftmax(feat, label)
return cls_out, mask_cls_out, pool_feat, mask_feat, mask_scores
else:
return torch.cat((bn_re_feat, bn_mask_feat), dim=1), bn_mask_feat
def getLoss(self, outputs, labels, mask_labels, **kwargs):
cls_out, mask_cls_out, feat, mask_feat, mask_scores = outputs
# cls_out, feat = outputs
tri_loss = (TripletLoss(margin=-1)(feat, labels, normalize_feature=False)[0] +
TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]) / 2
# mask_feat_tri_loss = TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]
softmax_loss = (F.cross_entropy(cls_out, labels) + F.cross_entropy(mask_cls_out, labels)) / 2
mask_loss = nn.functional.mse_loss(mask_scores, mask_labels) * 0.16
self.loss = softmax_loss + tri_loss + mask_loss
# self.loss = softmax_loss + tri_loss + mask_loss
return {
'softmax': softmax_loss,
'tri': tri_loss,
'mask': mask_loss,
}
def load_params_wo_fc(self, state_dict):
state_dict.pop('global_fc.weight')
state_dict.pop('mask_fc.weight')
if 'classifier.weight' in state_dict:
state_dict.pop('classifier.weight')
if 'amsoftmax.weight' in state_dict:
state_dict.pop('amsoftmax.weight')
res = self.load_state_dict(state_dict, strict=False)
print(f'missing keys {res.missing_keys}')
print(f'unexpected keys {res.unexpected_keys}')
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
def unfreeze_all_layers(self, ):
self.train()
for p in self.parameters():
p.requires_grad_()
def unfreeze_specific_layer(self, names):
if isinstance(names, str):
names = [names]
for name, module in self.named_children():
if name in names:
module.train()
for p in module.parameters():
p.requires_grad_()
else:
module.eval()
for p in module.parameters():
p.requires_grad_(False)