From b30f79ea4ceb855d1a5540673d1951a3b981bf57 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Mon, 10 May 2021 14:56:55 +0800 Subject: [PATCH] [Feature]Modify Parameters Passing in models.heads (#239) * add mytrain.py for test * test before layers * test attr in layers * test classifier * delete mytrain.py * set cal_acc in ClsHead defaults to False * set cal_acc defaults to False * use *args, **kwargs instead * change bs16 to 3 in test_image_classifier_vit * fix some comments * change cal_acc=True * test LinearClsHead --- mmcls/models/heads/cls_head.py | 6 +++--- mmcls/models/heads/linear_head.py | 11 +++-------- mmcls/models/heads/vision_transformer_head.py | 14 ++++---------- tests/test_classifiers.py | 4 ++-- tests/test_heads.py | 16 ++++++++++++---- 5 files changed, 24 insertions(+), 27 deletions(-) diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index 056fea063..e875394cc 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -15,13 +15,13 @@ class ClsHead(BaseHead): topk (int | tuple): Top-k accuracy. cal_acc (bool): Whether to calculate accuracy during training. If you use Mixup/CutMix or something like that during training, - it is not reasonable to calculate accuracy. Defaults to True. - """ # noqa: W605 + it is not reasonable to calculate accuracy. Defaults to False. + """ def __init__(self, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, ), - cal_acc=True): + cal_acc=False): super(ClsHead, self).__init__() assert isinstance(loss, dict) diff --git a/mmcls/models/heads/linear_head.py b/mmcls/models/heads/linear_head.py index 12c4671a6..0fc64fe8d 100644 --- a/mmcls/models/heads/linear_head.py +++ b/mmcls/models/heads/linear_head.py @@ -15,15 +15,10 @@ class LinearClsHead(ClsHead): num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. - loss (dict): Config of classification loss. - """ # noqa: W605 + """ - def __init__(self, - num_classes, - in_channels, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, )): - super(LinearClsHead, self).__init__(loss=loss, topk=topk) + def __init__(self, num_classes, in_channels, *args, **kwargs): + super(LinearClsHead, self).__init__(*args, **kwargs) self.in_channels = in_channels self.num_classes = num_classes diff --git a/mmcls/models/heads/vision_transformer_head.py b/mmcls/models/heads/vision_transformer_head.py index 9d6ec6cee..b1d36ce32 100644 --- a/mmcls/models/heads/vision_transformer_head.py +++ b/mmcls/models/heads/vision_transformer_head.py @@ -21,22 +21,16 @@ class VisionTransformerClsHead(ClsHead): available during pre-training. Default None. act_cfg (dict): The activation config. Only available during pre-training. Defalut Tanh. - loss (dict): Config of classification loss. - topk (int | tuple): Top-k accuracy. - cal_acc (bool): Whether to calculate accuracy during training. - If mixup is used, this should be False. Default False. - """ # noqa: W605 + """ def __init__(self, num_classes, in_channels, hidden_dim=None, act_cfg=dict(type='Tanh'), - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, ), - cal_acc=False): - super(VisionTransformerClsHead, self).__init__( - loss=loss, topk=topk, cal_acc=cal_acc) + *args, + **kwargs): + super(VisionTransformerClsHead, self).__init__(*args, **kwargs) self.in_channels = in_channels self.num_classes = num_classes self.hidden_dim = hidden_dim diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index c143c368f..56979af25 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -110,8 +110,8 @@ def test_image_classifier_vit(): train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000))) img_classifier = ImageClassifier(**model_cfg) img_classifier.init_weights() - imgs = torch.randn(16, 3, 224, 224) - label = torch.randint(0, 1000, (16, )) + imgs = torch.randn(3, 3, 224, 224) + label = torch.randint(0, 1000, (3, )) losses = img_classifier.forward_train(imgs, label) assert losses['loss'].item() > 0 diff --git a/tests/test_heads.py b/tests/test_heads.py index b40d8157a..0e7dcad1a 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -1,12 +1,12 @@ import torch -from mmcls.models.heads import (ClsHead, MultiLabelClsHead, +from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead, MultiLabelLinearClsHead) def test_cls_head(): - # test ClsHead with cal_acc=True + # test ClsHead with cal_acc=False head = ClsHead() fake_cls_score = torch.rand(4, 3) fake_gt_label = torch.randint(0, 2, (4, )) @@ -14,14 +14,22 @@ def test_cls_head(): losses = head.loss(fake_cls_score, fake_gt_label) assert losses['loss'].item() > 0 - # test ClsHead with cal_acc=False - head = ClsHead(cal_acc=False) + # test ClsHead with cal_acc=True + head = ClsHead(cal_acc=True) fake_cls_score = torch.rand(4, 3) fake_gt_label = torch.randint(0, 2, (4, )) losses = head.loss(fake_cls_score, fake_gt_label) assert losses['loss'].item() > 0 + # test LinearClsHead + head = LinearClsHead(10, 100) + fake_cls_score = torch.rand(4, 10) + fake_gt_label = torch.randint(0, 10, (4, )) + + losses = head.loss(fake_cls_score, fake_gt_label) + assert losses['loss'].item() > 0 + def test_multilabel_head(): head = MultiLabelClsHead()