deep-person-reid/torchreid/losses/ring_loss.py

23 lines
582 B
Python

from __future__ import absolute_import
from __future__ import division
import warnings
import torch
import torch.nn as nn
class RingLoss(nn.Module):
"""Ring loss.
Reference:
Zheng et al. Ring loss: Convex Feature Normalization for Face Recognition. CVPR 2018.
"""
def __init__(self):
super(RingLoss, self).__init__()
warnings.warn('This method is deprecated')
self.radius = nn.Parameter(torch.ones(1, dtype=torch.float))
def forward(self, x):
loss = ((x.norm(p=2, dim=1) - self.radius)**2).mean()
return loss