mmpretrain/tests/test_models/test_heads.py

483 lines
17 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2022-07-08 14:51:09 +08:00
import copy
import random
2022-06-09 21:48:12 +08:00
from unittest import TestCase
2022-07-08 14:51:09 +08:00
import numpy as np
import torch
2022-06-09 21:48:12 +08:00
from mmengine import is_seq_of
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
2022-06-09 21:48:12 +08:00
from mmcls.utils import register_all_modules
2022-06-09 21:48:12 +08:00
register_all_modules()
2022-07-08 14:51:09 +08:00
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
2022-06-09 21:48:12 +08:00
class TestClsHead(TestCase):
DEFAULT_ARGS = dict(type='ClsHead')
FAKE_FEATS = (torch.rand(4, 10), )
2022-06-09 21:48:12 +08:00
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
2022-06-09 21:48:12 +08:00
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
2022-06-09 21:48:12 +08:00
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
2022-06-09 21:48:12 +08:00
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertIs(outs, feats[-1])
2022-06-09 21:48:12 +08:00
def test_loss(self):
feats = self.FAKE_FEATS
2022-06-09 21:48:12 +08:00
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
2022-06-09 21:48:12 +08:00
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
2022-06-09 21:48:12 +08:00
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
2022-06-09 21:48:12 +08:00
# with cal_acc = True
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
head = MODELS.build(cfg)
2022-06-09 21:48:12 +08:00
losses = head.loss(feats, data_samples)
self.assertEqual(losses.keys(),
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
self.assertGreater(losses['loss'].item(), 0)
2022-06-09 21:48:12 +08:00
# test assertion when cal_acc but data is batch agumented.
data_samples = [
sample.set_gt_score(torch.rand(10)) for sample in data_samples
]
cfg = {
**self.DEFAULT_ARGS, 'cal_acc': True,
'loss': dict(type='CrossEntropyLoss', use_soft=True)
}
head = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
head.loss(feats, data_samples)
2022-06-09 21:48:12 +08:00
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
2022-06-09 21:48:12 +08:00
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
2022-06-09 21:48:12 +08:00
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
2022-06-09 21:48:12 +08:00
class TestLinearClsHead(TestCase):
DEFAULT_ARGS = dict(type='LinearClsHead', in_channels=10, num_classes=5)
FAKE_FEATS = (torch.rand(4, 10), )
2022-06-09 21:48:12 +08:00
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 5))
class TestVisionTransformerClsHead(TestCase):
DEFAULT_ARGS = dict(
type='VisionTransformerClsHead', in_channels=10, num_classes=5)
fake_feats = ([torch.rand(4, 7, 7, 16), torch.rand(4, 10)], )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
# test vit head default
head = MODELS.build(self.DEFAULT_ARGS)
assert not hasattr(head.layers, 'pre_logits')
assert not hasattr(head.layers, 'act')
# test vit head hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
assert hasattr(head.layers, 'pre_logits')
assert hasattr(head.layers, 'act')
# test vit head init_weights
head = MODELS.build(self.DEFAULT_ARGS)
head.init_weights()
# test vit head init_weights with hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1][1])
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
pre_logits = head.pre_logits(self.fake_feats)
self.assertEqual(pre_logits.shape, (4, 30))
def test_forward(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
class TestDeiTClsHead(TestVisionTransformerClsHead):
DEFAULT_ARGS = dict(type='DeiTClsHead', in_channels=10, num_classes=5)
fake_feats = ([
torch.rand(4, 7, 7, 16),
torch.rand(4, 10),
torch.rand(4, 10)
], )
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
cls_token, dist_token = head.pre_logits(self.fake_feats)
self.assertIs(cls_token, self.fake_feats[-1][1])
self.assertIs(dist_token, self.fake_feats[-1][2])
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
cls_token, dist_token = head.pre_logits(self.fake_feats)
self.assertEqual(cls_token.shape, (4, 30))
self.assertEqual(dist_token.shape, (4, 30))
class TestConformerHead(TestCase):
DEFAULT_ARGS = dict(
type='ConformerHead', in_channels=[64, 96], num_classes=5)
fake_feats = ([torch.rand(4, 64), torch.rand(4, 96)], )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
# test default
head = MODELS.build(self.DEFAULT_ARGS)
assert hasattr(head, 'conv_cls_head')
assert hasattr(head, 'trans_cls_head')
# test init_weights
head = MODELS.build(self.DEFAULT_ARGS)
head.init_weights()
assert abs(head.conv_cls_head.weight).sum() > 0
assert abs(head.trans_cls_head.weight).sum() > 0
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
outs = head(self.fake_feats)
self.assertEqual(outs[0].shape, (4, 5))
self.assertEqual(outs[1].shape, (4, 5))
def test_loss(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(self.fake_feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# with cal_acc = True
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
head = MODELS.build(cfg)
losses = head.loss(self.fake_feats, data_samples)
self.assertEqual(losses.keys(),
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
self.assertGreater(losses['loss'].item(), 0)
# test assertion when cal_acc but data is batch agumented.
data_samples = [
sample.set_gt_score(torch.rand(5)) for sample in data_samples
]
cfg = {
**self.DEFAULT_ARGS, 'cal_acc': True,
'loss': dict(type='CrossEntropyLoss', use_soft=True)
}
head = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
head.loss(self.fake_feats, data_samples)
def test_predict(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(self.fake_feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# with with data_samples
predictions = head.predict(self.fake_feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
class TestStackedLinearClsHead(TestCase):
DEFAULT_ARGS = dict(
type='StackedLinearClsHead', in_channels=10, num_classes=5)
fake_feats = (torch.rand(4, 10), )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({
**self.DEFAULT_ARGS, 'num_classes': -5,
'mid_channels': 10
})
# test mid_channels
with self.assertRaisesRegex(AssertionError, 'should be a sequence'):
MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': 10})
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20]})
assert len(head.layers) == 2
head.init_weights()
def test_pre_logits(self):
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
pre_logits = head.pre_logits(self.fake_feats)
self.assertEqual(pre_logits.shape, (4, 30))
def test_forward(self):
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
head = MODELS.build({
**self.DEFAULT_ARGS, 'mid_channels': [8, 10],
'dropout_rate': 0.2,
'norm_cfg': dict(type='BN1d'),
'act_cfg': dict(type='HSwish')
})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
2022-07-08 14:51:09 +08:00
class TestMultiLabelClsHead(TestCase):
DEFAULT_ARGS = dict(type='MultiLabelClsHead')
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertIs(outs, feats[-1])
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
# Test with thr and topk are all None
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, None, cfg)
self.assertEqual(head.topk, 2)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with thr
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
head = MODELS.build(cfg)
thr_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, None)
self.assertEqual(thr_losses.keys(), {'loss'})
self.assertGreater(thr_losses['loss'].item(), 0)
# Test with thr and topk are all not None
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
cfg['topk'] = 2
head = MODELS.build(cfg)
thr_topk_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, 2)
self.assertEqual(thr_topk_losses.keys(), {'loss'})
self.assertGreater(thr_topk_losses['loss'].item(), 0)
# Test with gt_lable with score
data_samples = [
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
]
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
predictions = head.predict(feats, data_samples)
self.assertEqual(head.thr, None)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
class EfficientFormerClsHead(TestClsHead):
DEFAULT_ARGS = dict(
type='EfficientFormerClsHead',
in_channels=10,
num_classes=10,
distillation=False)
FAKE_FEATS = (torch.rand(4, 10), )
def test_forward(self):
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['distillation'] = True
head = MODELS.build(cfg)
self.assertTrue(hasattr(head, 'dist_head'))
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 10))
# test without distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
head = MODELS.build(cfg)
self.assertFalse(hasattr(head, 'dist_head'))
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertEqual(outs.shape, (4, 10))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['distillation'] = True
head = MODELS.build(cfg)
with self.assertRaisesRegex(NotImplementedError, 'MMClassification '):
head.loss(feats, data_samples)
# test without distillation head
super().test_loss()
2022-07-08 14:51:09 +08:00
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
DEFAULT_ARGS = dict(
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
self.assertTrue(hasattr(head, 'fc'))
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
head(feats)