mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add imagenet bs 4096 * add vit_base_patch16_224_finetune * add vit_base_patch16_224_pretrain * add vit_base_patch16_384_finetune * add vit_base_patch16_384_finetune * add vit_b_p16_224_finetune_imagenet * add vit_b_p16_224_pretrain_imagenet * add vit_b_p16_384_finetune_imagenet * add vit * add vit * add vit head * vit unitest * keep up with ClsHead * test vit * add flag to determiine whether to calculate acc during training * Changes related to mmcv1.3.0 * change checkpoint saving interval to 10 * add label smooth * default_runtime.py recovery * docformatter * docformatter * delete 2 lines of comments * delete configs/_base_/schedules/imagenet_bs4096.py * add configs/_base_/schedules/imagenet_bs2048_AdamW.py * rename imagenet_bs4096.py to imagenet_bs2048_AdamW.py * add helpers.py * test vit hybrid backbone * fix HybridEmbed * use to_2tuple instead
118 lines
3.5 KiB
Python
118 lines
3.5 KiB
Python
import torch
|
|
|
|
from mmcls.models.classifiers import ImageClassifier
|
|
|
|
|
|
def test_image_classifier():
|
|
|
|
# Test mixup in ImageClassifier
|
|
model_cfg = dict(
|
|
backbone=dict(
|
|
type='ResNet_CIFAR',
|
|
depth=50,
|
|
num_stages=4,
|
|
out_indices=(3, ),
|
|
style='pytorch'),
|
|
neck=dict(type='GlobalAveragePooling'),
|
|
head=dict(
|
|
type='MultiLabelLinearClsHead',
|
|
num_classes=10,
|
|
in_channels=2048,
|
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
|
|
use_soft=True)),
|
|
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
|
|
img_classifier = ImageClassifier(**model_cfg)
|
|
img_classifier.init_weights()
|
|
imgs = torch.randn(16, 3, 32, 32)
|
|
label = torch.randint(0, 10, (16, ))
|
|
|
|
losses = img_classifier.forward_train(imgs, label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
|
|
def test_image_classifier_with_cutmix():
|
|
|
|
# Test cutmix in ImageClassifier
|
|
model_cfg = dict(
|
|
backbone=dict(
|
|
type='ResNet_CIFAR',
|
|
depth=50,
|
|
num_stages=4,
|
|
out_indices=(3, ),
|
|
style='pytorch'),
|
|
neck=dict(type='GlobalAveragePooling'),
|
|
head=dict(
|
|
type='MultiLabelLinearClsHead',
|
|
num_classes=10,
|
|
in_channels=2048,
|
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
|
|
use_soft=True)),
|
|
train_cfg=dict(
|
|
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0)))
|
|
img_classifier = ImageClassifier(**model_cfg)
|
|
img_classifier.init_weights()
|
|
imgs = torch.randn(16, 3, 32, 32)
|
|
label = torch.randint(0, 10, (16, ))
|
|
|
|
losses = img_classifier.forward_train(imgs, label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
|
|
def test_image_classifier_with_label_smooth_loss():
|
|
|
|
# Test mixup in ImageClassifier
|
|
model_cfg = dict(
|
|
backbone=dict(
|
|
type='ResNet_CIFAR',
|
|
depth=50,
|
|
num_stages=4,
|
|
out_indices=(3, ),
|
|
style='pytorch'),
|
|
neck=dict(type='GlobalAveragePooling'),
|
|
head=dict(
|
|
type='MultiLabelLinearClsHead',
|
|
num_classes=10,
|
|
in_channels=2048,
|
|
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1)),
|
|
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
|
|
img_classifier = ImageClassifier(**model_cfg)
|
|
img_classifier.init_weights()
|
|
imgs = torch.randn(16, 3, 32, 32)
|
|
label = torch.randint(0, 10, (16, ))
|
|
|
|
losses = img_classifier.forward_train(imgs, label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
|
|
def test_image_classifier_vit():
|
|
|
|
model_cfg = dict(
|
|
backbone=dict(
|
|
type='VisionTransformer',
|
|
num_layers=12,
|
|
embed_dim=768,
|
|
num_heads=12,
|
|
img_size=224,
|
|
patch_size=16,
|
|
in_channels=3,
|
|
feedforward_channels=3072,
|
|
drop_rate=0.1,
|
|
attn_drop_rate=0.),
|
|
neck=None,
|
|
head=dict(
|
|
type='VisionTransformerClsHead',
|
|
num_classes=1000,
|
|
in_channels=768,
|
|
hidden_dim=3072,
|
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True),
|
|
topk=(1, 5),
|
|
),
|
|
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
|
|
img_classifier = ImageClassifier(**model_cfg)
|
|
img_classifier.init_weights()
|
|
imgs = torch.randn(16, 3, 224, 224)
|
|
label = torch.randint(0, 1000, (16, ))
|
|
|
|
losses = img_classifier.forward_train(imgs, label)
|
|
assert losses['loss'].item() > 0
|