mmpretrain/tests/test_models/test_heads.py

737 lines
26 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import random
import tempfile
from unittest import TestCase
import numpy as np
import torch
from mmengine import is_seq_of
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample, MultiTaskDataSample
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
class TestClsHead(TestCase):
DEFAULT_ARGS = dict(type='ClsHead')
FAKE_FEATS = (torch.rand(4, 10), )
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertIs(outs, feats[-1])
def test_loss(self):
feats = self.FAKE_FEATS
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# with cal_acc = True
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
head = MODELS.build(cfg)
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(),
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
self.assertGreater(losses['loss'].item(), 0)
# test assertion when cal_acc but data is batch agumented.
data_samples = [
sample.set_gt_score(torch.rand(10)) for sample in data_samples
]
cfg = {
**self.DEFAULT_ARGS, 'cal_acc': True,
'loss': dict(type='CrossEntropyLoss', use_soft=True)
}
head = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
head.loss(feats, data_samples)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions:
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
class TestLinearClsHead(TestCase):
DEFAULT_ARGS = dict(type='LinearClsHead', in_channels=10, num_classes=5)
FAKE_FEATS = (torch.rand(4, 10), )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 5))
class TestVisionTransformerClsHead(TestCase):
DEFAULT_ARGS = dict(
type='VisionTransformerClsHead', in_channels=10, num_classes=5)
fake_feats = ([torch.rand(4, 7, 7, 16), torch.rand(4, 10)], )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
# test vit head default
head = MODELS.build(self.DEFAULT_ARGS)
assert not hasattr(head.layers, 'pre_logits')
assert not hasattr(head.layers, 'act')
# test vit head hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
assert hasattr(head.layers, 'pre_logits')
assert hasattr(head.layers, 'act')
# test vit head init_weights
head = MODELS.build(self.DEFAULT_ARGS)
head.init_weights()
# test vit head init_weights with hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1][1])
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
pre_logits = head.pre_logits(self.fake_feats)
self.assertEqual(pre_logits.shape, (4, 30))
def test_forward(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
class TestDeiTClsHead(TestVisionTransformerClsHead):
DEFAULT_ARGS = dict(type='DeiTClsHead', in_channels=10, num_classes=5)
fake_feats = ([
torch.rand(4, 7, 7, 16),
torch.rand(4, 10),
torch.rand(4, 10)
], )
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
cls_token, dist_token = head.pre_logits(self.fake_feats)
self.assertIs(cls_token, self.fake_feats[-1][1])
self.assertIs(dist_token, self.fake_feats[-1][2])
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
cls_token, dist_token = head.pre_logits(self.fake_feats)
self.assertEqual(cls_token.shape, (4, 30))
self.assertEqual(dist_token.shape, (4, 30))
class TestConformerHead(TestCase):
DEFAULT_ARGS = dict(
type='ConformerHead', in_channels=[64, 96], num_classes=5)
fake_feats = ([torch.rand(4, 64), torch.rand(4, 96)], )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
# test default
head = MODELS.build(self.DEFAULT_ARGS)
assert hasattr(head, 'conv_cls_head')
assert hasattr(head, 'trans_cls_head')
# test init_weights
head = MODELS.build(self.DEFAULT_ARGS)
head.init_weights()
assert abs(head.conv_cls_head.weight).sum() > 0
assert abs(head.trans_cls_head.weight).sum() > 0
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
outs = head(self.fake_feats)
self.assertEqual(outs[0].shape, (4, 5))
self.assertEqual(outs[1].shape, (4, 5))
def test_loss(self):
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(self.fake_feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# with cal_acc = True
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
head = MODELS.build(cfg)
losses = head.loss(self.fake_feats, data_samples)
self.assertEqual(losses.keys(),
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
self.assertGreater(losses['loss'].item(), 0)
# test assertion when cal_acc but data is batch agumented.
data_samples = [
sample.set_gt_score(torch.rand(5)) for sample in data_samples
]
cfg = {
**self.DEFAULT_ARGS, 'cal_acc': True,
'loss': dict(type='CrossEntropyLoss', use_soft=True)
}
head = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
head.loss(self.fake_feats, data_samples)
def test_predict(self):
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(self.fake_feats)
self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions:
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# with with data_samples
predictions = head.predict(self.fake_feats, data_samples)
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
class TestStackedLinearClsHead(TestCase):
DEFAULT_ARGS = dict(
type='StackedLinearClsHead', in_channels=10, num_classes=5)
fake_feats = (torch.rand(4, 10), )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({
**self.DEFAULT_ARGS, 'num_classes': -5,
'mid_channels': 10
})
# test mid_channels
with self.assertRaisesRegex(AssertionError, 'should be a sequence'):
MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': 10})
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20]})
assert len(head.layers) == 2
head.init_weights()
def test_pre_logits(self):
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
pre_logits = head.pre_logits(self.fake_feats)
self.assertEqual(pre_logits.shape, (4, 30))
def test_forward(self):
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
head = MODELS.build({
**self.DEFAULT_ARGS, 'mid_channels': [8, 10],
'dropout_rate': 0.2,
'norm_cfg': dict(type='BN1d'),
'act_cfg': dict(type='HSwish')
})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
class TestMultiLabelClsHead(TestCase):
DEFAULT_ARGS = dict(type='MultiLabelClsHead')
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertIs(outs, feats[-1])
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [DataSample().set_gt_label([0, 3]) for _ in range(4)]
# Test with thr and topk are all None
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, None, cfg)
self.assertEqual(head.topk, 2)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with thr
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
head = MODELS.build(cfg)
thr_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, None)
self.assertEqual(thr_losses.keys(), {'loss'})
self.assertGreater(thr_losses['loss'].item(), 0)
# Test with thr and topk are all not None
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
cfg['topk'] = 2
head = MODELS.build(cfg)
thr_topk_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, 2)
self.assertEqual(thr_topk_losses.keys(), {'loss'})
self.assertGreater(thr_topk_losses['loss'].item(), 0)
# Test with gt_lable with score
data_samples = [
DataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
]
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [DataSample().set_gt_label([1, 2]) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions:
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
predictions = head.predict(feats, data_samples)
self.assertEqual(head.thr, None)
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
class EfficientFormerClsHead(TestClsHead):
DEFAULT_ARGS = dict(
type='EfficientFormerClsHead',
in_channels=10,
num_classes=10,
distillation=False)
FAKE_FEATS = (torch.rand(4, 10), )
def test_forward(self):
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['distillation'] = True
head = MODELS.build(cfg)
self.assertTrue(hasattr(head, 'dist_head'))
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 10))
# test without distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
head = MODELS.build(cfg)
self.assertFalse(hasattr(head, 'dist_head'))
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 10))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['distillation'] = True
head = MODELS.build(cfg)
with self.assertRaisesRegex(NotImplementedError, 'MMPretrain '):
head.loss(feats, data_samples)
# test without distillation head
super().test_loss()
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
DEFAULT_ARGS = dict(
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
self.assertTrue(hasattr(head, 'fc'))
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
head(feats)
class TestMultiTaskHead(TestCase):
DEFAULT_ARGS = dict(
type='MultiTaskHead', # <- Head config, depends on #675
task_heads={
'task0': dict(type='LinearClsHead', num_classes=3),
'task1': dict(type='LinearClsHead', num_classes=6),
},
in_channels=10,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
)
DEFAULT_ARGS2 = dict(
type='MultiTaskHead', # <- Head config, depends on #675
task_heads={
'task0':
dict(
type='MultiTaskHead',
task_heads={
'task00': dict(type='LinearClsHead', num_classes=3),
'task01': dict(type='LinearClsHead', num_classes=6),
}),
'task1':
dict(type='LinearClsHead', num_classes=6)
},
in_channels=10,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
)
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), )
outs = head(feats)
self.assertEqual(outs['task0'].shape, (4, 3))
self.assertEqual(outs['task1'].shape, (4, 6))
self.assertTrue(isinstance(outs, dict))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = []
for _ in range(4):
data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = DataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(
losses.keys(),
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
self.assertGreater(losses['task0_loss'].item(), 0)
self.assertGreater(losses['task1_loss'].item(), 0)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = []
for _ in range(4):
data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = DataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS)
# without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for pred in predictions:
self.assertIn('task0', pred)
task0_sample = predictions[0].task0
self.assertTrue(type(task0_sample.pred_score), 'torch.tensor')
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('task0', pred)
# with data samples and nested
head_nested = MODELS.build(self.DEFAULT_ARGS2)
# adding a None data sample at the beginning
data_samples_nested = [None]
for _ in range(3):
data_sample_nested = MultiTaskDataSample()
data_sample_nested0 = MultiTaskDataSample()
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task00')
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task01')
data_sample_nested.set_field(data_sample_nested0, 'task0')
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
data_samples_nested.append(data_sample_nested)
predictions = head_nested.predict(feats, data_samples_nested)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for i in range(3):
sample = data_samples_nested[i + 1]
pred = predictions[i + 1]
self.assertIn('task0', pred)
self.assertIn('task1', pred)
self.assertIn('task01', pred.get('task0'))
self.assertEqual(
sample.get('task0').get('task01').gt_label.numpy()[0], 1)
def test_loss_empty_data_sample(self):
feats = (torch.rand(4, 10), )
data_samples = []
for _ in range(4):
data_sample = MultiTaskDataSample()
data_samples.append(data_sample)
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(
losses.keys(),
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
self.assertEqual(losses['task0_loss'].item(), 0)
self.assertEqual(losses['task1_loss'].item(), 0)
def test_nested_multi_task_loss(self):
head = MODELS.build(self.DEFAULT_ARGS2)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), )
outs = head(feats)
self.assertEqual(outs['task0']['task01'].shape, (4, 6))
self.assertTrue(isinstance(outs, dict))
self.assertTrue(isinstance(outs['task0'], dict))
def test_nested_invalid_sample(self):
feats = (torch.rand(4, 10), )
gt_label = {'task0': 1, 'task1': 1}
head = MODELS.build(self.DEFAULT_ARGS2)
data_sample = MultiTaskDataSample()
for task_name in gt_label:
task_sample = DataSample().set_gt_label(gt_label[task_name])
data_sample.set_field(task_sample, task_name)
with self.assertRaises(Exception):
head.loss(feats, data_sample)
def test_nested_invalid_sample2(self):
feats = (torch.rand(4, 10), )
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
head = MODELS.build(self.DEFAULT_ARGS)
data_sample = MultiTaskDataSample()
task_sample = DataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
with self.assertRaises(Exception):
head.loss(feats, data_sample)
class TestArcFaceClsHead(TestCase):
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)
def test_initialize(self):
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 0})
# Test margins
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'margins': dict()})
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4})
with self.assertRaises(AssertionError):
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4 + ['0.1']})
arcface = MODELS.build(self.DEFAULT_ARGS)
torch.allclose(arcface.margins, torch.tensor([0.5] * 5))
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 5})
torch.allclose(arcface.margins, torch.tensor([0.1] * 5))
margins = [0.1, 0.2, 0.3, 0.4, 5]
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_path = os.path.join(tmpdirname, 'margins.txt')
with open(tmp_path, 'w') as tmp_file:
for m in margins:
tmp_file.write(f'{m}\n')
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': tmp_path})
torch.allclose(arcface.margins, torch.tensor(margins))
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
# Test with SubCenterArcFace
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# target is not None
feats = (torch.rand(4, 10), torch.rand(4, 10))
target = torch.zeros(4).long()
outs = head(feats, target)
self.assertEqual(outs.shape, (4, 5))
# target is None
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 5))
# Test with SubCenterArcFace
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
# target is not None
feats = (torch.rand(4, 10), torch.rand(4, 10))
target = torch.zeros(4)
outs = head(feats, target)
self.assertEqual(outs.shape, (4, 5))
# target is None
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 5))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# test loss with used='before'
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with SubCenterArcFace
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
# test loss with used='before'
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)