mirror of https://github.com/JDAI-CV/fast-reid.git
81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
__all__ = [
|
|
"Linear",
|
|
"ArcSoftmax",
|
|
"CosSoftmax",
|
|
"CircleSoftmax"
|
|
]
|
|
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self, num_classes, scale, margin):
|
|
super().__init__()
|
|
self.num_classes = num_classes
|
|
self.s = scale
|
|
self.m = margin
|
|
|
|
def forward(self, logits, targets):
|
|
return logits.mul_(self.s)
|
|
|
|
def extra_repr(self):
|
|
return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}"
|
|
|
|
|
|
class CosSoftmax(Linear):
|
|
r"""Implement of large margin cosine distance:
|
|
"""
|
|
|
|
def forward(self, logits, targets):
|
|
index = torch.where(targets != -1)[0]
|
|
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
|
m_hot.scatter_(1, targets[index, None], self.m)
|
|
logits[index] -= m_hot
|
|
logits.mul_(self.s)
|
|
return logits
|
|
|
|
|
|
class ArcSoftmax(Linear):
|
|
|
|
def forward(self, logits, targets):
|
|
index = torch.where(targets != -1)[0]
|
|
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
|
m_hot.scatter_(1, targets[index, None], self.m)
|
|
logits.acos_()
|
|
logits[index] += m_hot
|
|
logits.cos_().mul_(self.s)
|
|
return logits
|
|
|
|
|
|
class CircleSoftmax(Linear):
|
|
|
|
def forward(self, logits, targets):
|
|
alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
|
|
alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
|
|
delta_p = 1 - self.m
|
|
delta_n = self.m
|
|
|
|
# When use model parallel, there are some targets not in class centers of local rank
|
|
index = torch.where(targets != -1)[0]
|
|
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
|
m_hot.scatter_(1, targets[index, None], 1)
|
|
|
|
logits_p = alpha_p * (logits - delta_p)
|
|
logits_n = alpha_n * (logits - delta_n)
|
|
|
|
logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot)
|
|
|
|
neg_index = torch.where(targets == -1)[0]
|
|
logits[neg_index] = logits_n[neg_index]
|
|
|
|
logits.mul_(self.s)
|
|
|
|
return logits
|