[Feature] Add arcface head. (#926)

* feat: add arcface head

* feat: add arcface head

* update arcface

* refactor archface head

* update archface head

* update archface head

* use torch.cat instead of torch.hstack to fix ci

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
pull/1162/head
zzc98 2022-11-02 17:45:33 +08:00 committed by GitHub
parent 693596bc2f
commit 9506241f73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 217 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .arcface_head import ArcFaceClsHead
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
@ -12,5 +13,5 @@ from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead', 'EfficientFormerClsHead'
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead'
]

View File

@ -0,0 +1,176 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from .cls_head import ClsHead
class NormLinear(nn.Linear):
"""An enhanced linear layer, which could normalize the input and the linear
weight.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample
bias (bool): Whether there is bias. If set to ``False``, the
layer will not learn an additive bias. Defaults to ``True``.
feature_norm (bool): Whether to normalize the input feature.
Defaults to ``True``.
weight_norm (bool):Whether to normalize the weight.
Defaults to ``True``.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = False,
feature_norm: bool = True,
weight_norm: bool = True):
super().__init__(in_features, out_features, bias=bias)
self.weight_norm = weight_norm
self.feature_norm = feature_norm
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.feature_norm:
input = F.normalize(input)
if self.weight_norm:
weight = F.normalize(self.weight)
else:
weight = self.weight
return F.linear(input, weight, self.bias)
@MODELS.register_module()
class ArcFaceClsHead(ClsHead):
"""ArcFace classifier head.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
s (float): Norm of input feature. Defaults to 30.0.
m (float): Margin. Defaults to 0.5.
easy_margin (bool): Avoid theta + m >= PI. Defaults to False.
ls_eps (float): Label smoothing. Defaults to 0.
bias (bool): Whether to use bias in norm layer. Defaults to False.
loss (dict): Config of classification loss. Defaults to
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: int,
s: float = 30.0,
m: float = 0.50,
easy_margin: bool = False,
ls_eps: float = 0.0,
bias: bool = False,
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg: Optional[dict] = None):
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
self.in_channels = in_channels
self.num_classes = num_classes
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self.s = s
self.m = m
self.ls_eps = ls_eps
self.norm_linear = NormLinear(in_channels, num_classes, bias=bias)
self.easy_margin = easy_margin
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``ArcFaceHead``, we just obtain the
feature of the last stage.
"""
# The ArcFaceHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def forward(self,
feats: Tuple[torch.Tensor],
target: Optional[torch.Tensor] = None) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# cos=(a*b)/(||a||*||b||)
cosine = self.norm_linear(pre_logits)
if target is None:
return self.s * cosine
phi = torch.cos(torch.acos(cosine) + self.m)
if self.easy_margin:
# when cosine>0, choose phi
# when cosine<=0, choose cosine
phi = torch.where(cosine > 0, phi, cosine)
else:
# when cos>th, choose phi
# when cos<=th, choose cosine-mm
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=pre_logits.device)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 -
self.ls_eps) * one_hot + self.ls_eps / self.num_classes
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
return output * self.s
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
# The part can be traced by torch.fx
cls_score = self(feats, target)
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
return losses

View File

@ -480,3 +480,42 @@ class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
head(feats)
class TestArcFaceClsHead(TestCase):
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# target is not None
feats = (torch.rand(4, 10), torch.rand(4, 10))
target = torch.zeros(4)
outs = head(feats, target)
self.assertEqual(outs.shape, (4, 5))
# target is None
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 5))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
# test loss with used='before'
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)