fast-reid/fastreid/modeling/layers/qam.py

68 lines
1.9 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from ..layers import *
from ..losses.loss_utils import one_hot
from ..model_utils import weights_init_kaiming
class QAMHead(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
# bnneck
self.bnneck = NoBiasBatchNorm1d(in_feat)
self.bnneck.apply(weights_init_kaiming)
# classifier
# self.adaptive_s = False
self._s = 6.0
self._m = 0.50
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets=None):
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
if not self.training:
return bn_feat
# get cos(theta)
cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
# add margin
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) # for numerical stability
# --------------------------- convert label to one-hot ---------------------------
targets = one_hot(targets, self._num_classes)
phi = (2 * np.pi - (theta + self._m)) ** 2
others = (2 * np.pi - theta) ** 2
pred_class_logits = targets * phi + (1.0 - targets) * others
# logits re-scale
pred_class_logits *= self._s
return pred_class_logits, global_feat