103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import tempfile
|
||
|
from unittest import TestCase
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from mmengine.runner import Runner
|
||
|
from torch.utils.data import DataLoader, Dataset
|
||
|
|
||
|
|
||
|
class ExampleDataset(Dataset):
|
||
|
|
||
|
def __init__(self):
|
||
|
self.index = 0
|
||
|
self.metainfo = None
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
results = dict(imgs=torch.rand((224, 224, 3)).float(), )
|
||
|
return results
|
||
|
|
||
|
def get_gt_labels(self):
|
||
|
gt_labels = np.array([0, 1, 2, 4, 0, 4, 1, 2, 2, 1])
|
||
|
return gt_labels
|
||
|
|
||
|
def __len__(self):
|
||
|
return 10
|
||
|
|
||
|
|
||
|
class TestSetAdaptiveMarginsHook(TestCase):
|
||
|
DEFAULT_HOOK_CFG = dict(type='SetAdaptiveMarginsHook')
|
||
|
DEFAULT_MODEL = dict(
|
||
|
type='ImageClassifier',
|
||
|
backbone=dict(
|
||
|
type='ResNet',
|
||
|
depth=34,
|
||
|
num_stages=4,
|
||
|
out_indices=(3, ),
|
||
|
style='pytorch'),
|
||
|
neck=dict(type='GlobalAveragePooling'),
|
||
|
head=dict(type='ArcFaceClsHead', in_channels=512, num_classes=5))
|
||
|
|
||
|
def test_before_train(self):
|
||
|
default_hooks = dict(
|
||
|
timer=dict(type='IterTimerHook'),
|
||
|
logger=None,
|
||
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||
|
checkpoint=dict(type='CheckpointHook', interval=1),
|
||
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||
|
visualization=dict(type='VisualizationHook', enable=False),
|
||
|
)
|
||
|
tmpdir = tempfile.TemporaryDirectory()
|
||
|
loader = DataLoader(ExampleDataset(), batch_size=2)
|
||
|
self.runner = Runner(
|
||
|
model=self.DEFAULT_MODEL,
|
||
|
work_dir=tmpdir.name,
|
||
|
train_dataloader=loader,
|
||
|
train_cfg=dict(by_epoch=True, max_epochs=1),
|
||
|
log_level='WARNING',
|
||
|
optim_wrapper=dict(
|
||
|
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
|
||
|
param_scheduler=dict(
|
||
|
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
|
||
|
default_scope='mmcls',
|
||
|
default_hooks=default_hooks,
|
||
|
experiment_name='test_construct_with_arcface',
|
||
|
custom_hooks=[self.DEFAULT_HOOK_CFG])
|
||
|
|
||
|
default_margins = torch.tensor([0.5] * 5)
|
||
|
torch.allclose(self.runner.model.head.margins.cpu(), default_margins)
|
||
|
self.runner.call_hook('before_train')
|
||
|
# counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once
|
||
|
# feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642]
|
||
|
# normized = [0.33752196, 0. , 0. , 1. , 0.33752196]
|
||
|
# margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488]
|
||
|
expert_margins = torch.tensor(
|
||
|
[0.20188488, 0.05, 0.05, 0.5, 0.20188488])
|
||
|
torch.allclose(self.runner.model.head.margins.cpu(), expert_margins)
|
||
|
|
||
|
model_cfg = {**self.DEFAULT_MODEL}
|
||
|
model_cfg['head'] = dict(
|
||
|
type='LinearClsHead',
|
||
|
num_classes=1000,
|
||
|
in_channels=512,
|
||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||
|
topk=(1, 5),
|
||
|
)
|
||
|
self.runner = Runner(
|
||
|
model=model_cfg,
|
||
|
work_dir=tmpdir.name,
|
||
|
train_dataloader=loader,
|
||
|
train_cfg=dict(by_epoch=True, max_epochs=1),
|
||
|
log_level='WARNING',
|
||
|
optim_wrapper=dict(
|
||
|
optimizer=dict(type='SGD', lr=0.1, momentum=0.9)),
|
||
|
param_scheduler=dict(
|
||
|
type='MultiStepLR', milestones=[1, 2], gamma=0.1),
|
||
|
default_scope='mmcls',
|
||
|
default_hooks=default_hooks,
|
||
|
experiment_name='test_construct_wo_arcface',
|
||
|
custom_hooks=[self.DEFAULT_HOOK_CFG])
|
||
|
with self.assertRaises(ValueError):
|
||
|
self.runner.call_hook('before_train')
|