23 lines
582 B
Python
23 lines
582 B
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class RingLoss(nn.Module):
|
|
"""Ring loss.
|
|
|
|
Reference:
|
|
Zheng et al. Ring loss: Convex Feature Normalization for Face Recognition. CVPR 2018.
|
|
"""
|
|
def __init__(self):
|
|
super(RingLoss, self).__init__()
|
|
warnings.warn('This method is deprecated')
|
|
self.radius = nn.Parameter(torch.ones(1, dtype=torch.float))
|
|
|
|
def forward(self, x):
|
|
loss = ((x.norm(p=2, dim=1) - self.radius)**2).mean()
|
|
return loss |