[Feature]Add cal_acc to cls_head.py (#206)

* add cal_acc to cls_head.py

* test ClsHead with cal_acc

* 4 spaces indentation
pull/210/head
whcao 2021-04-13 13:52:14 +08:00 committed by GitHub
parent 5195932952
commit dcf61173f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 6 deletions

View File

@ -13,11 +13,15 @@ class ClsHead(BaseHead):
Args:
loss (dict): Config of classification loss.
topk (int | tuple): Top-k accuracy.
cal_acc (bool): Whether to calculate accuracy during training.
If you use Mixup/CutMix or something like that during training,
it is not reasonable to calculate accuracy. Defaults to True.
""" # noqa: W605
def __init__(self,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, )):
topk=(1, ),
cal_acc=True):
super(ClsHead, self).__init__()
assert isinstance(loss, dict)
@ -30,17 +34,22 @@ class ClsHead(BaseHead):
self.compute_loss = build_loss(loss)
self.compute_accuracy = Accuracy(topk=self.topk)
self.cal_acc = cal_acc
def loss(self, cls_score, gt_label):
num_samples = len(cls_score)
losses = dict()
# compute loss
loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
# compute accuracy
acc = self.compute_accuracy(cls_score, gt_label)
assert len(acc) == len(self.topk)
if self.cal_acc:
# compute accuracy
acc = self.compute_accuracy(cls_score, gt_label)
assert len(acc) == len(self.topk)
losses['accuracy'] = {
f'top-{k}': a
for k, a in zip(self.topk, acc)
}
losses['loss'] = loss
losses['accuracy'] = {f'top-{k}': a for k, a in zip(self.topk, acc)}
return losses
def forward_train(self, cls_score, gt_label):

View File

@ -1,6 +1,26 @@
import torch
from mmcls.models.heads import MultiLabelClsHead, MultiLabelLinearClsHead
from mmcls.models.heads import (ClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead)
def test_cls_head():
# test ClsHead with cal_acc=True
head = ClsHead()
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
# test ClsHead with cal_acc=False
head = ClsHead(cal_acc=False)
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
def test_multilabel_head():