[Enhance] Enhance ArcFaceClsHead. (#1181)

* update arcface

* fix unit tests

* add adv-margins

add adv-margins

update arcface

* rebase

* update doc and fix ut

* rebase

* update code

* rebase

* use label data

* update set-margins

* Modify Arcface related method names.

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1211/head
Ezra-Yu 2022-11-21 18:10:39 +08:00 committed by GitHub
parent 4fb44f8770
commit b0007812d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 535 additions and 185 deletions

View File

@ -31,7 +31,8 @@ Hooks
ClassNumCheckHook
PreciseBNHook
VisualizationHook
SwitchRecipeHook
PrepareProtoBeforeValLoopHook
SetAdaptiveMarginsHook
.. module:: mmcls.engine.optimizers

View File

@ -140,6 +140,7 @@ Heads
EfficientFormerClsHead
DeiTClsHead
ConformerHead
ArcFaceClsHead
MultiLabelClsHead
MultiLabelLinearClsHead
CSRAClsHead

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .class_num_check_hook import ClassNumCheckHook
from .margin_head_hooks import SetAdaptiveMarginsHook
from .precise_bn_hook import PreciseBNHook
from .retriever_hooks import PrepareProtoBeforeValLoopHook
from .switch_recipe_hook import SwitchRecipeHook
@ -7,5 +8,6 @@ from .visualization_hook import VisualizationHook
__all__ = [
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook'
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
'SetAdaptiveMarginsHook'
]

View File

@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved
import numpy as np
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmcls.models.heads import ArcFaceClsHead
from mmcls.registry import HOOKS
@HOOKS.register_module()
class SetAdaptiveMarginsHook(Hook):
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
category-wise count.
A PyTorch implementation of paper `Google Landmark Recognition 2020
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_.
The margins will be
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
The `n` indicates the number of occurrences of a category.
Args:
margin_min (float): Lower bound of margins. Defaults to 0.05.
margin_max (float): Upper bound of margins. Defaults to 0.5.
power (float): The power of category freqercy. Defaults to -0.25.
"""
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
self.margin_min = margin_min
self.margin_max = margin_max
self.margin_range = margin_max - margin_min
self.p = power
def before_train(self, runner):
"""change the margins in ArcFaceClsHead.
Args:
runner (obj: `Runner`): Runner.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
if (hasattr(model, 'head')
and not isinstance(model.head, ArcFaceClsHead)):
raise ValueError(
'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
f'for ``ArcFaceClsHead``, but get {type(model.head)}')
# generate margins base on the dataset.
gt_labels = runner.train_dataloader.dataset.get_gt_labels()
label_count = np.bincount(gt_labels)
label_count[label_count == 0] = 1 # At least one occurrence
pow_freq = np.power(label_count, self.p)
min_f, max_f = pow_freq.min(), pow_freq.max()
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
margins = normized_pow_freq * self.margin_range + self.margin_min
assert len(margins) == runner.model.head.num_classes
model.head.set_margins(margins)

View File

@ -250,13 +250,16 @@ class HorNetBlock(nn.Module):
@MODELS.register_module()
class HorNet(BaseBackbone):
"""HorNet
A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions
with Recursive Gated Convolutions`
Inspiration from
https://github.com/raoyongming/HorNet
"""HorNet.
A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial
Interactions with Recursive Gated Convolutions
<https://arxiv.org/abs/2207.14284>`_ .
Inspiration from https://github.com/raoyongming/HorNet
Args:
arch (str | dict): HorNet architecture.
If use string, choose from 'tiny', 'small', 'base' and 'large'.
If use dict, it should have below keys:
- **base_dim** (int): The base dimensions of embedding.
@ -264,6 +267,7 @@ class HorNet(BaseBackbone):
- **orders** (List[int]): The number of order of gnConv in each
stage.
- **dw_cfg** (List[dict]): The Config for dw conv.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .arcface_head import ArcFaceClsHead
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .efficientformer_head import EfficientFormerClsHead
from .linear_head import LinearClsHead
from .margin_head import ArcFaceClsHead
from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead

View File

@ -1,176 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from .cls_head import ClsHead
class NormLinear(nn.Linear):
"""An enhanced linear layer, which could normalize the input and the linear
weight.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample
bias (bool): Whether there is bias. If set to ``False``, the
layer will not learn an additive bias. Defaults to ``True``.
feature_norm (bool): Whether to normalize the input feature.
Defaults to ``True``.
weight_norm (bool):Whether to normalize the weight.
Defaults to ``True``.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = False,
feature_norm: bool = True,
weight_norm: bool = True):
super().__init__(in_features, out_features, bias=bias)
self.weight_norm = weight_norm
self.feature_norm = feature_norm
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.feature_norm:
input = F.normalize(input)
if self.weight_norm:
weight = F.normalize(self.weight)
else:
weight = self.weight
return F.linear(input, weight, self.bias)
@MODELS.register_module()
class ArcFaceClsHead(ClsHead):
"""ArcFace classifier head.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
s (float): Norm of input feature. Defaults to 30.0.
m (float): Margin. Defaults to 0.5.
easy_margin (bool): Avoid theta + m >= PI. Defaults to False.
ls_eps (float): Label smoothing. Defaults to 0.
bias (bool): Whether to use bias in norm layer. Defaults to False.
loss (dict): Config of classification loss. Defaults to
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: int,
s: float = 30.0,
m: float = 0.50,
easy_margin: bool = False,
ls_eps: float = 0.0,
bias: bool = False,
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg: Optional[dict] = None):
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
self.in_channels = in_channels
self.num_classes = num_classes
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
self.s = s
self.m = m
self.ls_eps = ls_eps
self.norm_linear = NormLinear(in_channels, num_classes, bias=bias)
self.easy_margin = easy_margin
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``ArcFaceHead``, we just obtain the
feature of the last stage.
"""
# The ArcFaceHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def forward(self,
feats: Tuple[torch.Tensor],
target: Optional[torch.Tensor] = None) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# cos=(a*b)/(||a||*||b||)
cosine = self.norm_linear(pre_logits)
if target is None:
return self.s * cosine
phi = torch.cos(torch.acos(cosine) + self.m)
if self.easy_margin:
# when cosine>0, choose phi
# when cosine<=0, choose cosine
phi = torch.where(cosine > 0, phi, cosine)
else:
# when cos>th, choose phi
# when cos<=th, choose cosine-mm
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=pre_logits.device)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 -
self.ls_eps) * one_hot + self.ls_eps / self.num_classes
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
return output * self.s
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
# The part can be traced by torch.fx
cls_score = self(feats, target)
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
return losses

