diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index 056fea06..e875394c 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -15,13 +15,13 @@ class ClsHead(BaseHead): topk (int | tuple): Top-k accuracy. cal_acc (bool): Whether to calculate accuracy during training. If you use Mixup/CutMix or something like that during training, - it is not reasonable to calculate accuracy. Defaults to True. - """ # noqa: W605 + it is not reasonable to calculate accuracy. Defaults to False. + """ def __init__(self, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, ), - cal_acc=True): + cal_acc=False): super(ClsHead, self).__init__() assert isinstance(loss, dict) diff --git a/mmcls/models/heads/linear_head.py b/mmcls/models/heads/linear_head.py index 12c4671a..0fc64fe8 100644 --- a/mmcls/models/heads/linear_head.py +++ b/mmcls/models/heads/linear_head.py @@ -15,15 +15,10 @@ class LinearClsHead(ClsHead): num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. - loss (dict): Config of classification loss. - """ # noqa: W605 + """ - def __init__(self, - num_classes, - in_channels, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, )): - super(LinearClsHead, self).__init__(loss=loss, topk=topk) + def __init__(self, num_classes, in_channels, *args, **kwargs): + super(LinearClsHead, self).__init__(*args, **kwargs) self.in_channels = in_channels self.num_classes = num_classes diff --git a/mmcls/models/heads/vision_transformer_head.py b/mmcls/models/heads/vision_transformer_head.py index 9d6ec6ce..b1d36ce3 100644 --- a/mmcls/models/heads/vision_transformer_head.py +++ b/mmcls/models/heads/vision_transformer_head.py @@ -21,22 +21,16 @@ class VisionTransformerClsHead(ClsHead): available during pre-training. Default None. act_cfg (dict): The activation config. Only available during pre-training. Defalut Tanh. - loss (dict): Config of classification loss. - topk (int | tuple): Top-k accuracy. - cal_acc (bool): Whether to calculate accuracy during training. - If mixup is used, this should be False. Default False. - """ # noqa: W605 + """ def __init__(self, num_classes, in_channels, hidden_dim=None, act_cfg=dict(type='Tanh'), - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, ), - cal_acc=False): - super(VisionTransformerClsHead, self).__init__( - loss=loss, topk=topk, cal_acc=cal_acc) + *args, + **kwargs): + super(VisionTransformerClsHead, self).__init__(*args, **kwargs) self.in_channels = in_channels self.num_classes = num_classes self.hidden_dim = hidden_dim diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index c143c368..56979af2 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -110,8 +110,8 @@ def test_image_classifier_vit(): train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000))) img_classifier = ImageClassifier(**model_cfg) img_classifier.init_weights() - imgs = torch.randn(16, 3, 224, 224) - label = torch.randint(0, 1000, (16, )) + imgs = torch.randn(3, 3, 224, 224) + label = torch.randint(0, 1000, (3, )) losses = img_classifier.forward_train(imgs, label) assert losses['loss'].item() > 0 diff --git a/tests/test_heads.py b/tests/test_heads.py index b40d8157..0e7dcad1 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -1,12 +1,12 @@ import torch -from mmcls.models.heads import (ClsHead, MultiLabelClsHead, +from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead, MultiLabelLinearClsHead) def test_cls_head(): - # test ClsHead with cal_acc=True + # test ClsHead with cal_acc=False head = ClsHead() fake_cls_score = torch.rand(4, 3) fake_gt_label = torch.randint(0, 2, (4, )) @@ -14,14 +14,22 @@ def test_cls_head(): losses = head.loss(fake_cls_score, fake_gt_label) assert losses['loss'].item() > 0 - # test ClsHead with cal_acc=False - head = ClsHead(cal_acc=False) + # test ClsHead with cal_acc=True + head = ClsHead(cal_acc=True) fake_cls_score = torch.rand(4, 3) fake_gt_label = torch.randint(0, 2, (4, )) losses = head.loss(fake_cls_score, fake_gt_label) assert losses['loss'].item() > 0 + # test LinearClsHead + head = LinearClsHead(10, 100) + fake_cls_score = torch.rand(4, 10) + fake_gt_label = torch.randint(0, 10, (4, )) + + losses = head.loss(fake_cls_score, fake_gt_label) + assert losses['loss'].item() > 0 + def test_multilabel_head(): head = MultiLabelClsHead()