mmpretrain/tests/test_models/test_classifiers.py

472 lines
17 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import unittest
2022-06-09 21:48:12 +08:00
from unittest import TestCase
from unittest.mock import MagicMock
import torch
import torch.nn as nn
2022-06-09 21:48:12 +08:00
from mmengine import ConfigDict
from mmpretrain.models import ImageClassifier
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
2022-06-09 21:48:12 +08:00
def has_timm() -> bool:
try:
import timm # noqa: F401
return True
except ImportError:
return False
def has_huggingface() -> bool:
try:
import transformers # noqa: F401
return True
except ImportError:
return False
2022-06-09 21:48:12 +08:00
class TestImageClassifier(TestCase):
DEFAULT_ARGS = dict(
type='ImageClassifier',
2022-06-09 21:48:12 +08:00
backbone=dict(type='ResNet', depth=18),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
2022-06-09 21:48:12 +08:00
in_channels=512,
loss=dict(type='CrossEntropyLoss')))
2022-06-09 21:48:12 +08:00
def test_initialize(self):
model = MODELS.build(self.DEFAULT_ARGS)
self.assertTrue(model.with_neck)
self.assertTrue(model.with_head)
cfg = {**self.DEFAULT_ARGS, 'pretrained': 'checkpoint'}
model = MODELS.build(cfg)
self.assertDictEqual(model.init_cfg,
dict(type='Pretrained', checkpoint='checkpoint'))
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.pop('neck')
model = MODELS.build(cfg)
self.assertFalse(model.with_neck)
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.pop('head')
model = MODELS.build(cfg)
self.assertFalse(model.with_head)
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1.))
2022-06-09 21:48:12 +08:00
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNone(model.data_preprocessor.batch_augments)
def test_extract_feat(self):
inputs = torch.rand(1, 3, 224, 224)
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.backbone.out_indices = (0, 1, 2, 3)
model: ImageClassifier = MODELS.build(cfg)
# test backbone output
feats = model.extract_feat(inputs, stage='backbone')
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
# test neck output
feats = model.extract_feat(inputs, stage='neck')
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64))
self.assertEqual(feats[1].shape, (1, 128))
self.assertEqual(feats[2].shape, (1, 256))
self.assertEqual(feats[3].shape, (1, 512))
# test pre_logits output
feats = model.extract_feat(inputs, stage='pre_logits')
self.assertEqual(feats.shape, (1, 512))
# TODO: test transformer style feature extraction
# test extract_feats
multi_feats = model.extract_feats([inputs, inputs], stage='backbone')
self.assertEqual(len(multi_feats), 2)
for feats in multi_feats:
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
# Without neck, return backbone
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.backbone.out_indices = (0, 1, 2, 3)
cfg.pop('neck')
model: ImageClassifier = MODELS.build(cfg)
feats = model.extract_feat(inputs, stage='neck')
self.assertEqual(len(feats), 4)
self.assertEqual(feats[0].shape, (1, 64, 56, 56))
self.assertEqual(feats[1].shape, (1, 128, 28, 28))
self.assertEqual(feats[2].shape, (1, 256, 14, 14))
self.assertEqual(feats[3].shape, (1, 512, 7, 7))
# Without head, raise error
cfg = ConfigDict(self.DEFAULT_ARGS)
cfg.backbone.out_indices = (0, 1, 2, 3)
cfg.pop('head')
model: ImageClassifier = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'No head or the head'):
model.extract_feat(inputs, stage='pre_logits')
with self.assertRaisesRegex(AssertionError, 'use `extract_feat`'):
model.extract_feats(inputs)
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
2022-06-09 21:48:12 +08:00
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
2022-06-09 21:48:12 +08:00
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
2022-06-09 21:48:12 +08:00
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
2022-06-09 21:48:12 +08:00
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
2022-06-09 21:48:12 +08:00
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
outs = model(inputs)
self.assertIsInstance(outs, torch.Tensor)
# test forward train
losses = model(inputs, data_samples, mode='loss')
self.assertGreater(losses['loss'].item(), 0)
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (10, ))
2022-06-09 21:48:12 +08:00
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
2022-06-09 21:48:12 +08:00
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
model(inputs, mode='unknown')
def test_train_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
2022-06-09 21:48:12 +08:00
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
self.assertIn('loss', log_vars)
optim_wrapper.update_params.assert_called_once()
def test_val_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
2022-06-09 21:48:12 +08:00
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
2022-06-09 21:48:12 +08:00
def test_test_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
2022-06-09 21:48:12 +08:00
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
@unittest.skipIf(not has_timm(), 'timm is not installed.')
class TestTimmClassifier(TestCase):
DEFAULT_ARGS = dict(
type='TimmClassifier',
model_name='resnet18',
loss=dict(type='CrossEntropyLoss'),
)
def test_initialize(self):
model = MODELS.build(self.DEFAULT_ARGS)
assert isinstance(model.model, nn.Module)
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1.))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNone(model.data_preprocessor.batch_augments)
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
outs = model(inputs)
self.assertIsInstance(outs, torch.Tensor)
# test forward train
losses = model(inputs, data_samples, mode='loss')
self.assertGreater(losses['loss'].item(), 0)
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
model(inputs, mode='unknown')
def test_train_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
self.assertIn('loss', log_vars)
optim_wrapper.update_params.assert_called_once()
def test_val_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
@unittest.skipIf(not has_huggingface(), 'huggingface is not installed.')
class TestHuggingFaceClassifier(TestCase):
DEFAULT_ARGS = dict(
type='HuggingFaceClassifier',
model_name='microsoft/resnet-18',
loss=dict(type='CrossEntropyLoss'),
)
def test_initialize(self):
model = MODELS.build(self.DEFAULT_ARGS)
assert isinstance(model.model, nn.Module)
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1.))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)
cfg = {**self.DEFAULT_ARGS, 'train_cfg': dict()}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNone(model.data_preprocessor.batch_augments)
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
outs = model(inputs)
self.assertIsInstance(outs, torch.Tensor)
# test forward train
losses = model(inputs, data_samples, mode='loss')
self.assertGreater(losses['loss'].item(), 0)
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
model(inputs, mode='unknown')
def test_train_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
self.assertIn('loss', log_vars)
optim_wrapper.update_params.assert_called_once()
def test_val_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self):
cfg = {
**self.DEFAULT_ARGS, 'data_preprocessor':
dict(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
}
model: ImageClassifier = MODELS.build(cfg)
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))