From dcf61173f623297ceab83f17a538ad3712990442 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 13 Apr 2021 13:52:14 +0800 Subject: [PATCH] [Feature]Add cal_acc to cls_head.py (#206) * add cal_acc to cls_head.py * test ClsHead with cal_acc * 4 spaces indentation --- mmcls/models/heads/cls_head.py | 19 ++++++++++++++----- tests/test_heads.py | 22 +++++++++++++++++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index 77d0ba26..056fea06 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -13,11 +13,15 @@ class ClsHead(BaseHead): Args: loss (dict): Config of classification loss. topk (int | tuple): Top-k accuracy. + cal_acc (bool): Whether to calculate accuracy during training. + If you use Mixup/CutMix or something like that during training, + it is not reasonable to calculate accuracy. Defaults to True. """ # noqa: W605 def __init__(self, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, )): + topk=(1, ), + cal_acc=True): super(ClsHead, self).__init__() assert isinstance(loss, dict) @@ -30,17 +34,22 @@ class ClsHead(BaseHead): self.compute_loss = build_loss(loss) self.compute_accuracy = Accuracy(topk=self.topk) + self.cal_acc = cal_acc def loss(self, cls_score, gt_label): num_samples = len(cls_score) losses = dict() # compute loss loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples) - # compute accuracy - acc = self.compute_accuracy(cls_score, gt_label) - assert len(acc) == len(self.topk) + if self.cal_acc: + # compute accuracy + acc = self.compute_accuracy(cls_score, gt_label) + assert len(acc) == len(self.topk) + losses['accuracy'] = { + f'top-{k}': a + for k, a in zip(self.topk, acc) + } losses['loss'] = loss - losses['accuracy'] = {f'top-{k}': a for k, a in zip(self.topk, acc)} return losses def forward_train(self, cls_score, gt_label): diff --git a/tests/test_heads.py b/tests/test_heads.py index f3a70526..b40d8157 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -1,6 +1,26 @@ import torch -from mmcls.models.heads import MultiLabelClsHead, MultiLabelLinearClsHead +from mmcls.models.heads import (ClsHead, MultiLabelClsHead, + MultiLabelLinearClsHead) + + +def test_cls_head(): + + # test ClsHead with cal_acc=True + head = ClsHead() + fake_cls_score = torch.rand(4, 3) + fake_gt_label = torch.randint(0, 2, (4, )) + + losses = head.loss(fake_cls_score, fake_gt_label) + assert losses['loss'].item() > 0 + + # test ClsHead with cal_acc=False + head = ClsHead(cal_acc=False) + fake_cls_score = torch.rand(4, 3) + fake_gt_label = torch.randint(0, 2, (4, )) + + losses = head.loss(fake_cls_score, fake_gt_label) + assert losses['loss'].item() > 0 def test_multilabel_head():