mirror of https://github.com/JDAI-CV/fast-reid.git
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
|
|
|
|
class ArcSoftmax(nn.Module):
|
|
def __init__(self, cfg, in_feat, num_classes):
|
|
super().__init__()
|
|
self.in_feat = in_feat
|
|
self._num_classes = num_classes
|
|
self._s = cfg.MODEL.HEADS.SCALE
|
|
self._m = cfg.MODEL.HEADS.MARGIN
|
|
|
|
self.cos_m = math.cos(self._m)
|
|
self.sin_m = math.sin(self._m)
|
|
self.threshold = math.cos(math.pi - self._m)
|
|
self.mm = math.sin(math.pi - self._m) * self._m
|
|
|
|
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
|
self.register_buffer('t', torch.zeros(1))
|
|
|
|
def forward(self, features, targets):
|
|
# get cos(theta)
|
|
cos_theta = F.linear(F.normalize(features), F.normalize(self.weight))
|
|
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
|
|
|
|
target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1)
|
|
|
|
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
|
|
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
|
|
mask = cos_theta > cos_theta_m
|
|
final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
|
|
|
|
hard_example = cos_theta[mask]
|
|
with torch.no_grad():
|
|
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
|
|
cos_theta[mask] = hard_example * (self.t + hard_example)
|
|
cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
|
|
pred_class_logits = cos_theta * self._s
|
|
return pred_class_logits
|
|
|
|
def extra_repr(self):
|
|
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
|
self.in_feat, self._num_classes, self._s, self._m
|
|
)
|