import torch from mmcls.models.classifiers import ImageClassifier def test_image_classifier(): # Test mixup in ImageClassifier model_cfg = dict( backbone=dict( type='ResNet_CIFAR', depth=50, num_stages=4, out_indices=(3, ), style='pytorch'), neck=dict(type='GlobalAveragePooling'), head=dict( type='MultiLabelLinearClsHead', num_classes=10, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10))) img_classifier = ImageClassifier(**model_cfg) img_classifier.init_weights() imgs = torch.randn(16, 3, 32, 32) label = torch.randint(0, 10, (16, )) losses = img_classifier.forward_train(imgs, label) assert losses['loss'].item() > 0