# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch import torch.nn as nn __all__ = [ "Linear", "ArcSoftmax", "CosSoftmax", "CircleSoftmax" ] class Linear(nn.Module): def __init__(self, num_classes, scale, margin): super().__init__() self.num_classes = num_classes self.s = scale self.m = margin def forward(self, logits, targets): return logits.mul_(self.s) def extra_repr(self): return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}" class CosSoftmax(Linear): r"""Implement of large margin cosine distance: """ def forward(self, logits, targets): index = torch.where(targets != -1)[0] m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype) m_hot.scatter_(1, targets[index, None], self.m) logits[index] -= m_hot logits.mul_(self.s) return logits class ArcSoftmax(Linear): def forward(self, logits, targets): index = torch.where(targets != -1)[0] m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype) m_hot.scatter_(1, targets[index, None], self.m) logits.acos_() logits[index] += m_hot logits.cos_().mul_(self.s) return logits class CircleSoftmax(Linear): def forward(self, logits, targets): alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.) alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.) delta_p = 1 - self.m delta_n = self.m # When use model parallel, there are some targets not in class centers of local rank index = torch.where(targets != -1)[0] m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype) m_hot.scatter_(1, targets[index, None], 1) logits_p = alpha_p * (logits - delta_p) logits_n = alpha_n * (logits - delta_n) logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot) neg_index = torch.where(targets == -1)[0] logits[neg_index] = logits_n[neg_index] logits.mul_(self.s) return logits