mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add mytrain.py for test * test before layers * test attr in layers * test classifier * delete mytrain.py * set cal_acc in ClsHead defaults to False * set cal_acc defaults to False * use *args, **kwargs instead * change bs16 to 3 in test_image_classifier_vit * fix some comments * change cal_acc=True * test LinearClsHead
51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
import torch
|
|
|
|
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
|
|
MultiLabelLinearClsHead)
|
|
|
|
|
|
def test_cls_head():
|
|
|
|
# test ClsHead with cal_acc=False
|
|
head = ClsHead()
|
|
fake_cls_score = torch.rand(4, 3)
|
|
fake_gt_label = torch.randint(0, 2, (4, ))
|
|
|
|
losses = head.loss(fake_cls_score, fake_gt_label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
# test ClsHead with cal_acc=True
|
|
head = ClsHead(cal_acc=True)
|
|
fake_cls_score = torch.rand(4, 3)
|
|
fake_gt_label = torch.randint(0, 2, (4, ))
|
|
|
|
losses = head.loss(fake_cls_score, fake_gt_label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
# test LinearClsHead
|
|
head = LinearClsHead(10, 100)
|
|
fake_cls_score = torch.rand(4, 10)
|
|
fake_gt_label = torch.randint(0, 10, (4, ))
|
|
|
|
losses = head.loss(fake_cls_score, fake_gt_label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
|
|
def test_multilabel_head():
|
|
head = MultiLabelClsHead()
|
|
fake_cls_score = torch.rand(4, 3)
|
|
fake_gt_label = torch.randint(0, 2, (4, 3))
|
|
|
|
losses = head.loss(fake_cls_score, fake_gt_label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
|
|
def test_multilabel_linear_head():
|
|
head = MultiLabelLinearClsHead(3, 5)
|
|
fake_cls_score = torch.rand(4, 3)
|
|
fake_gt_label = torch.randint(0, 2, (4, 3))
|
|
|
|
head.init_weights()
|
|
losses = head.loss(fake_cls_score, fake_gt_label)
|
|
assert losses['loss'].item() > 0
|