mmclassification/tests/test_models/test_classifiers.py

297 lines
9.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from copy import deepcopy
import numpy as np
import pytest
import torch
from mmcv import ConfigDict
from mmcv.runner.base_module import BaseModule
from mmcls.models import CLASSIFIERS
from mmcls.models.classifiers import ImageClassifier
def test_image_classifier():
model_cfg = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss')))
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
model_cfg_ = deepcopy(model_cfg)
model = CLASSIFIERS.build(model_cfg_)
# test property
assert model.with_neck
assert model.with_head
# test train_step
outputs = model.train_step({'img': imgs, 'gt_label': label}, None)
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test val_step
outputs = model.val_step({'img': imgs, 'gt_label': label}, None)
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test forward
losses = model(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0
# test forward_test
model_cfg_ = deepcopy(model_cfg)
model = CLASSIFIERS.build(model_cfg_)
pred = model(imgs, return_loss=False, img_metas=None)
assert isinstance(pred, list) and len(pred) == 16
single_img = torch.randn(1, 3, 32, 32)
pred = model(single_img, return_loss=False, img_metas=None)
assert isinstance(pred, list) and len(pred) == 1
# test pretrained
# TODO remove deprecated pretrained
with pytest.warns(UserWarning):
model_cfg_ = deepcopy(model_cfg)
model_cfg_['pretrained'] = 'checkpoint'
model = CLASSIFIERS.build(model_cfg_)
assert model.init_cfg == dict(
type='Pretrained', checkpoint='checkpoint')
# test show_result
img = np.random.random_integers(0, 255, (224, 224, 3)).astype(np.uint8)
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
with tempfile.TemporaryDirectory() as tmpdir:
out_file = osp.join(tmpdir, 'out.png')
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)
with tempfile.TemporaryDirectory() as tmpdir:
out_file = osp.join(tmpdir, 'out.png')
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)
def test_image_classifier_with_mixup():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
augments=dict(
type='BatchMixup', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Considering BC-breaking
# TODO remove deprecated mixup usage.
model_cfg['train_cfg'] = dict(mixup=dict(alpha=1.0, num_classes=10))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_cutmix():
# Test cutmix in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
augments=dict(
type='BatchCutMix', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Considering BC-breaking
# TODO remove deprecated mixup usage.
model_cfg['train_cfg'] = dict(
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_augments():
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
# Test cutmix and mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(augments=[
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
dict(type='Identity', num_classes=10, prob=0.2)
]))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test cutmix with cutmix_minmax in ImageClassifier
model_cfg['train_cfg'] = dict(
augments=dict(
type='BatchCutMix',
alpha=1.,
num_classes=10,
prob=1.,
cutmix_minmax=[0.2, 0.8]))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test not using train_cfg
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test not using cutmix and mixup in ImageClassifier
model_cfg['train_cfg'] = dict(augments=None)
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_return_tuple():
model_cfg = ConfigDict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch',
return_tuple=False),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss')))
imgs = torch.randn(16, 3, 32, 32)
model_cfg_ = deepcopy(model_cfg)
with pytest.warns(DeprecationWarning):
model = CLASSIFIERS.build(model_cfg_)
# test backbone return tensor
feat = model.extract_feat(imgs)
assert isinstance(feat, torch.Tensor)
# test backbone return tuple
model_cfg_ = deepcopy(model_cfg)
model_cfg_.backbone.return_tuple = True
model = CLASSIFIERS.build(model_cfg_)
feat = model.extract_feat(imgs)
assert isinstance(feat, tuple)
# test warning if backbone return tensor
class ToyBackbone(BaseModule):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
def forward(self, x):
return self.conv(x)
model_cfg_ = deepcopy(model_cfg)
model_cfg_.backbone.return_tuple = True
model = CLASSIFIERS.build(model_cfg_)
model.backbone = ToyBackbone()
with pytest.warns(DeprecationWarning):
model.extract_feat(imgs)