mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
1. change dataset show to trainset show and testset show seperately 2. add cls layer to easily plug in circle loss and arcface
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
|
|
from ..layers import *
|
|
from ..losses.loss_utils import one_hot
|
|
from ..model_utils import weights_init_kaiming
|
|
|
|
|
|
class QAMHead(nn.Module):
|
|
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
|
super().__init__()
|
|
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
|
|
|
self.pool_layer = nn.Sequential(
|
|
pool_layer,
|
|
Flatten()
|
|
)
|
|
# bnneck
|
|
self.bnneck = NoBiasBatchNorm1d(in_feat)
|
|
self.bnneck.apply(weights_init_kaiming)
|
|
|
|
# classifier
|
|
# self.adaptive_s = False
|
|
self._s = 6.0
|
|
self._m = 0.50
|
|
|
|
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
|
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
def forward(self, features, targets=None):
|
|
global_feat = self.pool_layer(features)
|
|
bn_feat = self.bnneck(global_feat)
|
|
if not self.training:
|
|
return bn_feat
|
|
|
|
# get cos(theta)
|
|
cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
|
|
|
# add margin
|
|
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) # for numerical stability
|
|
|
|
# --------------------------- convert label to one-hot ---------------------------
|
|
targets = one_hot(targets, self._num_classes)
|
|
|
|
phi = (2 * np.pi - (theta + self._m)) ** 2
|
|
others = (2 * np.pi - theta) ** 2
|
|
|
|
pred_class_logits = targets * phi + (1.0 - targets) * others
|
|
|
|
# logits re-scale
|
|
pred_class_logits *= self._s
|
|
|
|
return pred_class_logits, global_feat
|