[Enhance] Enhance ArcFaceClsHead. (#1181)
* update arcface * fix unit tests * add adv-margins add adv-margins update arcface * rebase * update doc and fix ut * rebase * update code * rebase * use label data * update set-margins * Modify Arcface related method names. Co-authored-by: mzr1996 <mzr1996@163.com>pull/1211/head
parent
4fb44f8770
commit
b0007812d6
|
@ -31,7 +31,8 @@ Hooks
|
|||
ClassNumCheckHook
|
||||
PreciseBNHook
|
||||
VisualizationHook
|
||||
SwitchRecipeHook
|
||||
PrepareProtoBeforeValLoopHook
|
||||
SetAdaptiveMarginsHook
|
||||
|
||||
.. module:: mmcls.engine.optimizers
|
||||
|
||||
|
|
|
@ -140,6 +140,7 @@ Heads
|
|||
EfficientFormerClsHead
|
||||
DeiTClsHead
|
||||
ConformerHead
|
||||
ArcFaceClsHead
|
||||
MultiLabelClsHead
|
||||
MultiLabelLinearClsHead
|
||||
CSRAClsHead
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .class_num_check_hook import ClassNumCheckHook
|
||||
from .margin_head_hooks import SetAdaptiveMarginsHook
|
||||
from .precise_bn_hook import PreciseBNHook
|
||||
from .retriever_hooks import PrepareProtoBeforeValLoopHook
|
||||
from .switch_recipe_hook import SwitchRecipeHook
|
||||
|
@ -7,5 +8,6 @@ from .visualization_hook import VisualizationHook
|
|||
|
||||
__all__ = [
|
||||
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
|
||||
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook'
|
||||
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
|
||||
'SetAdaptiveMarginsHook'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved
|
||||
import numpy as np
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
from mmcls.models.heads import ArcFaceClsHead
|
||||
from mmcls.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SetAdaptiveMarginsHook(Hook):
|
||||
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
|
||||
category-wise count.
|
||||
|
||||
A PyTorch implementation of paper `Google Landmark Recognition 2020
|
||||
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_.
|
||||
The margins will be
|
||||
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
|
||||
The `n` indicates the number of occurrences of a category.
|
||||
|
||||
Args:
|
||||
margin_min (float): Lower bound of margins. Defaults to 0.05.
|
||||
margin_max (float): Upper bound of margins. Defaults to 0.5.
|
||||
power (float): The power of category freqercy. Defaults to -0.25.
|
||||
"""
|
||||
|
||||
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
|
||||
self.margin_min = margin_min
|
||||
self.margin_max = margin_max
|
||||
self.margin_range = margin_max - margin_min
|
||||
self.p = power
|
||||
|
||||
def before_train(self, runner):
|
||||
"""change the margins in ArcFaceClsHead.
|
||||
|
||||
Args:
|
||||
runner (obj: `Runner`): Runner.
|
||||
"""
|
||||
model = runner.model
|
||||
if is_model_wrapper(model):
|
||||
model = model.module
|
||||
|
||||
if (hasattr(model, 'head')
|
||||
and not isinstance(model.head, ArcFaceClsHead)):
|
||||
raise ValueError(
|
||||
'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
|
||||
f'for ``ArcFaceClsHead``, but get {type(model.head)}')
|
||||
|
||||
# generate margins base on the dataset.
|
||||
gt_labels = runner.train_dataloader.dataset.get_gt_labels()
|
||||
label_count = np.bincount(gt_labels)
|
||||
label_count[label_count == 0] = 1 # At least one occurrence
|
||||
pow_freq = np.power(label_count, self.p)
|
||||
|
||||
min_f, max_f = pow_freq.min(), pow_freq.max()
|
||||
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
|
||||
margins = normized_pow_freq * self.margin_range + self.margin_min
|
||||
|
||||
assert len(margins) == runner.model.head.num_classes
|
||||
|
||||
model.head.set_margins(margins)
|
|
@ -250,13 +250,16 @@ class HorNetBlock(nn.Module):
|
|||
|
||||
@MODELS.register_module()
|
||||
class HorNet(BaseBackbone):
|
||||
"""HorNet
|
||||
A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
|
||||
with Recursive Gated Convolutions`
|
||||
Inspiration from
|
||||
https://github.com/raoyongming/HorNet
|
||||
"""HorNet.
|
||||
|
||||
A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial
|
||||
Interactions with Recursive Gated Convolutions
|
||||
<https://arxiv.org/abs/2207.14284>`_ .
|
||||
Inspiration from https://github.com/raoyongming/HorNet
|
||||
|
||||
Args:
|
||||
arch (str | dict): HorNet architecture.
|
||||
|
||||
If use string, choose from 'tiny', 'small', 'base' and 'large'.
|
||||
If use dict, it should have below keys:
|
||||
- **base_dim** (int): The base dimensions of embedding.
|
||||
|
@ -264,6 +267,7 @@ class HorNet(BaseBackbone):
|
|||
- **orders** (List[int]): The number of order of gnConv in each
|
||||
stage.
|
||||
- **dw_cfg** (List[dict]): The Config for dw conv.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
in_channels (int): Number of input image channels. Defaults to 3.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .arcface_head import ArcFaceClsHead
|
||||
from .cls_head import ClsHead
|
||||
from .conformer_head import ConformerHead
|
||||
from .deit_head import DeiTClsHead
|
||||
from .efficientformer_head import EfficientFormerClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .margin_head import ArcFaceClsHead
|
||||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
from .multi_label_csra_head import CSRAClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
|
|
|
@ -1,176 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
class NormLinear(nn.Linear):
|
||||
"""An enhanced linear layer, which could normalize the input and the linear
|
||||
weight.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample
|
||||
bias (bool): Whether there is bias. If set to ``False``, the
|
||||
layer will not learn an additive bias. Defaults to ``True``.
|
||||
feature_norm (bool): Whether to normalize the input feature.
|
||||
Defaults to ``True``.
|
||||
weight_norm (bool):Whether to normalize the weight.
|
||||
Defaults to ``True``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
feature_norm: bool = True,
|
||||
weight_norm: bool = True):
|
||||
|
||||
super().__init__(in_features, out_features, bias=bias)
|
||||
self.weight_norm = weight_norm
|
||||
self.feature_norm = feature_norm
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.feature_norm:
|
||||
input = F.normalize(input)
|
||||
if self.weight_norm:
|
||||
weight = F.normalize(self.weight)
|
||||
else:
|
||||
weight = self.weight
|
||||
return F.linear(input, weight, self.bias)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ArcFaceClsHead(ClsHead):
|
||||
"""ArcFace classifier head.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
s (float): Norm of input feature. Defaults to 30.0.
|
||||
m (float): Margin. Defaults to 0.5.
|
||||
easy_margin (bool): Avoid theta + m >= PI. Defaults to False.
|
||||
ls_eps (float): Label smoothing. Defaults to 0.
|
||||
bias (bool): Whether to use bias in norm layer. Defaults to False.
|
||||
loss (dict): Config of classification loss. Defaults to
|
||||
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
|
||||
init_cfg (dict, optional): the config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
s: float = 30.0,
|
||||
m: float = 0.50,
|
||||
easy_margin: bool = False,
|
||||
ls_eps: float = 0.0,
|
||||
bias: bool = False,
|
||||
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
||||
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
||||
if self.num_classes <= 0:
|
||||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.ls_eps = ls_eps
|
||||
|
||||
self.norm_linear = NormLinear(in_channels, num_classes, bias=bias)
|
||||
|
||||
self.easy_margin = easy_margin
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``ArcFaceHead``, we just obtain the
|
||||
feature of the last stage.
|
||||
"""
|
||||
# The ArcFaceHead doesn't have other module, just return after
|
||||
# unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def forward(self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
target: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
|
||||
pre_logits = self.pre_logits(feats)
|
||||
|
||||
# cos=(a*b)/(||a||*||b||)
|
||||
cosine = self.norm_linear(pre_logits)
|
||||
|
||||
if target is None:
|
||||
return self.s * cosine
|
||||
|
||||
phi = torch.cos(torch.acos(cosine) + self.m)
|
||||
|
||||
if self.easy_margin:
|
||||
# when cosine>0, choose phi
|
||||
# when cosine<=0, choose cosine
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
# when cos>th, choose phi
|
||||
# when cos<=th, choose cosine-mm
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||
|
||||
one_hot = torch.zeros(cosine.size(), device=pre_logits.device)
|
||||
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
|
||||
if self.ls_eps > 0:
|
||||
one_hot = (1 -
|
||||
self.ls_eps) * one_hot + self.ls_eps / self.num_classes
|
||||
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
||||
return output * self.s
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
else:
|
||||
target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
|
||||
# The part can be traced by torch.fx
|
||||
cls_score = self(feats, target)
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
loss = self.loss_module(
|
||||
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
|
||||
losses['loss'] = loss
|
||||
|
||||
return losses
|
|
@ -0,0 +1,299 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.fileio import list_from_file
|
||||
from mmengine.runner import autocast
|
||||
from mmengine.utils import is_seq_of
|
||||
|
||||
from mmcls.models.losses import convert_to_one_hot
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
class NormProduct(nn.Linear):
|
||||
"""An enhanced linear layer with k clustering centers to calculate product
|
||||
between normalized input and linear weight.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample
|
||||
k (int): The number of clustering centers. Defaults to 1.
|
||||
bias (bool): Whether there is bias. If set to ``False``, the
|
||||
layer will not learn an additive bias. Defaults to ``True``.
|
||||
feature_norm (bool): Whether to normalize the input feature.
|
||||
Defaults to ``True``.
|
||||
weight_norm (bool):Whether to normalize the weight.
|
||||
Defaults to ``True``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
k=1,
|
||||
bias: bool = False,
|
||||
feature_norm: bool = True,
|
||||
weight_norm: bool = True):
|
||||
|
||||
super().__init__(in_features, out_features * k, bias=bias)
|
||||
self.weight_norm = weight_norm
|
||||
self.feature_norm = feature_norm
|
||||
self.out_features = out_features
|
||||
self.k = k
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.feature_norm:
|
||||
input = F.normalize(input)
|
||||
if self.weight_norm:
|
||||
weight = F.normalize(self.weight)
|
||||
else:
|
||||
weight = self.weight
|
||||
cosine_all = F.linear(input, weight, self.bias)
|
||||
|
||||
if self.k == 1:
|
||||
return cosine_all
|
||||
else:
|
||||
cosine_all = cosine_all.view(-1, self.out_features, self.k)
|
||||
cosine, _ = torch.max(cosine_all, dim=2)
|
||||
return cosine
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ArcFaceClsHead(ClsHead):
|
||||
"""ArcFace classifier head.
|
||||
|
||||
A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss
|
||||
for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_ and
|
||||
`Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web
|
||||
Faces <https://link.springer.com/chapter/10.1007/978-3-030-58621-8_43>`_
|
||||
|
||||
Example:
|
||||
To use ArcFace in config files.
|
||||
|
||||
1. use vanilla ArcFace
|
||||
|
||||
.. code:: python
|
||||
|
||||
mode = dict(
|
||||
backbone = xxx,
|
||||
neck = xxxx,
|
||||
head=dict(
|
||||
type='ArcFaceClsHead',
|
||||
num_classes=5000,
|
||||
in_channels=1024,
|
||||
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg=None),
|
||||
)
|
||||
|
||||
2. use SubCenterArcFace with 3 sub-centers
|
||||
|
||||
.. code:: python
|
||||
|
||||
mode = dict(
|
||||
backbone = xxx,
|
||||
neck = xxxx,
|
||||
head=dict(
|
||||
type='ArcFaceClsHead',
|
||||
num_classes=5000,
|
||||
in_channels=1024,
|
||||
num_subcenters=3,
|
||||
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg=None),
|
||||
)
|
||||
|
||||
3. use SubCenterArcFace With CountPowerAdaptiveMargins
|
||||
|
||||
.. code:: python
|
||||
|
||||
mode = dict(
|
||||
backbone = xxx,
|
||||
neck = xxxx,
|
||||
head=dict(
|
||||
type='ArcFaceClsHead',
|
||||
num_classes=5000,
|
||||
in_channels=1024,
|
||||
num_subcenters=3,
|
||||
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg=None),
|
||||
)
|
||||
|
||||
custom_hooks = [dict(type='SetAdaptiveMarginsHook')]
|
||||
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
num_subcenters (int): Number of subcenters. Defaults to 1.
|
||||
scale (float): Scale factor of output logit. Defaults to 64.0.
|
||||
margins (float): The penalty margin. Could be the fllowing formats:
|
||||
|
||||
- float: The margin, would be same for all the categories.
|
||||
- Sequence[float]: The category-based margins list.
|
||||
- str: A '.txt' file path which contains a list. Each line
|
||||
represents the margin of a category, and the number in the
|
||||
i-th row indicates the margin of the i-th class.
|
||||
|
||||
Defaults to 0.5.
|
||||
easy_margin (bool): Avoid theta + m >= PI. Defaults to False.
|
||||
loss (dict): Config of classification loss. Defaults to
|
||||
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
|
||||
init_cfg (dict, optional): the config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
num_subcenters: int = 1,
|
||||
scale: float = 64.,
|
||||
margins: Optional[Union[float, Sequence[float], str]] = 0.50,
|
||||
easy_margin: bool = False,
|
||||
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
||||
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
assert num_subcenters >= 1 and num_classes >= 0
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self.num_subcenters = num_subcenters
|
||||
self.scale = scale
|
||||
self.easy_margin = easy_margin
|
||||
|
||||
self.norm_product = NormProduct(in_channels, num_classes,
|
||||
num_subcenters)
|
||||
|
||||
if isinstance(margins, float):
|
||||
margins = [margins] * num_classes
|
||||
elif isinstance(margins, str) and margins.endswith('.txt'):
|
||||
margins = [float(item) for item in list_from_file(margins)]
|
||||
else:
|
||||
assert is_seq_of(list(margins), (float, int)), (
|
||||
'the attribute `margins` in ``ArcFaceClsHead`` should be a '
|
||||
' float, a Sequence of float, or a ".txt" file path.')
|
||||
|
||||
assert len(margins) == num_classes, \
|
||||
'The length of margins must be equal with num_classes.'
|
||||
|
||||
self.register_buffer(
|
||||
'margins', torch.tensor(margins).float(), persistent=False)
|
||||
# To make `phi` monotonic decreasing, refers to
|
||||
# https://github.com/deepinsight/insightface/issues/108
|
||||
sinm_m = torch.sin(math.pi - self.margins) * self.margins
|
||||
threshold = torch.cos(math.pi - self.margins)
|
||||
self.register_buffer('sinm_m', sinm_m, persistent=False)
|
||||
self.register_buffer('threshold', threshold, persistent=False)
|
||||
|
||||
def set_margins(self, margins: Union[Sequence[float], float]) -> None:
|
||||
"""set margins of arcface head.
|
||||
|
||||
Args:
|
||||
margins (Union[Sequence[float], float]): The marigins.
|
||||
"""
|
||||
if isinstance(margins, float):
|
||||
margins = [margins] * self.num_classes
|
||||
assert is_seq_of(
|
||||
list(margins), float) and (len(margins) == self.num_classes), (
|
||||
f'margins must be Sequence[Union(float, int)], get {margins}')
|
||||
|
||||
self.margins = torch.tensor(
|
||||
margins, device=self.margins.device, dtype=torch.float32)
|
||||
self.sinm_m = torch.sin(self.margins) * self.margins
|
||||
self.threshold = -torch.cos(self.margins)
|
||||
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``ArcFaceHead``, we just obtain the
|
||||
feature of the last stage.
|
||||
"""
|
||||
# The ArcFaceHead doesn't have other module, just return after
|
||||
# unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def _get_logit_with_margin(self, pre_logits, target):
|
||||
"""add arc margin to the cosine in target index.
|
||||
|
||||
The target must be in index format.
|
||||
"""
|
||||
assert target.dim() == 1 or (
|
||||
target.dim() == 2 and target.shape[1] == 1), \
|
||||
'The target must be in index format.'
|
||||
cosine = self.norm_product(pre_logits)
|
||||
phi = torch.cos(torch.acos(cosine) + self.margins)
|
||||
|
||||
if self.easy_margin:
|
||||
# when cosine>0, choose phi
|
||||
# when cosine<=0, choose cosine
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
# when cos>th, choose phi
|
||||
# when cos<=th, choose cosine-mm
|
||||
phi = torch.where(cosine > self.threshold, phi,
|
||||
cosine - self.sinm_m)
|
||||
|
||||
target = convert_to_one_hot(target, self.num_classes)
|
||||
output = target * phi + (1 - target) * cosine
|
||||
return output
|
||||
|
||||
def forward(self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
target: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
# Disable AMP
|
||||
with autocast(enabled=False):
|
||||
pre_logits = self.pre_logits(feats)
|
||||
|
||||
if target is None:
|
||||
# when eval, logit is the cosine between W and pre_logits;
|
||||
# cos(theta_yj) = (x/||x||) * (W/||W||)
|
||||
logit = self.norm_product(pre_logits)
|
||||
else:
|
||||
# when training, add a margin to the pre_logits where target is
|
||||
# True, then logit is the cosine between W and new pre_logits
|
||||
logit = self._get_logit_with_margin(pre_logits, target)
|
||||
|
||||
return self.scale * logit
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
# Unpack data samples and pack targets
|
||||
label_target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
else:
|
||||
# change the labels to to one-hot format scores.
|
||||
target = label_target
|
||||
|
||||
# the index format target would be used
|
||||
cls_score = self(feats, label_target)
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
loss = self.loss_module(
|
||||
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
|
||||
losses['loss'] = loss
|
||||
|
||||
return losses
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.runner import Runner
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __init__(self):
|
||||
self.index = 0
|
||||
self.metainfo = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(imgs=torch.rand((224, 224, 3)).float(), )
|
||||
return results
|
||||
|
||||
def get_gt_labels(self):
|
||||
gt_labels = np.array([0, 1, 2, 4, 0, 4, 1, 2, 2, 1])
|
||||
return gt_labels
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
|
||||
class TestSetAdaptiveMarginsHook(TestCase):
|
||||
DEFAULT_HOOK_CFG = dict(type='SetAdaptiveMarginsHook')
|
||||
DEFAULT_MODEL = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=34,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(type='ArcFaceClsHead', in_channels=512, num_classes=5))
|
||||
|
||||
def test_before_train(self):
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=None,
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='VisualizationHook', enable=False),
|
||||
)
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
loader = DataLoader(ExampleDataset(), batch_size=2)
|
||||
self.runner = Runner(
|
||||
model=self.DEFAULT_MODEL,
|
||||
work_dir=tmpdir.name,
|
||||
train_dataloader=loader,
|
||||
train_cfg=dict(by_epoch=True, max_epochs=1),
|
||||
log_level='WARNING',
|
||||
optim_wrapper=dict(
|
||||
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
|
||||
param_scheduler=dict(
|
||||
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
|
||||
default_scope='mmcls',
|
||||
default_hooks=default_hooks,
|
||||
experiment_name='test_construct_with_arcface',
|
||||
custom_hooks=[self.DEFAULT_HOOK_CFG])
|
||||
|
||||
default_margins = torch.tensor([0.5] * 5)
|
||||
torch.allclose(self.runner.model.head.margins.cpu(), default_margins)
|
||||
self.runner.call_hook('before_train')
|
||||
# counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once
|
||||
# feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642]
|
||||
# normized = [0.33752196, 0. , 0. , 1. , 0.33752196]
|
||||
# margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488]
|
||||
expert_margins = torch.tensor(
|
||||
[0.20188488, 0.05, 0.05, 0.5, 0.20188488])
|
||||
torch.allclose(self.runner.model.head.margins.cpu(), expert_margins)
|
||||
|
||||
model_cfg = {**self.DEFAULT_MODEL}
|
||||
model_cfg['head'] = dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
)
|
||||
self.runner = Runner(
|
||||
model=model_cfg,
|
||||
work_dir=tmpdir.name,
|
||||
train_dataloader=loader,
|
||||
train_cfg=dict(by_epoch=True, max_epochs=1),
|
||||
log_level='WARNING',
|
||||
optim_wrapper=dict(
|
||||
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
|
||||
param_scheduler=dict(
|
||||
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
|
||||
default_scope='mmcls',
|
||||
default_hooks=default_hooks,
|
||||
experiment_name='test_construct_wo_arcface',
|
||||
custom_hooks=[self.DEFAULT_HOOK_CFG])
|
||||
with self.assertRaises(ValueError):
|
||||
self.runner.call_hook('before_train')
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
|
@ -486,9 +488,37 @@ class TestArcFaceClsHead(TestCase):
|
|||
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
||||
with self.assertRaises(AssertionError):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 0})
|
||||
|
||||
# Test margins
|
||||
with self.assertRaises(AssertionError):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'margins': dict()})
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4})
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4 + ['0.1']})
|
||||
|
||||
arcface = MODELS.build(self.DEFAULT_ARGS)
|
||||
torch.allclose(arcface.margins, torch.tensor([0.5] * 5))
|
||||
|
||||
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 5})
|
||||
torch.allclose(arcface.margins, torch.tensor([0.1] * 5))
|
||||
|
||||
margins = [0.1, 0.2, 0.3, 0.4, 5]
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmp_path = os.path.join(tmpdirname, 'margins.txt')
|
||||
with open(tmp_path, 'w') as tmp_file:
|
||||
for m in margins:
|
||||
tmp_file.write(f'{m}\n')
|
||||
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': tmp_path})
|
||||
torch.allclose(arcface.margins, torch.tensor(margins))
|
||||
|
||||
def test_pre_logits(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
|
@ -497,10 +527,29 @@ class TestArcFaceClsHead(TestCase):
|
|||
pre_logits = head.pre_logits(feats)
|
||||
self.assertIs(pre_logits, feats[-1])
|
||||
|
||||
# Test with SubCenterArcFace
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
pre_logits = head.pre_logits(feats)
|
||||
self.assertIs(pre_logits, feats[-1])
|
||||
|
||||
def test_forward(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
# target is not None
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
target = torch.zeros(4).long()
|
||||
outs = head(feats, target)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
# target is None
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
outs = head(feats)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
# Test with SubCenterArcFace
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
|
||||
# target is not None
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
target = torch.zeros(4)
|
||||
outs = head(feats, target)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
@ -519,3 +568,10 @@ class TestArcFaceClsHead(TestCase):
|
|||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(losses.keys(), {'loss'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
||||
# Test with SubCenterArcFace
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
|
||||
# test loss with used='before'
|
||||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(losses.keys(), {'loss'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
|
Loading…
Reference in New Issue