215 lines
7.9 KiB
Python
215 lines
7.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
from mmengine import ConfigDict
|
|
|
|
from mmcls.models import ImageClassifier
|
|
from mmcls.registry import MODELS
|
|
from mmcls.structures import ClsDataSample
|
|
from mmcls.utils import register_all_modules
|
|
|
|
register_all_modules()
|
|
|
|
|
|
class TestImageClassifier(TestCase):
|
|
DEFAULT_ARGS = dict(
|
|
type='ImageClassifier',
|
|
backbone=dict(type='ResNet', depth=18),
|
|
neck=dict(type='GlobalAveragePooling'),
|
|
head=dict(
|
|
type='LinearClsHead',
|
|
num_classes=10,
|
|
in_channels=512,
|
|
loss=dict(type='CrossEntropyLoss')))
|
|
|
|
def test_initialize(self):
|
|
model = MODELS.build(self.DEFAULT_ARGS)
|
|
self.assertTrue(model.with_neck)
|
|
self.assertTrue(model.with_head)
|
|
|
|
cfg = {**self.DEFAULT_ARGS, 'pretrained': 'checkpoint'}
|
|
model = MODELS.build(cfg)
|
|
self.assertDictEqual(model.init_cfg,
|
|
dict(type='Pretrained', checkpoint='checkpoint'))
|
|
|
|
cfg = ConfigDict(self.DEFAULT_ARGS)
|
|
cfg.pop('neck')
|
|
model = MODELS.build(cfg)
|
|
self.assertFalse(model.with_neck)
|
|
|
|
cfg = ConfigDict(self.DEFAULT_ARGS)
|
|
cfg.pop('head')
|
|
model = MODELS.build(cfg)
|
|
self.assertFalse(model.with_head)
|
|
|
|
# test set batch augmentation from train_cfg
|
|
cfg = {
|
|
**self.DEFAULT_ARGS, 'train_cfg':
|
|
dict(augments=dict(type='Mixup', alpha=1., num_classes=10))
|
|
}
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
self.assertIsNotNone(model.data_preprocessor.batch_augments)
|
|
|
|
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
self.assertIsNone(model.data_preprocessor.batch_augments)
|
|
|
|
def test_extract_feat(self):
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
cfg = ConfigDict(self.DEFAULT_ARGS)
|
|
cfg.backbone.out_indices = (0, 1, 2, 3)
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
|
|
# test backbone output
|
|
feats = model.extract_feat(inputs, stage='backbone')
|
|
self.assertEqual(len(feats), 4)
|
|
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
|
|
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
|
|
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
|
|
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
|
|
|
|
# test neck output
|
|
feats = model.extract_feat(inputs, stage='neck')
|
|
self.assertEqual(len(feats), 4)
|
|
self.assertEqual(feats[0].shape, (1, 64))
|
|
self.assertEqual(feats[1].shape, (1, 128))
|
|
self.assertEqual(feats[2].shape, (1, 256))
|
|
self.assertEqual(feats[3].shape, (1, 512))
|
|
|
|
# test pre_logits output
|
|
feats = model.extract_feat(inputs, stage='pre_logits')
|
|
self.assertEqual(feats.shape, (1, 512))
|
|
|
|
# TODO: test transformer style feature extraction
|
|
|
|
# test extract_feats
|
|
multi_feats = model.extract_feats([inputs, inputs], stage='backbone')
|
|
self.assertEqual(len(multi_feats), 2)
|
|
for feats in multi_feats:
|
|
self.assertEqual(len(feats), 4)
|
|
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
|
|
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
|
|
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
|
|
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
|
|
|
|
# Without neck, return backbone
|
|
cfg = ConfigDict(self.DEFAULT_ARGS)
|
|
cfg.backbone.out_indices = (0, 1, 2, 3)
|
|
cfg.pop('neck')
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
feats = model.extract_feat(inputs, stage='neck')
|
|
self.assertEqual(len(feats), 4)
|
|
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
|
|
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
|
|
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
|
|
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
|
|
|
|
# Without head, raise error
|
|
cfg = ConfigDict(self.DEFAULT_ARGS)
|
|
cfg.backbone.out_indices = (0, 1, 2, 3)
|
|
cfg.pop('head')
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
with self.assertRaisesRegex(AssertionError, 'No head or the head'):
|
|
model.extract_feat(inputs, stage='pre_logits')
|
|
|
|
with self.assertRaisesRegex(AssertionError, 'use `extract_feat`'):
|
|
model.extract_feats(inputs)
|
|
|
|
def test_loss(self):
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
data_samples = [ClsDataSample().set_gt_label(1)]
|
|
|
|
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
|
losses = model.loss(inputs, data_samples)
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
def test_predict(self):
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
data_samples = [ClsDataSample().set_gt_label(1)]
|
|
|
|
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
|
predictions = model.predict(inputs)
|
|
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
|
|
|
predictions = model.predict(inputs, data_samples)
|
|
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
|
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
|
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
|
predictions[0].pred_label.score)
|
|
|
|
def test_forward(self):
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
data_samples = [ClsDataSample().set_gt_label(1)]
|
|
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# test pure forward
|
|
outs = model(inputs)
|
|
self.assertIsInstance(outs, torch.Tensor)
|
|
|
|
# test forward train
|
|
losses = model(inputs, data_samples, mode='loss')
|
|
self.assertGreater(losses['loss'].item(), 0)
|
|
|
|
# test forward test
|
|
predictions = model(inputs, mode='predict')
|
|
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
|
|
|
predictions = model(inputs, data_samples, mode='predict')
|
|
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
|
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
|
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
|
predictions[0].pred_label.score)
|
|
|
|
# test forward with invalid mode
|
|
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
|
|
model(inputs, mode='unknown')
|
|
|
|
def test_train_step(self):
|
|
cfg = {
|
|
**self.DEFAULT_ARGS, 'data_preprocessor':
|
|
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
|
|
}
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
|
|
data = [{
|
|
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
|
'data_sample': ClsDataSample().set_gt_label(1)
|
|
}]
|
|
|
|
optim_wrapper = MagicMock()
|
|
log_vars = model.train_step(data, optim_wrapper)
|
|
self.assertIn('loss', log_vars)
|
|
optim_wrapper.update_params.assert_called_once()
|
|
|
|
def test_val_step(self):
|
|
cfg = {
|
|
**self.DEFAULT_ARGS, 'data_preprocessor':
|
|
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
|
|
}
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
|
|
data = [{
|
|
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
|
'data_sample': ClsDataSample().set_gt_label(1)
|
|
}]
|
|
|
|
predictions = model.val_step(data)
|
|
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
|
|
|
def test_test_step(self):
|
|
cfg = {
|
|
**self.DEFAULT_ARGS, 'data_preprocessor':
|
|
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
|
|
}
|
|
model: ImageClassifier = MODELS.build(cfg)
|
|
|
|
data = [{
|
|
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
|
'data_sample': ClsDataSample().set_gt_label(1)
|
|
}]
|
|
|
|
predictions = model.test_step(data)
|
|
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|