31 lines
913 B
Python
31 lines
913 B
Python
import torch
|
|
|
|
from mmcls.models.classifiers import ImageClassifier
|
|
|
|
|
|
def test_image_classifier():
|
|
|
|
# Test mixup in ImageClassifier
|
|
model_cfg = dict(
|
|
backbone=dict(
|
|
type='ResNet_CIFAR',
|
|
depth=50,
|
|
num_stages=4,
|
|
out_indices=(3, ),
|
|
style='pytorch'),
|
|
neck=dict(type='GlobalAveragePooling'),
|
|
head=dict(
|
|
type='MultiLabelLinearClsHead',
|
|
num_classes=10,
|
|
in_channels=2048,
|
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
|
|
use_soft=True)),
|
|
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
|
|
img_classifier = ImageClassifier(**model_cfg)
|
|
img_classifier.init_weights()
|
|
imgs = torch.randn(16, 3, 32, 32)
|
|
label = torch.randint(0, 10, (16, ))
|
|
|
|
losses = img_classifier.forward_train(imgs, label)
|
|
assert losses['loss'].item() > 0
|