mmclassification/tests/test_engine/test_hooks/test_arcface_hooks.py

103 lines
3.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import numpy as np
import torch
from mmengine.runner import Runner
from torch.utils.data import DataLoader, Dataset
class ExampleDataset(Dataset):
def __init__(self):
self.index = 0
self.metainfo = None
def __getitem__(self, idx):
results = dict(imgs=torch.rand((224, 224, 3)).float(), )
return results
def get_gt_labels(self):
gt_labels = np.array([0, 1, 2, 4, 0, 4, 1, 2, 2, 1])
return gt_labels
def __len__(self):
return 10
class TestSetAdaptiveMarginsHook(TestCase):
DEFAULT_HOOK_CFG = dict(type='SetAdaptiveMarginsHook')
DEFAULT_MODEL = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(type='ArcFaceClsHead', in_channels=512, num_classes=5))
def test_before_train(self):
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=None,
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='VisualizationHook', enable=False),
)
tmpdir = tempfile.TemporaryDirectory()
loader = DataLoader(ExampleDataset(), batch_size=2)
self.runner = Runner(
model=self.DEFAULT_MODEL,
work_dir=tmpdir.name,
train_dataloader=loader,
train_cfg=dict(by_epoch=True, max_epochs=1),
log_level='WARNING',
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
param_scheduler=dict(
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
default_scope='mmcls',
default_hooks=default_hooks,
experiment_name='test_construct_with_arcface',
custom_hooks=[self.DEFAULT_HOOK_CFG])
default_margins = torch.tensor([0.5] * 5)
torch.allclose(self.runner.model.head.margins.cpu(), default_margins)
self.runner.call_hook('before_train')
# counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once
# feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642]
# normized = [0.33752196, 0. , 0. , 1. , 0.33752196]
# margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488]
expert_margins = torch.tensor(
[0.20188488, 0.05, 0.05, 0.5, 0.20188488])
torch.allclose(self.runner.model.head.margins.cpu(), expert_margins)
model_cfg = {**self.DEFAULT_MODEL}
model_cfg['head'] = dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
)
self.runner = Runner(
model=model_cfg,
work_dir=tmpdir.name,
train_dataloader=loader,
train_cfg=dict(by_epoch=True, max_epochs=1),
log_level='WARNING',
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
param_scheduler=dict(
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
default_scope='mmcls',
default_hooks=default_hooks,
experiment_name='test_construct_wo_arcface',
custom_hooks=[self.DEFAULT_HOOK_CFG])
with self.assertRaises(ValueError):
self.runner.call_hook('before_train')