# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase from unittest.mock import MagicMock import torch from mmengine import ConfigDict from mmcls.engine import ClsDataSample from mmcls.models import ImageClassifier from mmcls.registry import MODELS from mmcls.utils import register_all_modules register_all_modules() class TestImageClassifier(TestCase): DEFAULT_ARGS = dict( type='ImageClassifier', backbone=dict(type='ResNet', depth=18), neck=dict(type='GlobalAveragePooling'), head=dict( type='LinearClsHead', num_classes=10, in_channels=512, loss=dict(type='CrossEntropyLoss'))) def test_initialize(self): model = MODELS.build(self.DEFAULT_ARGS) self.assertTrue(model.with_neck) self.assertTrue(model.with_head) cfg = {**self.DEFAULT_ARGS, 'pretrained': 'checkpoint'} model = MODELS.build(cfg) self.assertDictEqual(model.init_cfg, dict(type='Pretrained', checkpoint='checkpoint')) cfg = ConfigDict(self.DEFAULT_ARGS) cfg.pop('neck') model = MODELS.build(cfg) self.assertFalse(model.with_neck) cfg = ConfigDict(self.DEFAULT_ARGS) cfg.pop('head') model = MODELS.build(cfg) self.assertFalse(model.with_head) # test set batch augmentation from train_cfg cfg = { **self.DEFAULT_ARGS, 'train_cfg': dict(augments=dict(type='Mixup', alpha=1., num_classes=10)) } model: ImageClassifier = MODELS.build(cfg) self.assertIsNotNone(model.data_preprocessor.batch_augments) cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()} model: ImageClassifier = MODELS.build(cfg) self.assertIsNone(model.data_preprocessor.batch_augments) def test_extract_feat(self): inputs = torch.rand(1, 3, 224, 224) cfg = ConfigDict(self.DEFAULT_ARGS) cfg.backbone.out_indices = (0, 1, 2, 3) model: ImageClassifier = MODELS.build(cfg) # test backbone output feats = model.extract_feat(inputs, stage='backbone') self.assertEqual(len(feats), 4) self.assertEqual(feats[0].shape, (1, 64, 56, 56)) self.assertEqual(feats[1].shape, (1, 128, 28, 28)) self.assertEqual(feats[2].shape, (1, 256, 14, 14)) self.assertEqual(feats[3].shape, (1, 512, 7, 7)) # test neck output feats = model.extract_feat(inputs, stage='neck') self.assertEqual(len(feats), 4) self.assertEqual(feats[0].shape, (1, 64)) self.assertEqual(feats[1].shape, (1, 128)) self.assertEqual(feats[2].shape, (1, 256)) self.assertEqual(feats[3].shape, (1, 512)) # test pre_logits output feats = model.extract_feat(inputs, stage='pre_logits') self.assertEqual(feats.shape, (1, 512)) # TODO: test transformer style feature extraction # test extract_feats multi_feats = model.extract_feats([inputs, inputs], stage='backbone') self.assertEqual(len(multi_feats), 2) for feats in multi_feats: self.assertEqual(len(feats), 4) self.assertEqual(feats[0].shape, (1, 64, 56, 56)) self.assertEqual(feats[1].shape, (1, 128, 28, 28)) self.assertEqual(feats[2].shape, (1, 256, 14, 14)) self.assertEqual(feats[3].shape, (1, 512, 7, 7)) # Without neck, return backbone cfg = ConfigDict(self.DEFAULT_ARGS) cfg.backbone.out_indices = (0, 1, 2, 3) cfg.pop('neck') model: ImageClassifier = MODELS.build(cfg) feats = model.extract_feat(inputs, stage='neck') self.assertEqual(len(feats), 4) self.assertEqual(feats[0].shape, (1, 64, 56, 56)) self.assertEqual(feats[1].shape, (1, 128, 28, 28)) self.assertEqual(feats[2].shape, (1, 256, 14, 14)) self.assertEqual(feats[3].shape, (1, 512, 7, 7)) # Without head, raise error cfg = ConfigDict(self.DEFAULT_ARGS) cfg.backbone.out_indices = (0, 1, 2, 3) cfg.pop('head') model: ImageClassifier = MODELS.build(cfg) with self.assertRaisesRegex(AssertionError, 'No head or the head'): model.extract_feat(inputs, stage='pre_logits') with self.assertRaisesRegex(AssertionError, 'use `extract_feat`'): model.extract_feats(inputs) def test_loss(self): inputs = torch.rand(1, 3, 224, 224) data_samples = [ClsDataSample().set_gt_label(1)] model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) losses = model.loss(inputs, data_samples) self.assertGreater(losses['loss'].item(), 0) def test_predict(self): inputs = torch.rand(1, 3, 224, 224) data_samples = [ClsDataSample().set_gt_label(1)] model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) predictions = model.predict(inputs) self.assertEqual(predictions[0].pred_label.score.shape, (10, )) predictions = model.predict(inputs, data_samples) self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) torch.testing.assert_allclose(data_samples[0].pred_label.score, predictions[0].pred_label.score) def test_forward(self): inputs = torch.rand(1, 3, 224, 224) data_samples = [ClsDataSample().set_gt_label(1)] model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) # test pure forward outs = model(inputs) self.assertIsInstance(outs, torch.Tensor) # test forward train losses = model(inputs, data_samples, mode='loss') self.assertGreater(losses['loss'].item(), 0) # test forward test predictions = model(inputs, mode='predict') self.assertEqual(predictions[0].pred_label.score.shape, (10, )) predictions = model(inputs, data_samples, mode='predict') self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) torch.testing.assert_allclose(data_samples[0].pred_label.score, predictions[0].pred_label.score) # test forward with invalid mode with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'): model(inputs, mode='unknown') def test_train_step(self): cfg = { **self.DEFAULT_ARGS, 'data_preprocessor': dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]) } model: ImageClassifier = MODELS.build(cfg) data = [{ 'inputs': torch.randint(0, 256, (3, 224, 224)), 'data_sample': ClsDataSample().set_gt_label(1) }] optim_wrapper = MagicMock() log_vars = model.train_step(data, optim_wrapper) self.assertIn('loss', log_vars) optim_wrapper.update_params.assert_called_once() def test_val_step(self): cfg = { **self.DEFAULT_ARGS, 'data_preprocessor': dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]) } model: ImageClassifier = MODELS.build(cfg) data = [{ 'inputs': torch.randint(0, 256, (3, 224, 224)), 'data_sample': ClsDataSample().set_gt_label(1) }] predictions = model.val_step(data) self.assertEqual(predictions[0].pred_label.score.shape, (10, )) def test_test_step(self): cfg = { **self.DEFAULT_ARGS, 'data_preprocessor': dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]) } model: ImageClassifier = MODELS.build(cfg) data = [{ 'inputs': torch.randint(0, 256, (3, 224, 224)), 'data_sample': ClsDataSample().set_gt_label(1) }] predictions = model.test_step(data) self.assertEqual(predictions[0].pred_label.score.shape, (10, ))