62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved
|
|
import numpy as np
|
|
from mmengine.hooks import Hook
|
|
from mmengine.model import is_model_wrapper
|
|
|
|
from mmpretrain.models.heads import ArcFaceClsHead
|
|
from mmpretrain.registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class SetAdaptiveMarginsHook(Hook):
|
|
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
|
|
category-wise count.
|
|
|
|
A PyTorch implementation of paper `Google Landmark Recognition 2020
|
|
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_.
|
|
The margins will be
|
|
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
|
|
The `n` indicates the number of occurrences of a category.
|
|
|
|
Args:
|
|
margin_min (float): Lower bound of margins. Defaults to 0.05.
|
|
margin_max (float): Upper bound of margins. Defaults to 0.5.
|
|
power (float): The power of category freqercy. Defaults to -0.25.
|
|
"""
|
|
|
|
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
|
|
self.margin_min = margin_min
|
|
self.margin_max = margin_max
|
|
self.margin_range = margin_max - margin_min
|
|
self.p = power
|
|
|
|
def before_train(self, runner):
|
|
"""change the margins in ArcFaceClsHead.
|
|
|
|
Args:
|
|
runner (obj: `Runner`): Runner.
|
|
"""
|
|
model = runner.model
|
|
if is_model_wrapper(model):
|
|
model = model.module
|
|
|
|
if (hasattr(model, 'head')
|
|
and not isinstance(model.head, ArcFaceClsHead)):
|
|
raise ValueError(
|
|
'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
|
|
f'for ``ArcFaceClsHead``, but get {type(model.head)}')
|
|
|
|
# generate margins base on the dataset.
|
|
gt_labels = runner.train_dataloader.dataset.get_gt_labels()
|
|
label_count = np.bincount(gt_labels)
|
|
label_count[label_count == 0] = 1 # At least one occurrence
|
|
pow_freq = np.power(label_count, self.p)
|
|
|
|
min_f, max_f = pow_freq.min(), pow_freq.max()
|
|
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
|
|
margins = normized_pow_freq * self.margin_range + self.margin_min
|
|
|
|
assert len(margins) == runner.model.head.num_classes
|
|
|
|
model.head.set_margins(margins)
|