View File

@ -0,0 +1,299 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.fileio import list_from_file
from mmengine.runner import autocast
from mmengine.utils import is_seq_of
from mmcls.models.losses import convert_to_one_hot
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from .cls_head import ClsHead
class NormProduct(nn.Linear):
"""An enhanced linear layer with k clustering centers to calculate product
between normalized input and linear weight.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample
k (int): The number of clustering centers. Defaults to 1.
bias (bool): Whether there is bias. If set to ``False``, the
layer will not learn an additive bias. Defaults to ``True``.
feature_norm (bool): Whether to normalize the input feature.
Defaults to ``True``.
weight_norm (bool):Whether to normalize the weight.
Defaults to ``True``.
"""
def __init__(self,
in_features: int,
out_features: int,
k=1,
bias: bool = False,
feature_norm: bool = True,
weight_norm: bool = True):
super().__init__(in_features, out_features * k, bias=bias)
self.weight_norm = weight_norm
self.feature_norm = feature_norm
self.out_features = out_features
self.k = k
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.feature_norm:
input = F.normalize(input)
if self.weight_norm:
weight = F.normalize(self.weight)
else:
weight = self.weight
cosine_all = F.linear(input, weight, self.bias)
if self.k == 1:
return cosine_all
else:
cosine_all = cosine_all.view(-1, self.out_features, self.k)
cosine, _ = torch.max(cosine_all, dim=2)
return cosine
@MODELS.register_module()
class ArcFaceClsHead(ClsHead):
"""ArcFace classifier head.
A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss
for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_ and
`Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web
Faces <https://link.springer.com/chapter/10.1007/978-3-030-58621-8_43>`_
Example:
To use ArcFace in config files.
1. use vanilla ArcFace
.. code:: python
mode = dict(
backbone = xxx,
neck = xxxx,
head=dict(
type='ArcFaceClsHead',
num_classes=5000,
in_channels=1024,
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=None),
)
2. use SubCenterArcFace with 3 sub-centers
.. code:: python
mode = dict(
backbone = xxx,
neck = xxxx,
head=dict(
type='ArcFaceClsHead',
num_classes=5000,
in_channels=1024,
num_subcenters=3,
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=None),
)
3. use SubCenterArcFace With CountPowerAdaptiveMargins
.. code:: python
mode = dict(
backbone = xxx,
neck = xxxx,
head=dict(
type='ArcFaceClsHead',
num_classes=5000,
in_channels=1024,
num_subcenters=3,
loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=None),
)
custom_hooks = [dict(type='SetAdaptiveMarginsHook')]
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
num_subcenters (int): Number of subcenters. Defaults to 1.
scale (float): Scale factor of output logit. Defaults to 64.0.
margins (float): The penalty margin. Could be the fllowing formats:
- float: The margin, would be same for all the categories.
- Sequence[float]: The category-based margins list.
- str: A '.txt' file path which contains a list. Each line
represents the margin of a category, and the number in the
i-th row indicates the margin of the i-th class.
Defaults to 0.5.
easy_margin (bool): Avoid theta + m >= PI. Defaults to False.
loss (dict): Config of classification loss. Defaults to
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: int,
num_subcenters: int = 1,
scale: float = 64.,
margins: Optional[Union[float, Sequence[float], str]] = 0.50,
easy_margin: bool = False,
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg: Optional[dict] = None):
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
assert num_subcenters >= 1 and num_classes >= 0
self.in_channels = in_channels
self.num_classes = num_classes
self.num_subcenters = num_subcenters
self.scale = scale
self.easy_margin = easy_margin
self.norm_product = NormProduct(in_channels, num_classes,
num_subcenters)
if isinstance(margins, float):
margins = [margins] * num_classes
elif isinstance(margins, str) and margins.endswith('.txt'):
margins = [float(item) for item in list_from_file(margins)]
else:
assert is_seq_of(list(margins), (float, int)), (
'the attribute `margins` in ``ArcFaceClsHead`` should be a '
' float, a Sequence of float, or a ".txt" file path.')
assert len(margins) == num_classes, \
'The length of margins must be equal with num_classes.'
self.register_buffer(
'margins', torch.tensor(margins).float(), persistent=False)
# To make `phi` monotonic decreasing, refers to
# https://github.com/deepinsight/insightface/issues/108
sinm_m = torch.sin(math.pi - self.margins) * self.margins
threshold = torch.cos(math.pi - self.margins)
self.register_buffer('sinm_m', sinm_m, persistent=False)
self.register_buffer('threshold', threshold, persistent=False)
def set_margins(self, margins: Union[Sequence[float], float]) -> None:
"""set margins of arcface head.
Args:
margins (Union[Sequence[float], float]): The marigins.
"""
if isinstance(margins, float):
margins = [margins] * self.num_classes
assert is_seq_of(
list(margins), float) and (len(margins) == self.num_classes), (
f'margins must be Sequence[Union(float, int)], get {margins}')
self.margins = torch.tensor(
margins, device=self.margins.device, dtype=torch.float32)
self.sinm_m = torch.sin(self.margins) * self.margins
self.threshold = -torch.cos(self.margins)
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``ArcFaceHead``, we just obtain the
feature of the last stage.
"""
# The ArcFaceHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def _get_logit_with_margin(self, pre_logits, target):
"""add arc margin to the cosine in target index.
The target must be in index format.
"""
assert target.dim() == 1 or (
target.dim() == 2 and target.shape[1] == 1), \
'The target must be in index format.'
cosine = self.norm_product(pre_logits)
phi = torch.cos(torch.acos(cosine) + self.margins)
if self.easy_margin:
# when cosine>0, choose phi
# when cosine<=0, choose cosine
phi = torch.where(cosine > 0, phi, cosine)
else:
# when cos>th, choose phi
# when cos<=th, choose cosine-mm
phi = torch.where(cosine > self.threshold, phi,
cosine - self.sinm_m)
target = convert_to_one_hot(target, self.num_classes)
output = target * phi + (1 - target) * cosine
return output
def forward(self,
feats: Tuple[torch.Tensor],
target: Optional[torch.Tensor] = None) -> torch.Tensor:
"""The forward process."""
# Disable AMP
with autocast(enabled=False):
pre_logits = self.pre_logits(feats)
if target is None:
# when eval, logit is the cosine between W and pre_logits;
# cos(theta_yj) = (x/||x||) * (W/||W||)
logit = self.norm_product(pre_logits)
else:
# when training, add a margin to the pre_logits where target is
# True, then logit is the cosine between W and new pre_logits
logit = self._get_logit_with_margin(pre_logits, target)
return self.scale * logit
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# Unpack data samples and pack targets
label_target = torch.cat([i.gt_label.label for i in data_samples])
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
else:
# change the labels to to one-hot format scores.
target = label_target
# the index format target would be used
cls_score = self(feats, label_target)
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
return losses

