[Feature]Modify Parameters Passing in models.heads (#239)

* add mytrain.py for test

* test before layers

* test attr in layers

* test classifier

* delete mytrain.py

* set cal_acc in ClsHead defaults to False

* set cal_acc defaults to False

* use *args, **kwargs instead

* change bs16 to 3 in test_image_classifier_vit

* fix some comments

* change cal_acc=True

* test LinearClsHead
pull/242/head
whcao 2021-05-10 14:56:55 +08:00 committed by GitHub
parent 37167158e7
commit b30f79ea4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 27 deletions

View File

@ -15,13 +15,13 @@ class ClsHead(BaseHead):
topk (int | tuple): Top-k accuracy.
cal_acc (bool): Whether to calculate accuracy during training.
If you use Mixup/CutMix or something like that during training,
it is not reasonable to calculate accuracy. Defaults to True.
""" # noqa: W605
it is not reasonable to calculate accuracy. Defaults to False.
"""
def __init__(self,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
cal_acc=True):
cal_acc=False):
super(ClsHead, self).__init__()
assert isinstance(loss, dict)

View File

@ -15,15 +15,10 @@ class LinearClsHead(ClsHead):
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
loss (dict): Config of classification loss.
""" # noqa: W605
"""
def __init__(self,
num_classes,
in_channels,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, )):
super(LinearClsHead, self).__init__(loss=loss, topk=topk)
def __init__(self, num_classes, in_channels, *args, **kwargs):
super(LinearClsHead, self).__init__(*args, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes

View File

@ -21,22 +21,16 @@ class VisionTransformerClsHead(ClsHead):
available during pre-training. Default None.
act_cfg (dict): The activation config. Only available during
pre-training. Defalut Tanh.
loss (dict): Config of classification loss.
topk (int | tuple): Top-k accuracy.
cal_acc (bool): Whether to calculate accuracy during training.
If mixup is used, this should be False. Default False.
""" # noqa: W605
"""
def __init__(self,
num_classes,
in_channels,
hidden_dim=None,
act_cfg=dict(type='Tanh'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
cal_acc=False):
super(VisionTransformerClsHead, self).__init__(
loss=loss, topk=topk, cal_acc=cal_acc)
*args,
**kwargs):
super(VisionTransformerClsHead, self).__init__(*args, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
self.hidden_dim = hidden_dim

View File

@ -110,8 +110,8 @@ def test_image_classifier_vit():
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))
imgs = torch.randn(3, 3, 224, 224)
label = torch.randint(0, 1000, (3, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0

View File

@ -1,12 +1,12 @@
import torch
from mmcls.models.heads import (ClsHead, MultiLabelClsHead,
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead)
def test_cls_head():
# test ClsHead with cal_acc=True
# test ClsHead with cal_acc=False
head = ClsHead()
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
@ -14,14 +14,22 @@ def test_cls_head():
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
# test ClsHead with cal_acc=False
head = ClsHead(cal_acc=False)
# test ClsHead with cal_acc=True
head = ClsHead(cal_acc=True)
fake_cls_score = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
# test LinearClsHead
head = LinearClsHead(10, 100)
fake_cls_score = torch.rand(4, 10)
fake_gt_label = torch.randint(0, 10, (4, ))
losses = head.loss(fake_cls_score, fake_gt_label)
assert losses['loss'].item() > 0
def test_multilabel_head():
head = MultiLabelClsHead()