mmclassification/tests/test_classifiers.py
whcao 3a08db9182
[Feature]Add augments to models/utils (#278)
* add mytrain.py for test

* test before layers

* test attr in layers

* test classifier

* delete mytrain.py

* add rand_bbox_minmax rand_bbox and cutmix_bbox_and_lam to BaseCutMixLayer

* add mixup_prob to BatchMixupLayer

* add cutmixup

* add cutmixup to __init__

* test classifier with cutmixup

* delete some comments

* set mixup_prob default to 1.0

* add cutmixup to classifier

* use cutmixup

* use cutmixup

* fix bugs

* test cutmixup

* move mixup and cutmix to augment

* inherit from BaseAugment

* add BaseAugment

* inherit from BaseAugment

* rename identity.py

* add @

* build augment

* register module

* rename to augment.py

* delete cutmixup.py

* do not inherit from BaseAugment

* add augments

* use augments in classifier

* prob default to 1.0

* add comments

* use augments

* use augments

* assert sum of augmentation probabilities should equal to 1

* augmentation probabilities equal to 1

* calculate Identity prob

* replace xxx with self.xxx

* add comments

* sync with augments

* for BC-breaking

* delete useless comments in mixup.py
2021-06-20 09:44:51 +08:00

222 lines
6.9 KiB
Python

import torch
from mmcls.models.classifiers import ImageClassifier
def test_image_classifier():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
augments=dict(
type='BatchMixup', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Considering BC-breaking
model_cfg['train_cfg'] = dict(mixup=dict(alpha=1.0, num_classes=10))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_cutmix():
# Test cutmix in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
augments=dict(
type='BatchCutMix', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Considering BC-breaking
model_cfg['train_cfg'] = dict(
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_augments():
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
# Test cutmix and mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(augments=[
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
dict(type='Identity', num_classes=10, prob=0.2)
]))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test cutmix with cutmix_minmax in ImageClassifier
model_cfg['train_cfg'] = dict(
augments=dict(
type='BatchCutMix',
alpha=1.,
num_classes=10,
prob=1.,
cutmix_minmax=[0.2, 0.8]))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test not using cutmix and mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
# Test not using cutmix and mixup in ImageClassifier
model_cfg['train_cfg'] = dict(augments=None)
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_label_smooth_loss():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1)),
train_cfg=dict(
augments=dict(
type='BatchMixup', alpha=1., num_classes=10, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_vit():
model_cfg = dict(
backbone=dict(
type='VisionTransformer',
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.1,
attn_drop_rate=0.),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=3072,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True),
topk=(1, 5),
),
train_cfg=dict(
augments=dict(
type='BatchMixup', alpha=0.2, num_classes=1000, prob=1.)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(2, 3, 224, 224)
label = torch.randint(0, 1000, (2, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0