View File

@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import numpy as np
import torch
from mmengine.runner import Runner
from torch.utils.data import DataLoader, Dataset
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
self.metainfo = None
def __getitem__(self, idx):
results = dict(imgs=torch.rand((224, 224, 3)).float(), )
return results
def get_gt_labels(self):
gt_labels = np.array([0, 1, 2, 4, 0, 4, 1, 2, 2, 1])
return gt_labels
def __len__(self):
return 10
class TestSetAdaptiveMarginsHook(TestCase):
DEFAULT_HOOK_CFG = dict(type='SetAdaptiveMarginsHook')
DEFAULT_MODEL = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(type='ArcFaceClsHead', in_channels=512, num_classes=5))
def test_before_train(self):
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=None,
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='VisualizationHook', enable=False),
)
tmpdir = tempfile.TemporaryDirectory()
loader = DataLoader(ExampleDataset(), batch_size=2)
self.runner = Runner(
model=self.DEFAULT_MODEL,
work_dir=tmpdir.name,
train_dataloader=loader,
train_cfg=dict(by_epoch=True, max_epochs=1),
log_level='WARNING',
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
param_scheduler=dict(
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
default_scope='mmcls',
default_hooks=default_hooks,
experiment_name='test_construct_with_arcface',
custom_hooks=[self.DEFAULT_HOOK_CFG])
default_margins = torch.tensor([0.5] * 5)
torch.allclose(self.runner.model.head.margins.cpu(), default_margins)
self.runner.call_hook('before_train')
# counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once
# feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642]
# normized = [0.33752196, 0. , 0. , 1. , 0.33752196]
# margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488]
expert_margins = torch.tensor(
[0.20188488, 0.05, 0.05, 0.5, 0.20188488])
torch.allclose(self.runner.model.head.margins.cpu(), expert_margins)
model_cfg = {**self.DEFAULT_MODEL}
model_cfg['head'] = dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
)
self.runner = Runner(
model=model_cfg,
work_dir=tmpdir.name,
train_dataloader=loader,
train_cfg=dict(by_epoch=True, max_epochs=1),
log_level='WARNING',
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
param_scheduler=dict(
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
default_scope='mmcls',
default_hooks=default_hooks,
experiment_name='test_construct_wo_arcface',
custom_hooks=[self.DEFAULT_HOOK_CFG])
with self.assertRaises(ValueError):
self.runner.call_hook('before_train')

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import random
import tempfile
from unittest import TestCase
import numpy as np
@ -486,9 +488,37 @@ class TestArcFaceClsHead(TestCase):
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 0})
# Test margins
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'margins': dict()})
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4})
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4 + ['0.1']})
arcface = MODELS.build(self.DEFAULT_ARGS)
torch.allclose(arcface.margins, torch.tensor([0.5] * 5))
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 5})
torch.allclose(arcface.margins, torch.tensor([0.1] * 5))
margins = [0.1, 0.2, 0.3, 0.4, 5]
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_path = os.path.join(tmpdirname, 'margins.txt')
with open(tmp_path, 'w') as tmp_file:
for m in margins:
tmp_file.write(f'{m}\n')
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': tmp_path})
torch.allclose(arcface.margins, torch.tensor(margins))
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
@ -497,10 +527,29 @@ class TestArcFaceClsHead(TestCase):
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
# Test with SubCenterArcFace
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# target is not None
feats = (torch.rand(4, 10), torch.rand(4, 10))
target = torch.zeros(4).long()
outs = head(feats, target)
self.assertEqual(outs.shape, (4, 5))
# target is None
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 5))
# Test with SubCenterArcFace
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
# target is not None
feats = (torch.rand(4, 10), torch.rand(4, 10))
target = torch.zeros(4)
outs = head(feats, target)
self.assertEqual(outs.shape, (4, 5))
@ -519,3 +568,10 @@ class TestArcFaceClsHead(TestCase):
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with SubCenterArcFace
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
# test loss with used='before'
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)