300 lines
11 KiB
Python
300 lines
11 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmengine.fileio import list_from_file
|
|
from mmengine.runner import autocast
|
|
from mmengine.utils import is_seq_of
|
|
|
|
from mmcls.models.losses import convert_to_one_hot
|
|
from mmcls.registry import MODELS
|
|
from mmcls.structures import ClsDataSample
|
|
from .cls_head import ClsHead
|
|
|
|
|
|
class NormProduct(nn.Linear):
|
|
"""An enhanced linear layer with k clustering centers to calculate product
|
|
between normalized input and linear weight.
|
|
|
|
Args:
|
|
in_features (int): size of each input sample.
|
|
out_features (int): size of each output sample
|
|
k (int): The number of clustering centers. Defaults to 1.
|
|
bias (bool): Whether there is bias. If set to ``False``, the
|
|
layer will not learn an additive bias. Defaults to ``True``.
|
|
feature_norm (bool): Whether to normalize the input feature.
|
|
Defaults to ``True``.
|
|
weight_norm (bool):Whether to normalize the weight.
|
|
Defaults to ``True``.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_features: int,
|
|
out_features: int,
|
|
k=1,
|
|
bias: bool = False,
|
|
feature_norm: bool = True,
|
|
weight_norm: bool = True):
|
|
|
|
super().__init__(in_features, out_features * k, bias=bias)
|
|
self.weight_norm = weight_norm
|
|
self.feature_norm = feature_norm
|
|
self.out_features = out_features
|
|
self.k = k
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
if self.feature_norm:
|
|
input = F.normalize(input)
|
|
if self.weight_norm:
|
|
weight = F.normalize(self.weight)
|
|
else:
|
|
weight = self.weight
|
|
cosine_all = F.linear(input, weight, self.bias)
|
|
|
|
if self.k == 1:
|
|
return cosine_all
|
|
else:
|
|
cosine_all = cosine_all.view(-1, self.out_features, self.k)
|
|
cosine, _ = torch.max(cosine_all, dim=2)
|
|
return cosine
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ArcFaceClsHead(ClsHead):
|
|
"""ArcFace classifier head.
|
|
|
|
A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss
|
|
for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_ and
|
|
`Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web
|
|
Faces <https://link.springer.com/chapter/10.1007/978-3-030-58621-8_43>`_
|
|
|
|
Example:
|
|
To use ArcFace in config files.
|
|
|
|
1. use vanilla ArcFace
|
|
|
|
.. code:: python
|
|
|
|
mode = dict(
|
|
backbone = xxx,
|
|
neck = xxxx,
|
|
head=dict(
|
|
type='ArcFaceClsHead',
|
|
num_classes=5000,
|
|
in_channels=1024,
|
|
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
init_cfg=None),
|
|
)
|
|
|
|
2. use SubCenterArcFace with 3 sub-centers
|
|
|
|
.. code:: python
|
|
|
|
mode = dict(
|
|
backbone = xxx,
|
|
neck = xxxx,
|
|
head=dict(
|
|
type='ArcFaceClsHead',
|
|
num_classes=5000,
|
|
in_channels=1024,
|
|
num_subcenters=3,
|
|
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
init_cfg=None),
|
|
)
|
|
|
|
3. use SubCenterArcFace With CountPowerAdaptiveMargins
|
|
|
|
.. code:: python
|
|
|
|
mode = dict(
|
|
backbone = xxx,
|
|
neck = xxxx,
|
|
head=dict(
|
|
type='ArcFaceClsHead',
|
|
num_classes=5000,
|
|
in_channels=1024,
|
|
num_subcenters=3,
|
|
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
init_cfg=None),
|
|
)
|
|
|
|
custom_hooks = [dict(type='SetAdaptiveMarginsHook')]
|
|
|
|
|
|
Args:
|
|
num_classes (int): Number of categories excluding the background
|
|
category.
|
|
in_channels (int): Number of channels in the input feature map.
|
|
num_subcenters (int): Number of subcenters. Defaults to 1.
|
|
scale (float): Scale factor of output logit. Defaults to 64.0.
|
|
margins (float): The penalty margin. Could be the fllowing formats:
|
|
|
|
- float: The margin, would be same for all the categories.
|
|
- Sequence[float]: The category-based margins list.
|
|
- str: A '.txt' file path which contains a list. Each line
|
|
represents the margin of a category, and the number in the
|
|
i-th row indicates the margin of the i-th class.
|
|
|
|
Defaults to 0.5.
|
|
easy_margin (bool): Avoid theta + m >= PI. Defaults to False.
|
|
loss (dict): Config of classification loss. Defaults to
|
|
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
|
|
init_cfg (dict, optional): the config to control the initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_classes: int,
|
|
in_channels: int,
|
|
num_subcenters: int = 1,
|
|
scale: float = 64.,
|
|
margins: Optional[Union[float, Sequence[float], str]] = 0.50,
|
|
easy_margin: bool = False,
|
|
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
init_cfg: Optional[dict] = None):
|
|
|
|
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
|
|
self.loss_module = MODELS.build(loss)
|
|
|
|
assert num_subcenters >= 1 and num_classes >= 0
|
|
self.in_channels = in_channels
|
|
self.num_classes = num_classes
|
|
self.num_subcenters = num_subcenters
|
|
self.scale = scale
|
|
self.easy_margin = easy_margin
|
|
|
|
self.norm_product = NormProduct(in_channels, num_classes,
|
|
num_subcenters)
|
|
|
|
if isinstance(margins, float):
|
|
margins = [margins] * num_classes
|
|
elif isinstance(margins, str) and margins.endswith('.txt'):
|
|
margins = [float(item) for item in list_from_file(margins)]
|
|
else:
|
|
assert is_seq_of(list(margins), (float, int)), (
|
|
'the attribute `margins` in ``ArcFaceClsHead`` should be a '
|
|
' float, a Sequence of float, or a ".txt" file path.')
|
|
|
|
assert len(margins) == num_classes, \
|
|
'The length of margins must be equal with num_classes.'
|
|
|
|
self.register_buffer(
|
|
'margins', torch.tensor(margins).float(), persistent=False)
|
|
# To make `phi` monotonic decreasing, refers to
|
|
# https://github.com/deepinsight/insightface/issues/108
|
|
sinm_m = torch.sin(math.pi - self.margins) * self.margins
|
|
threshold = torch.cos(math.pi - self.margins)
|
|
self.register_buffer('sinm_m', sinm_m, persistent=False)
|
|
self.register_buffer('threshold', threshold, persistent=False)
|
|
|
|
def set_margins(self, margins: Union[Sequence[float], float]) -> None:
|
|
"""set margins of arcface head.
|
|
|
|
Args:
|
|
margins (Union[Sequence[float], float]): The marigins.
|
|
"""
|
|
if isinstance(margins, float):
|
|
margins = [margins] * self.num_classes
|
|
assert is_seq_of(
|
|
list(margins), float) and (len(margins) == self.num_classes), (
|
|
f'margins must be Sequence[Union(float, int)], get {margins}')
|
|
|
|
self.margins = torch.tensor(
|
|
margins, device=self.margins.device, dtype=torch.float32)
|
|
self.sinm_m = torch.sin(self.margins) * self.margins
|
|
self.threshold = -torch.cos(self.margins)
|
|
|
|
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
"""The process before the final classification head.
|
|
|
|
The input ``feats`` is a tuple of tensor, and each tensor is the
|
|
feature of a backbone stage. In ``ArcFaceHead``, we just obtain the
|
|
feature of the last stage.
|
|
"""
|
|
# The ArcFaceHead doesn't have other module, just return after
|
|
# unpacking.
|
|
return feats[-1]
|
|
|
|
def _get_logit_with_margin(self, pre_logits, target):
|
|
"""add arc margin to the cosine in target index.
|
|
|
|
The target must be in index format.
|
|
"""
|
|
assert target.dim() == 1 or (
|
|
target.dim() == 2 and target.shape[1] == 1), \
|
|
'The target must be in index format.'
|
|
cosine = self.norm_product(pre_logits)
|
|
phi = torch.cos(torch.acos(cosine) + self.margins)
|
|
|
|
if self.easy_margin:
|
|
# when cosine>0, choose phi
|
|
# when cosine<=0, choose cosine
|
|
phi = torch.where(cosine > 0, phi, cosine)
|
|
else:
|
|
# when cos>th, choose phi
|
|
# when cos<=th, choose cosine-mm
|
|
phi = torch.where(cosine > self.threshold, phi,
|
|
cosine - self.sinm_m)
|
|
|
|
target = convert_to_one_hot(target, self.num_classes)
|
|
output = target * phi + (1 - target) * cosine
|
|
return output
|
|
|
|
def forward(self,
|
|
feats: Tuple[torch.Tensor],
|
|
target: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""The forward process."""
|
|
# Disable AMP
|
|
with autocast(enabled=False):
|
|
pre_logits = self.pre_logits(feats)
|
|
|
|
if target is None:
|
|
# when eval, logit is the cosine between W and pre_logits;
|
|
# cos(theta_yj) = (x/||x||) * (W/||W||)
|
|
logit = self.norm_product(pre_logits)
|
|
else:
|
|
# when training, add a margin to the pre_logits where target is
|
|
# True, then logit is the cosine between W and new pre_logits
|
|
logit = self._get_logit_with_margin(pre_logits, target)
|
|
|
|
return self.scale * logit
|
|
|
|
def loss(self, feats: Tuple[torch.Tensor],
|
|
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
|
"""Calculate losses from the classification score.
|
|
|
|
Args:
|
|
feats (tuple[Tensor]): The features extracted from the backbone.
|
|
Multiple stage inputs are acceptable but only the last stage
|
|
will be used to classify. The shape of every item should be
|
|
``(num_samples, num_classes)``.
|
|
data_samples (List[ClsDataSample]): The annotation data of
|
|
every samples.
|
|
**kwargs: Other keyword arguments to forward the loss module.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
# Unpack data samples and pack targets
|
|
label_target = torch.cat([i.gt_label.label for i in data_samples])
|
|
if 'score' in data_samples[0].gt_label:
|
|
# Batch augmentation may convert labels to one-hot format scores.
|
|
target = torch.stack([i.gt_label.score for i in data_samples])
|
|
else:
|
|
# change the labels to to one-hot format scores.
|
|
target = label_target
|
|
|
|
# the index format target would be used
|
|
cls_score = self(feats, label_target)
|
|
|
|
# compute loss
|
|
losses = dict()
|
|
loss = self.loss_module(
|
|
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
|
|
losses['loss'] = loss
|
|
|
|
return losses
|