mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
1. change dataset show to trainset show and testset show seperately 2. add cls layer to easily plug in circle loss and arcface
160 lines
6.2 KiB
Python
160 lines
6.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from fastreid.modeling.model_utils import *
|
|
from fastreid.modeling.layers import NoBiasBatchNorm1d
|
|
|
|
|
|
class MaskUnit(nn.Module):
|
|
def __init__(self, in_planes=2048):
|
|
super().__init__()
|
|
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.maxpool2 = nn.MaxPool2d(kernel_size=4, stride=2)
|
|
|
|
self.mask = nn.Linear(in_planes, 1, bias=None)
|
|
|
|
def forward(self, x):
|
|
x1 = self.maxpool1(x)
|
|
x2 = self.maxpool2(x)
|
|
xx = x.view(x.size(0), x.size(1), -1) # (bs, 2048, 192)
|
|
x1 = x1.view(x1.size(0), x1.size(1), -1) # (bs, 2048, 48)
|
|
x2 = x2.view(x2.size(0), x2.size(1), -1) # (bs, 2048, 33)
|
|
feat = torch.cat((xx, x1, x2), dim=2) # (bs, 2048, 273)
|
|
feat = feat.transpose(1, 2) # (bs, 274, 2048)
|
|
mask_scores = self.mask(feat) # (bs, 274, 1)
|
|
scores = F.normalize(mask_scores[:, :192], p=1, dim=1) # (bs, 192, 1)
|
|
mask_feat = torch.bmm(xx, scores) # (bs, 2048, 1)
|
|
return mask_feat.squeeze(2), mask_scores.squeeze(2)
|
|
|
|
|
|
class Maskmodel(nn.Module):
|
|
def __init__(self,
|
|
backbone,
|
|
num_classes,
|
|
last_stride,
|
|
with_ibn=False,
|
|
with_se=False,
|
|
gcb=None,
|
|
stage_with_gcb=[False, False, False, False],
|
|
pretrain=True,
|
|
model_path=''):
|
|
super().__init__()
|
|
if 'resnet' in backbone:
|
|
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
|
|
stage_with_gcb, model_path=model_path)
|
|
self.in_planes = 2048
|
|
elif 'osnet' in backbone:
|
|
if with_ibn:
|
|
self.base = osnet_ibn_x1_0(pretrained=pretrain)
|
|
else:
|
|
self.base = osnet_x1_0(pretrained=pretrain)
|
|
self.in_planes = 512
|
|
else:
|
|
print(f'not support {backbone} backbone')
|
|
self.num_classes = num_classes
|
|
# self.gap = GeM()
|
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
# self.res_part = Bottleneck(2048, 512)
|
|
|
|
self.global_reduction = nn.Sequential(
|
|
nn.Linear(2048, 1024),
|
|
nn.BatchNorm1d(1024),
|
|
nn.LeakyReLU(0.1)
|
|
)
|
|
self.global_bnneck = NoBiasBatchNorm1d(1024)
|
|
self.global_bnneck.apply(weights_init_kaiming)
|
|
self.global_fc = nn.Linear(1024, self.num_classes, bias=False)
|
|
self.global_fc.apply(weights_init_classifier)
|
|
|
|
self.mask_layer = MaskUnit(self.in_planes)
|
|
self.mask_reduction = nn.Sequential(
|
|
nn.Linear(2048, 1024),
|
|
nn.BatchNorm1d(1024),
|
|
nn.LeakyReLU(0.1)
|
|
)
|
|
self.mask_bnneck = NoBiasBatchNorm1d(1024)
|
|
self.mask_bnneck.apply(weights_init_kaiming)
|
|
|
|
self.mask_fc = nn.Linear(1024, self.num_classes, bias=False)
|
|
self.mask_fc.apply(weights_init_classifier)
|
|
|
|
def forward(self, x, label=None, pose=None):
|
|
global_feat = self.base(x) # (bs, 2048, 24, 8)
|
|
pool_feat = self.gap(global_feat) # (bs, 2048, 1, 1)
|
|
pool_feat = pool_feat.view(-1, 2048) # (bs, 2048)
|
|
re_feat = self.global_reduction(pool_feat) # (bs, 1024)
|
|
bn_re_feat = self.global_bnneck(re_feat) # normalize for angular softmax
|
|
|
|
# global_feat = global_feat.view(global_feat.size(0), global_feat.size(1), -1)
|
|
# pose = pose.unsqueeze(2)
|
|
# pose_feat = torch.bmm(global_feat, pose).squeeze(2) # (bs, 2048)
|
|
# fused_feat = pool_feat + pose_feat
|
|
# bn_feat = self.bottleneck(fused_feat)
|
|
# mask_feat = self.res_part(global_feat)
|
|
mask_feat, mask_scores = self.mask_layer(global_feat)
|
|
mask_re_feat = self.mask_reduction(mask_feat)
|
|
bn_mask_feat = self.mask_bnneck(mask_re_feat)
|
|
if self.training:
|
|
cls_out = self.global_fc(bn_re_feat)
|
|
mask_cls_out = self.mask_fc(bn_mask_feat)
|
|
# am_out = self.amsoftmax(feat, label)
|
|
return cls_out, mask_cls_out, pool_feat, mask_feat, mask_scores
|
|
else:
|
|
return torch.cat((bn_re_feat, bn_mask_feat), dim=1), bn_mask_feat
|
|
|
|
def getLoss(self, outputs, labels, mask_labels, **kwargs):
|
|
cls_out, mask_cls_out, feat, mask_feat, mask_scores = outputs
|
|
# cls_out, feat = outputs
|
|
|
|
tri_loss = (TripletLoss(margin=-1)(feat, labels, normalize_feature=False)[0] +
|
|
TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]) / 2
|
|
# mask_feat_tri_loss = TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]
|
|
softmax_loss = (F.cross_entropy(cls_out, labels) + F.cross_entropy(mask_cls_out, labels)) / 2
|
|
mask_loss = nn.functional.mse_loss(mask_scores, mask_labels) * 0.16
|
|
|
|
self.loss = softmax_loss + tri_loss + mask_loss
|
|
# self.loss = softmax_loss + tri_loss + mask_loss
|
|
return {
|
|
'softmax': softmax_loss,
|
|
'tri': tri_loss,
|
|
'mask': mask_loss,
|
|
}
|
|
|
|
def load_params_wo_fc(self, state_dict):
|
|
state_dict.pop('global_fc.weight')
|
|
state_dict.pop('mask_fc.weight')
|
|
if 'classifier.weight' in state_dict:
|
|
state_dict.pop('classifier.weight')
|
|
if 'amsoftmax.weight' in state_dict:
|
|
state_dict.pop('amsoftmax.weight')
|
|
res = self.load_state_dict(state_dict, strict=False)
|
|
print(f'missing keys {res.missing_keys}')
|
|
print(f'unexpected keys {res.unexpected_keys}')
|
|
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
|
|
|
def unfreeze_all_layers(self, ):
|
|
self.train()
|
|
for p in self.parameters():
|
|
p.requires_grad_()
|
|
|
|
def unfreeze_specific_layer(self, names):
|
|
if isinstance(names, str):
|
|
names = [names]
|
|
|
|
for name, module in self.named_children():
|
|
if name in names:
|
|
module.train()
|
|
for p in module.parameters():
|
|
p.requires_grad_()
|
|
else:
|
|
module.eval()
|
|
for p in module.parameters():
|
|
p.requires_grad_(False)
|