mmclassification/tests/test_heads.py
whcao b30f79ea4c
[Feature]Modify Parameters Passing in models.heads (#239)
* add mytrain.py for test

* test before layers

* test attr in layers

* test classifier

* delete mytrain.py

* set cal_acc in ClsHead defaults to False

* set cal_acc defaults to False

* use *args, **kwargs instead

* change bs16 to 3 in test_image_classifier_vit

* fix some comments

* change cal_acc=True

* test LinearClsHead
2021-05-10 14:56:55 +08:00

51 lines
1.4 KiB
Python

import torch
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead)
def test_cls_head():
# test ClsHead with cal_acc=False
head = ClsHead()
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
# test ClsHead with cal_acc=True
head = ClsHead(cal_acc=True)
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
# test LinearClsHead
head = LinearClsHead(10, 100)
fake_cls_score = torch.rand(4, 10)
fake_gt_label = torch.randint(0, 10, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
def test_multilabel_head():
head = MultiLabelClsHead()
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, 3))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
def test_multilabel_linear_head():
head = MultiLabelLinearClsHead(3, 5)
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, 3))
head.init_weights()
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0