737 lines
26 KiB
Python
737 lines
26 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
import random
|
|
import tempfile
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmengine import is_seq_of
|
|
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample, MultiTaskDataSample
|
|
|
|
|
|
def setup_seed(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
class TestClsHead(TestCase):
|
|
DEFAULT_ARGS = dict(type='ClsHead')
|
|
FAKE_FEATS = (torch.rand(4, 10), )
|
|
|
|
def test_pre_logits(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# return the last item
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
pre_logits = head.pre_logits(feats)
|
|
self.assertIs(pre_logits, feats[-1])
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# return the last item (same as pre_logits)
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertIs(outs, feats[-1])
|
|
|
|
def test_loss(self):
|
|
feats = self.FAKE_FEATS
|
|
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
|
|
|
# with cal_acc = False
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# with cal_acc = True
|
|
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
|
|
head = MODELS.build(cfg)
|
|
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(losses.keys(),
|
|
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# test assertion when cal_acc but data is batch agumented.
|
|
data_samples = [
|
|
sample.set_gt_score(torch.rand(10)) for sample in data_samples
|
|
]
|
|
cfg = {
|
|
**self.DEFAULT_ARGS, 'cal_acc': True,
|
|
'loss': dict(type='CrossEntropyLoss', use_soft=True)
|
|
}
|
|
head = MODELS.build(cfg)
|
|
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
|
|
head.loss(feats, data_samples)
|
|
|
|
def test_predict(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# with without data_samples
|
|
predictions = head.predict(feats)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for pred in predictions:
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
# with with data_samples
|
|
predictions = head.predict(feats, data_samples)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for sample, pred in zip(data_samples, predictions):
|
|
self.assertIs(sample, pred)
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
|
|
class TestLinearClsHead(TestCase):
|
|
DEFAULT_ARGS = dict(type='LinearClsHead', in_channels=10, num_classes=5)
|
|
FAKE_FEATS = (torch.rand(4, 10), )
|
|
|
|
def test_initialize(self):
|
|
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
|
|
|
def test_pre_logits(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# return the last item
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
pre_logits = head.pre_logits(feats)
|
|
self.assertIs(pre_logits, feats[-1])
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
|
|
class TestVisionTransformerClsHead(TestCase):
|
|
DEFAULT_ARGS = dict(
|
|
type='VisionTransformerClsHead', in_channels=10, num_classes=5)
|
|
fake_feats = ([torch.rand(4, 7, 7, 16), torch.rand(4, 10)], )
|
|
|
|
def test_initialize(self):
|
|
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
|
|
|
# test vit head default
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
assert not hasattr(head.layers, 'pre_logits')
|
|
assert not hasattr(head.layers, 'act')
|
|
|
|
# test vit head hidden_dim
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
|
assert hasattr(head.layers, 'pre_logits')
|
|
assert hasattr(head.layers, 'act')
|
|
|
|
# test vit head init_weights
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
head.init_weights()
|
|
|
|
# test vit head init_weights with hidden_dim
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
|
head.init_weights()
|
|
assert abs(head.layers.pre_logits.weight).sum() > 0
|
|
|
|
def test_pre_logits(self):
|
|
# test default
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
pre_logits = head.pre_logits(self.fake_feats)
|
|
self.assertIs(pre_logits, self.fake_feats[-1][1])
|
|
|
|
# test hidden_dim
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
|
pre_logits = head.pre_logits(self.fake_feats)
|
|
self.assertEqual(pre_logits.shape, (4, 30))
|
|
|
|
def test_forward(self):
|
|
# test default
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
outs = head(self.fake_feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
# test hidden_dim
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
|
outs = head(self.fake_feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
|
|
class TestDeiTClsHead(TestVisionTransformerClsHead):
|
|
DEFAULT_ARGS = dict(type='DeiTClsHead', in_channels=10, num_classes=5)
|
|
fake_feats = ([
|
|
torch.rand(4, 7, 7, 16),
|
|
torch.rand(4, 10),
|
|
torch.rand(4, 10)
|
|
], )
|
|
|
|
def test_pre_logits(self):
|
|
# test default
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
cls_token, dist_token = head.pre_logits(self.fake_feats)
|
|
self.assertIs(cls_token, self.fake_feats[-1][1])
|
|
self.assertIs(dist_token, self.fake_feats[-1][2])
|
|
|
|
# test hidden_dim
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
|
cls_token, dist_token = head.pre_logits(self.fake_feats)
|
|
self.assertEqual(cls_token.shape, (4, 30))
|
|
self.assertEqual(dist_token.shape, (4, 30))
|
|
|
|
|
|
class TestConformerHead(TestCase):
|
|
DEFAULT_ARGS = dict(
|
|
type='ConformerHead', in_channels=[64, 96], num_classes=5)
|
|
fake_feats = ([torch.rand(4, 64), torch.rand(4, 96)], )
|
|
|
|
def test_initialize(self):
|
|
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
|
|
|
# test default
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
assert hasattr(head, 'conv_cls_head')
|
|
assert hasattr(head, 'trans_cls_head')
|
|
|
|
# test init_weights
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
head.init_weights()
|
|
assert abs(head.conv_cls_head.weight).sum() > 0
|
|
assert abs(head.trans_cls_head.weight).sum() > 0
|
|
|
|
def test_pre_logits(self):
|
|
# test default
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
pre_logits = head.pre_logits(self.fake_feats)
|
|
self.assertIs(pre_logits, self.fake_feats[-1])
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
outs = head(self.fake_feats)
|
|
self.assertEqual(outs[0].shape, (4, 5))
|
|
self.assertEqual(outs[1].shape, (4, 5))
|
|
|
|
def test_loss(self):
|
|
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
|
|
|
# with cal_acc = False
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
losses = head.loss(self.fake_feats, data_samples)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# with cal_acc = True
|
|
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
|
|
head = MODELS.build(cfg)
|
|
|
|
losses = head.loss(self.fake_feats, data_samples)
|
|
self.assertEqual(losses.keys(),
|
|
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# test assertion when cal_acc but data is batch agumented.
|
|
data_samples = [
|
|
sample.set_gt_score(torch.rand(5)) for sample in data_samples
|
|
]
|
|
cfg = {
|
|
**self.DEFAULT_ARGS, 'cal_acc': True,
|
|
'loss': dict(type='CrossEntropyLoss', use_soft=True)
|
|
}
|
|
head = MODELS.build(cfg)
|
|
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
|
|
head.loss(self.fake_feats, data_samples)
|
|
|
|
def test_predict(self):
|
|
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# with without data_samples
|
|
predictions = head.predict(self.fake_feats)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for pred in predictions:
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
# with with data_samples
|
|
predictions = head.predict(self.fake_feats, data_samples)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for sample, pred in zip(data_samples, predictions):
|
|
self.assertIs(sample, pred)
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
|
|
class TestStackedLinearClsHead(TestCase):
|
|
DEFAULT_ARGS = dict(
|
|
type='StackedLinearClsHead', in_channels=10, num_classes=5)
|
|
fake_feats = (torch.rand(4, 10), )
|
|
|
|
def test_initialize(self):
|
|
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
|
MODELS.build({
|
|
**self.DEFAULT_ARGS, 'num_classes': -5,
|
|
'mid_channels': 10
|
|
})
|
|
|
|
# test mid_channels
|
|
with self.assertRaisesRegex(AssertionError, 'should be a sequence'):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': 10})
|
|
|
|
# test default
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20]})
|
|
assert len(head.layers) == 2
|
|
head.init_weights()
|
|
|
|
def test_pre_logits(self):
|
|
# test default
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
|
|
pre_logits = head.pre_logits(self.fake_feats)
|
|
self.assertEqual(pre_logits.shape, (4, 30))
|
|
|
|
def test_forward(self):
|
|
# test default
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
|
|
outs = head(self.fake_feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
head = MODELS.build({
|
|
**self.DEFAULT_ARGS, 'mid_channels': [8, 10],
|
|
'dropout_rate': 0.2,
|
|
'norm_cfg': dict(type='BN1d'),
|
|
'act_cfg': dict(type='HSwish')
|
|
})
|
|
outs = head(self.fake_feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
|
|
class TestMultiLabelClsHead(TestCase):
|
|
DEFAULT_ARGS = dict(type='MultiLabelClsHead')
|
|
|
|
def test_pre_logits(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# return the last item
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
pre_logits = head.pre_logits(feats)
|
|
self.assertIs(pre_logits, feats[-1])
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# return the last item (same as pre_logits)
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertIs(outs, feats[-1])
|
|
|
|
def test_loss(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = [DataSample().set_gt_label([0, 3]) for _ in range(4)]
|
|
|
|
# Test with thr and topk are all None
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(head.thr, 0.5)
|
|
self.assertEqual(head.topk, None)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# Test with topk
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
cfg['topk'] = 2
|
|
head = MODELS.build(cfg)
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(head.thr, None, cfg)
|
|
self.assertEqual(head.topk, 2)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# Test with thr
|
|
setup_seed(0)
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
cfg['thr'] = 0.1
|
|
head = MODELS.build(cfg)
|
|
thr_losses = head.loss(feats, data_samples)
|
|
self.assertEqual(head.thr, 0.1)
|
|
self.assertEqual(head.topk, None)
|
|
self.assertEqual(thr_losses.keys(), {'loss'})
|
|
self.assertGreater(thr_losses['loss'].item(), 0)
|
|
|
|
# Test with thr and topk are all not None
|
|
setup_seed(0)
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
cfg['thr'] = 0.1
|
|
cfg['topk'] = 2
|
|
head = MODELS.build(cfg)
|
|
thr_topk_losses = head.loss(feats, data_samples)
|
|
self.assertEqual(head.thr, 0.1)
|
|
self.assertEqual(head.topk, 2)
|
|
self.assertEqual(thr_topk_losses.keys(), {'loss'})
|
|
self.assertGreater(thr_topk_losses['loss'].item(), 0)
|
|
|
|
# Test with gt_lable with score
|
|
data_samples = [
|
|
DataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
|
|
]
|
|
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(head.thr, 0.5)
|
|
self.assertEqual(head.topk, None)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
def test_predict(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = [DataSample().set_gt_label([1, 2]) for _ in range(4)]
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# with without data_samples
|
|
predictions = head.predict(feats)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for pred in predictions:
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
# with with data_samples
|
|
predictions = head.predict(feats, data_samples)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for sample, pred in zip(data_samples, predictions):
|
|
self.assertIs(sample, pred)
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
# Test with topk
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
cfg['topk'] = 2
|
|
head = MODELS.build(cfg)
|
|
predictions = head.predict(feats, data_samples)
|
|
self.assertEqual(head.thr, None)
|
|
self.assertTrue(is_seq_of(predictions, DataSample))
|
|
for sample, pred in zip(data_samples, predictions):
|
|
self.assertIs(sample, pred)
|
|
self.assertIn('pred_label', pred)
|
|
self.assertIn('pred_score', pred)
|
|
|
|
|
|
class EfficientFormerClsHead(TestClsHead):
|
|
DEFAULT_ARGS = dict(
|
|
type='EfficientFormerClsHead',
|
|
in_channels=10,
|
|
num_classes=10,
|
|
distillation=False)
|
|
FAKE_FEATS = (torch.rand(4, 10), )
|
|
|
|
def test_forward(self):
|
|
# test with distillation head
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
cfg['distillation'] = True
|
|
head = MODELS.build(cfg)
|
|
self.assertTrue(hasattr(head, 'dist_head'))
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertEqual(outs.shape, (4, 10))
|
|
|
|
# test without distillation head
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
head = MODELS.build(cfg)
|
|
self.assertFalse(hasattr(head, 'dist_head'))
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertEqual(outs.shape, (4, 10))
|
|
|
|
def test_loss(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
|
|
|
# test with distillation head
|
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
|
cfg['distillation'] = True
|
|
head = MODELS.build(cfg)
|
|
with self.assertRaisesRegex(NotImplementedError, 'MMPretrain '):
|
|
head.loss(feats, data_samples)
|
|
|
|
# test without distillation head
|
|
super().test_loss()
|
|
|
|
|
|
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
|
|
DEFAULT_ARGS = dict(
|
|
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
self.assertTrue(hasattr(head, 'fc'))
|
|
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
|
|
|
|
# return the last item (same as pre_logits)
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
head(feats)
|
|
|
|
|
|
class TestMultiTaskHead(TestCase):
|
|
DEFAULT_ARGS = dict(
|
|
type='MultiTaskHead', # <- Head config, depends on #675
|
|
task_heads={
|
|
'task0': dict(type='LinearClsHead', num_classes=3),
|
|
'task1': dict(type='LinearClsHead', num_classes=6),
|
|
},
|
|
in_channels=10,
|
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
)
|
|
|
|
DEFAULT_ARGS2 = dict(
|
|
type='MultiTaskHead', # <- Head config, depends on #675
|
|
task_heads={
|
|
'task0':
|
|
dict(
|
|
type='MultiTaskHead',
|
|
task_heads={
|
|
'task00': dict(type='LinearClsHead', num_classes=3),
|
|
'task01': dict(type='LinearClsHead', num_classes=6),
|
|
}),
|
|
'task1':
|
|
dict(type='LinearClsHead', num_classes=6)
|
|
},
|
|
in_channels=10,
|
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
)
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
# return the last item (same as pre_logits)
|
|
feats = (torch.rand(4, 10), )
|
|
outs = head(feats)
|
|
self.assertEqual(outs['task0'].shape, (4, 3))
|
|
self.assertEqual(outs['task1'].shape, (4, 6))
|
|
self.assertTrue(isinstance(outs, dict))
|
|
|
|
def test_loss(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = []
|
|
|
|
for _ in range(4):
|
|
data_sample = MultiTaskDataSample()
|
|
for task_name in self.DEFAULT_ARGS['task_heads']:
|
|
task_sample = DataSample().set_gt_label(1)
|
|
data_sample.set_field(task_sample, task_name)
|
|
data_samples.append(data_sample)
|
|
# with cal_acc = False
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(
|
|
losses.keys(),
|
|
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
|
|
self.assertGreater(losses['task0_loss'].item(), 0)
|
|
self.assertGreater(losses['task1_loss'].item(), 0)
|
|
|
|
def test_predict(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = []
|
|
|
|
for _ in range(4):
|
|
data_sample = MultiTaskDataSample()
|
|
for task_name in self.DEFAULT_ARGS['task_heads']:
|
|
task_sample = DataSample().set_gt_label(1)
|
|
data_sample.set_field(task_sample, task_name)
|
|
data_samples.append(data_sample)
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
# without data_samples
|
|
predictions = head.predict(feats)
|
|
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
|
for pred in predictions:
|
|
self.assertIn('task0', pred)
|
|
task0_sample = predictions[0].task0
|
|
self.assertTrue(type(task0_sample.pred_score), 'torch.tensor')
|
|
|
|
# with with data_samples
|
|
predictions = head.predict(feats, data_samples)
|
|
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
|
for sample, pred in zip(data_samples, predictions):
|
|
self.assertIs(sample, pred)
|
|
self.assertIn('task0', pred)
|
|
|
|
# with data samples and nested
|
|
head_nested = MODELS.build(self.DEFAULT_ARGS2)
|
|
# adding a None data sample at the beginning
|
|
data_samples_nested = [None]
|
|
for _ in range(3):
|
|
data_sample_nested = MultiTaskDataSample()
|
|
data_sample_nested0 = MultiTaskDataSample()
|
|
data_sample_nested0.set_field(DataSample().set_gt_label(1),
|
|
'task00')
|
|
data_sample_nested0.set_field(DataSample().set_gt_label(1),
|
|
'task01')
|
|
data_sample_nested.set_field(data_sample_nested0, 'task0')
|
|
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
|
|
data_samples_nested.append(data_sample_nested)
|
|
|
|
predictions = head_nested.predict(feats, data_samples_nested)
|
|
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
|
for i in range(3):
|
|
sample = data_samples_nested[i + 1]
|
|
pred = predictions[i + 1]
|
|
self.assertIn('task0', pred)
|
|
self.assertIn('task1', pred)
|
|
self.assertIn('task01', pred.get('task0'))
|
|
self.assertEqual(
|
|
sample.get('task0').get('task01').gt_label.numpy()[0], 1)
|
|
|
|
def test_loss_empty_data_sample(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = []
|
|
|
|
for _ in range(4):
|
|
data_sample = MultiTaskDataSample()
|
|
data_samples.append(data_sample)
|
|
# with cal_acc = False
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(
|
|
losses.keys(),
|
|
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
|
|
self.assertEqual(losses['task0_loss'].item(), 0)
|
|
self.assertEqual(losses['task1_loss'].item(), 0)
|
|
|
|
def test_nested_multi_task_loss(self):
|
|
|
|
head = MODELS.build(self.DEFAULT_ARGS2)
|
|
# return the last item (same as pre_logits)
|
|
feats = (torch.rand(4, 10), )
|
|
outs = head(feats)
|
|
self.assertEqual(outs['task0']['task01'].shape, (4, 6))
|
|
self.assertTrue(isinstance(outs, dict))
|
|
self.assertTrue(isinstance(outs['task0'], dict))
|
|
|
|
def test_nested_invalid_sample(self):
|
|
feats = (torch.rand(4, 10), )
|
|
gt_label = {'task0': 1, 'task1': 1}
|
|
head = MODELS.build(self.DEFAULT_ARGS2)
|
|
data_sample = MultiTaskDataSample()
|
|
for task_name in gt_label:
|
|
task_sample = DataSample().set_gt_label(gt_label[task_name])
|
|
data_sample.set_field(task_sample, task_name)
|
|
with self.assertRaises(Exception):
|
|
head.loss(feats, data_sample)
|
|
|
|
def test_nested_invalid_sample2(self):
|
|
feats = (torch.rand(4, 10), )
|
|
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
data_sample = MultiTaskDataSample()
|
|
task_sample = DataSample().set_gt_label(gt_label['task1'])
|
|
data_sample.set_field(task_sample, 'task1')
|
|
data_sample.set_field(MultiTaskDataSample(), 'task0')
|
|
for task_name in gt_label['task0']:
|
|
task_sample = DataSample().set_gt_label(
|
|
gt_label['task0'][task_name])
|
|
data_sample.task0.set_field(task_sample, task_name)
|
|
with self.assertRaises(Exception):
|
|
head.loss(feats, data_sample)
|
|
|
|
|
|
class TestArcFaceClsHead(TestCase):
|
|
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)
|
|
|
|
def test_initialize(self):
|
|
with self.assertRaises(AssertionError):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
|
|
|
with self.assertRaises(AssertionError):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 0})
|
|
|
|
# Test margins
|
|
with self.assertRaises(AssertionError):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'margins': dict()})
|
|
|
|
with self.assertRaises(AssertionError):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4})
|
|
|
|
with self.assertRaises(AssertionError):
|
|
MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 4 + ['0.1']})
|
|
|
|
arcface = MODELS.build(self.DEFAULT_ARGS)
|
|
torch.allclose(arcface.margins, torch.tensor([0.5] * 5))
|
|
|
|
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': [0.1] * 5})
|
|
torch.allclose(arcface.margins, torch.tensor([0.1] * 5))
|
|
|
|
margins = [0.1, 0.2, 0.3, 0.4, 5]
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
tmp_path = os.path.join(tmpdirname, 'margins.txt')
|
|
with open(tmp_path, 'w') as tmp_file:
|
|
for m in margins:
|
|
tmp_file.write(f'{m}\n')
|
|
arcface = MODELS.build({**self.DEFAULT_ARGS, 'margins': tmp_path})
|
|
torch.allclose(arcface.margins, torch.tensor(margins))
|
|
|
|
def test_pre_logits(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# return the last item
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
pre_logits = head.pre_logits(feats)
|
|
self.assertIs(pre_logits, feats[-1])
|
|
|
|
# Test with SubCenterArcFace
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
pre_logits = head.pre_logits(feats)
|
|
self.assertIs(pre_logits, feats[-1])
|
|
|
|
def test_forward(self):
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
# target is not None
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
target = torch.zeros(4).long()
|
|
outs = head(feats, target)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
# target is None
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
# Test with SubCenterArcFace
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
|
|
# target is not None
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
target = torch.zeros(4)
|
|
outs = head(feats, target)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
# target is None
|
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
|
outs = head(feats)
|
|
self.assertEqual(outs.shape, (4, 5))
|
|
|
|
def test_loss(self):
|
|
feats = (torch.rand(4, 10), )
|
|
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
|
|
|
# test loss with used='before'
|
|
head = MODELS.build(self.DEFAULT_ARGS)
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# Test with SubCenterArcFace
|
|
head = MODELS.build({**self.DEFAULT_ARGS, 'num_subcenters': 3})
|
|
# test loss with used='before'
|
|
losses = head.loss(feats, data_samples)
|
|
self.assertEqual(losses.keys(), {'loss'})
|
|
self.assertGreater(losses['loss'].item(), 0)
|