diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py
index 77d0ba26..056fea06 100644
--- a/mmcls/models/heads/cls_head.py
+++ b/mmcls/models/heads/cls_head.py
@@ -13,11 +13,15 @@ class ClsHead(BaseHead):
     Args:
         loss (dict): Config of classification loss.
         topk (int | tuple): Top-k accuracy.
+        cal_acc (bool): Whether to calculate accuracy during training.
+            If you use Mixup/CutMix or something like that during training,
+            it is not reasonable to calculate accuracy. Defaults to True.
     """  # noqa: W605
 
     def __init__(self,
                  loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
-                 topk=(1, )):
+                 topk=(1, ),
+                 cal_acc=True):
         super(ClsHead, self).__init__()
 
         assert isinstance(loss, dict)
@@ -30,17 +34,22 @@ class ClsHead(BaseHead):
 
         self.compute_loss = build_loss(loss)
         self.compute_accuracy = Accuracy(topk=self.topk)
+        self.cal_acc = cal_acc
 
     def loss(self, cls_score, gt_label):
         num_samples = len(cls_score)
         losses = dict()
         # compute loss
         loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
-        # compute accuracy
-        acc = self.compute_accuracy(cls_score, gt_label)
-        assert len(acc) == len(self.topk)
+        if self.cal_acc:
+            # compute accuracy
+            acc = self.compute_accuracy(cls_score, gt_label)
+            assert len(acc) == len(self.topk)
+            losses['accuracy'] = {
+                f'top-{k}': a
+                for k, a in zip(self.topk, acc)
+            }
         losses['loss'] = loss
-        losses['accuracy'] = {f'top-{k}': a for k, a in zip(self.topk, acc)}
         return losses
 
     def forward_train(self, cls_score, gt_label):
diff --git a/tests/test_heads.py b/tests/test_heads.py
index f3a70526..b40d8157 100644
--- a/tests/test_heads.py
+++ b/tests/test_heads.py
@@ -1,6 +1,26 @@
 import torch
 
-from mmcls.models.heads import MultiLabelClsHead, MultiLabelLinearClsHead
+from mmcls.models.heads import (ClsHead, MultiLabelClsHead,
+                                MultiLabelLinearClsHead)
+
+
+def test_cls_head():
+
+    # test ClsHead with cal_acc=True
+    head = ClsHead()
+    fake_cls_score = torch.rand(4, 3)
+    fake_gt_label = torch.randint(0, 2, (4, ))
+
+    losses = head.loss(fake_cls_score, fake_gt_label)
+    assert losses['loss'].item() > 0
+
+    # test ClsHead with cal_acc=False
+    head = ClsHead(cal_acc=False)
+    fake_cls_score = torch.rand(4, 3)
+    fake_gt_label = torch.randint(0, 2, (4, ))
+
+    losses = head.loss(fake_cls_score, fake_gt_label)
+    assert losses['loss'].item() > 0
 
 
 def test_multilabel_head():