mmpretrain/tests/test_models/test_classifiers.py

307 lines
9.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
from mmcv import ConfigDict
from mmcls.models import CLASSIFIERS
from mmcls.models.classifiers import ImageClassifier
def test_image_classifier():
model_cfg = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss')))
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
model_cfg_ = deepcopy(model_cfg)
model = CLASSIFIERS.build(model_cfg_)
# test property
assert model.with_neck
assert model.with_head
# test train_step
outputs = model.train_step({'img': imgs, 'gt_label': label}, None)
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test train_step without optimizer
outputs = model.train_step({'img': imgs, 'gt_label': label})
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test val_step
outputs = model.val_step({'img': imgs, 'gt_label': label}, None)
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test val_step without optimizer
outputs = model.val_step({'img': imgs, 'gt_label': label})
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test forward
losses = model(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0
# test forward_test
model_cfg_ = deepcopy(model_cfg)
model = CLASSIFIERS.build(model_cfg_)
pred = model(imgs, return_loss=False, img_metas=None)
assert isinstance(pred, list) and len(pred) == 16
single_img = torch.randn(1, 3, 32, 32)
pred = model(single_img, return_loss=False, img_metas=None)
assert isinstance(pred, list) and len(pred) == 1
pred = model.simple_test(imgs, softmax=False)
assert isinstance(pred, list) and len(pred) == 16
assert len(pred[0] == 10)
pred = model.simple_test(imgs, softmax=False, post_process=False)
assert isinstance(pred, torch.Tensor)
assert pred.shape == (16, 10)
soft_pred = model.simple_test(imgs, softmax=True, post_process=False)
assert isinstance(soft_pred, torch.Tensor)
assert soft_pred.shape == (16, 10)
torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1))
# test pretrained
model_cfg_ = deepcopy(model_cfg)
model_cfg_['pretrained'] = 'checkpoint'
model = CLASSIFIERS.build(model_cfg_)
assert model.init_cfg == dict(type='Pretrained', checkpoint='checkpoint')
def test_image_classifier_with_mixup():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
augments=dict(
type='BatchMixup', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_cutmix():
# Test cutmix in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
augments=dict(
type='BatchCutMix', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_augments():
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
# Test cutmix and mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(augments=[
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
dict(type='Identity', num_classes=10, prob=0.2)
]))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test cutmix with cutmix_minmax in ImageClassifier
model_cfg['train_cfg'] = dict(
augments=dict(
type='BatchCutMix',
alpha=1.,
num_classes=10,
prob=1.,
cutmix_minmax=[0.2, 0.8]))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test not using train_cfg
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test not using cutmix and mixup in ImageClassifier
model_cfg['train_cfg'] = dict(augments=None)
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_classifier_extract_feat():
model_cfg = ConfigDict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss'),
topk=(1, 5),
))
model = CLASSIFIERS.build(model_cfg)
# test backbone output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
assert outs[0].shape == (1, 64, 56, 56)
assert outs[1].shape == (1, 128, 28, 28)
assert outs[2].shape == (1, 256, 14, 14)
assert outs[3].shape == (1, 512, 7, 7)
# test neck output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
assert outs[0].shape == (1, 64)
assert outs[1].shape == (1, 128)
assert outs[2].shape == (1, 256)
assert outs[3].shape == (1, 512)
# test pre_logits output
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
assert out.shape == (1, 512)
# test transformer style feature extraction
model_cfg = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer', arch='b', out_indices=[-3, -2, -1]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=1024,
loss=dict(type='CrossEntropyLoss'),
))
model = CLASSIFIERS.build(model_cfg)
# test backbone output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
for out in outs:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
# test neck output (the same with backbone)
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
for out in outs:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
# test pre_logits output
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
assert out.shape == (1, 1024)
# test extract_feats
multi_imgs = [torch.rand(1, 3, 224, 224) for _ in range(3)]
outs = model.extract_feats(multi_imgs)
for outs_per_img in outs:
for out in outs_per_img:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
outs = model.extract_feats(multi_imgs, stage='pre_logits')
for out_per_img in outs:
assert out_per_img.shape == (1, 1024)