mmclassification/tests/test_classifiers.py
whcao affb39fe07
[Feature]Add Vit (#214)
* add imagenet bs 4096

* add vit_base_patch16_224_finetune

* add vit_base_patch16_224_pretrain

* add vit_base_patch16_384_finetune

* add vit_base_patch16_384_finetune

* add vit_b_p16_224_finetune_imagenet

* add vit_b_p16_224_pretrain_imagenet

* add vit_b_p16_384_finetune_imagenet

* add vit

* add vit

* add vit head

* vit unitest

* keep up with ClsHead

* test vit

* add flag to determiine whether to calculate acc during training

* Changes related to mmcv1.3.0

* change checkpoint saving interval to 10

* add label smooth

* default_runtime.py recovery

* docformatter

* docformatter

* delete 2 lines of comments

* delete configs/_base_/schedules/imagenet_bs4096.py

* add configs/_base_/schedules/imagenet_bs2048_AdamW.py

* rename imagenet_bs4096.py to imagenet_bs2048_AdamW.py

* add helpers.py

* test vit hybrid backbone

* fix HybridEmbed

* use to_2tuple instead
2021-04-16 19:22:41 +08:00

118 lines
3.5 KiB
Python

import torch
from mmcls.models.classifiers import ImageClassifier
def test_image_classifier():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_cutmix():
# Test cutmix in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_label_smooth_loss():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1)),
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_vit():
model_cfg = dict(
backbone=dict(
type='VisionTransformer',
num_layers=12,
embed_dim=768,
num_heads=12,
img_size=224,
patch_size=16,
in_channels=3,
feedforward_channels=3072,
drop_rate=0.1,
attn_drop_rate=0.),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=3072,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True),
topk=(1, 5),
),
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0