mmpretrain/tests/test_models/test_classifiers.py

472 lines
17 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest import TestCase
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from mmengine import ConfigDict
from mmpretrain.models import ImageClassifier
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
def has_timm() -> bool:
try:
import timm # noqa: F401
return True
except ImportError:
return False
def has_huggingface() -> bool:
try:
import transformers # noqa: F401
return True
except ImportError:
return False
class TestImageClassifier(TestCase):
DEFAULT_ARGS = dict(
type='ImageClassifier',
backbone=dict(type='ResNet', depth=18),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss')))
def test_initialize(self):
model = MODELS.build(self.DEFAULT_ARGS)
self.assertTrue(model.with_neck)
self.assertTrue(model.with_head)
cfg = {**self.DEFAULT_ARGS, 'pretrained': 'checkpoint'}
model = MODELS.build(cfg)
self.assertDictEqual(model.init_cfg,
dict(type='Pretrained', checkpoint='checkpoint'))
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.pop('neck')
model = MODELS.build(cfg)
self.assertFalse(model.with_neck)
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.pop('head')
model = MODELS.build(cfg)
self.assertFalse(model.with_head)
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1.))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNone(model.data_preprocessor.batch_augments)
def test_extract_feat(self):
inputs = torch.rand(1, 3, 224, 224)
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.backbone.out_indices = (0, 1, 2, 3)
model: ImageClassifier = MODELS.build(cfg)
# test backbone output
feats = model.extract_feat(inputs, stage='backbone')
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
# test neck output
feats = model.extract_feat(inputs, stage='neck')
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64))
self.assertEqual(feats[1].shape, (1, 128))
self.assertEqual(feats[2].shape, (1, 256))
self.assertEqual(feats[3].shape, (1, 512))
# test pre_logits output
feats = model.extract_feat(inputs, stage='pre_logits')
self.assertEqual(feats.shape, (1, 512))
# TODO: test transformer style feature extraction
# test extract_feats
multi_feats = model.extract_feats([inputs, inputs], stage='backbone')
self.assertEqual(len(multi_feats), 2)
for feats in multi_feats:
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
# Without neck, return backbone
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.backbone.out_indices = (0, 1, 2, 3)
cfg.pop('neck')
model: ImageClassifier = MODELS.build(cfg)
feats = model.extract_feat(inputs, stage='neck')
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
# Without head, raise error
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.backbone.out_indices = (0, 1, 2, 3)
cfg.pop('head')
model: ImageClassifier = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'No head or the head'):
model.extract_feat(inputs, stage='pre_logits')
with self.assertRaisesRegex(AssertionError, 'use `extract_feat`'):
model.extract_feats(inputs)
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
outs = model(inputs)
self.assertIsInstance(outs, torch.Tensor)
# test forward train
losses = model(inputs, data_samples, mode='loss')
self.assertGreater(losses['loss'].item(), 0)
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
model(inputs, mode='unknown')
def test_train_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
self.assertIn('loss', log_vars)
optim_wrapper.update_params.assert_called_once()
def test_val_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
def test_test_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
@unittest.skipIf(not has_timm(), 'timm is not installed.')
class TestTimmClassifier(TestCase):
DEFAULT_ARGS = dict(
type='TimmClassifier',
model_name='resnet18',
loss=dict(type='CrossEntropyLoss'),
)
def test_initialize(self):
model = MODELS.build(self.DEFAULT_ARGS)
assert isinstance(model.model, nn.Module)
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1.))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNone(model.data_preprocessor.batch_augments)
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
outs = model(inputs)
self.assertIsInstance(outs, torch.Tensor)
# test forward train
losses = model(inputs, data_samples, mode='loss')
self.assertGreater(losses['loss'].item(), 0)
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
model(inputs, mode='unknown')
def test_train_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
self.assertIn('loss', log_vars)
optim_wrapper.update_params.assert_called_once()
def test_val_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
@unittest.skipIf(not has_huggingface(), 'huggingface is not installed.')
class TestHuggingFaceClassifier(TestCase):
DEFAULT_ARGS = dict(
type='HuggingFaceClassifier',
model_name='microsoft/resnet-18',
loss=dict(type='CrossEntropyLoss'),
)
def test_initialize(self):
model = MODELS.build(self.DEFAULT_ARGS)
assert isinstance(model.model, nn.Module)
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1.))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNone(model.data_preprocessor.batch_augments)
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
outs = model(inputs)
self.assertIsInstance(outs, torch.Tensor)
# test forward train
losses = model(inputs, data_samples, mode='loss')
self.assertGreater(losses['loss'].item(), 0)
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
model(inputs, mode='unknown')
def test_train_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
self.assertIn('loss', log_vars)
optim_wrapper.update_params.assert_called_once()
def test_val_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))