# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter class CircleSoftmax(nn.Module): def __init__(self, cfg, in_feat, num_classes): super().__init__() self.in_feat = in_feat self._num_classes = num_classes self._s = cfg.MODEL.HEADS.SCALE self._m = cfg.MODEL.HEADS.MARGIN self.weight = Parameter(torch.Tensor(num_classes, in_feat)) def forward(self, features, targets): sim_mat = F.linear(F.normalize(features), F.normalize(self.weight)) alpha_p = F.relu(-sim_mat.detach() + 1 + self._m) alpha_n = F.relu(sim_mat.detach() + self._m) delta_p = 1 - self._m delta_n = self._m s_p = self._s * alpha_p * (sim_mat - delta_p) s_n = self._s * alpha_n * (sim_mat - delta_n) targets = F.one_hot(targets, num_classes=self._num_classes) pred_class_logits = targets * s_p + (1.0 - targets) * s_n return pred_class_logits def extra_repr(self): return 'in_features={}, num_classes={}, scale={}, margin={}'.format( self.in_feat, self._num_classes, self._s, self._m )