[Feature]Modify Parameters Passing in models.heads (#239)
* add mytrain.py for test * test before layers * test attr in layers * test classifier * delete mytrain.py * set cal_acc in ClsHead defaults to False * set cal_acc defaults to False * use *args, **kwargs instead * change bs16 to 3 in test_image_classifier_vit * fix some comments * change cal_acc=True * test LinearClsHeadpull/242/head
parent
37167158e7
commit
b30f79ea4c
|
@ -15,13 +15,13 @@ class ClsHead(BaseHead):
|
|||
topk (int | tuple): Top-k accuracy.
|
||||
cal_acc (bool): Whether to calculate accuracy during training.
|
||||
If you use Mixup/CutMix or something like that during training,
|
||||
it is not reasonable to calculate accuracy. Defaults to True.
|
||||
""" # noqa: W605
|
||||
it is not reasonable to calculate accuracy. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, ),
|
||||
cal_acc=True):
|
||||
cal_acc=False):
|
||||
super(ClsHead, self).__init__()
|
||||
|
||||
assert isinstance(loss, dict)
|
||||
|
|
|
@ -15,15 +15,10 @@ class LinearClsHead(ClsHead):
|
|||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
loss (dict): Config of classification loss.
|
||||
""" # noqa: W605
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, )):
|
||||
super(LinearClsHead, self).__init__(loss=loss, topk=topk)
|
||||
def __init__(self, num_classes, in_channels, *args, **kwargs):
|
||||
super(LinearClsHead, self).__init__(*args, **kwargs)
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
||||
|
|
|
@ -21,22 +21,16 @@ class VisionTransformerClsHead(ClsHead):
|
|||
available during pre-training. Default None.
|
||||
act_cfg (dict): The activation config. Only available during
|
||||
pre-training. Defalut Tanh.
|
||||
loss (dict): Config of classification loss.
|
||||
topk (int | tuple): Top-k accuracy.
|
||||
cal_acc (bool): Whether to calculate accuracy during training.
|
||||
If mixup is used, this should be False. Default False.
|
||||
""" # noqa: W605
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
hidden_dim=None,
|
||||
act_cfg=dict(type='Tanh'),
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, ),
|
||||
cal_acc=False):
|
||||
super(VisionTransformerClsHead, self).__init__(
|
||||
loss=loss, topk=topk, cal_acc=cal_acc)
|
||||
*args,
|
||||
**kwargs):
|
||||
super(VisionTransformerClsHead, self).__init__(*args, **kwargs)
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self.hidden_dim = hidden_dim
|
||||
|
|
|
@ -110,8 +110,8 @@ def test_image_classifier_vit():
|
|||
train_cfg=dict(mixup=dict(alpha=0.2, num_classes=1000)))
|
||||
img_classifier = ImageClassifier(**model_cfg)
|
||||
img_classifier.init_weights()
|
||||
imgs = torch.randn(16, 3, 224, 224)
|
||||
label = torch.randint(0, 1000, (16, ))
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
label = torch.randint(0, 1000, (3, ))
|
||||
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import torch
|
||||
|
||||
from mmcls.models.heads import (ClsHead, MultiLabelClsHead,
|
||||
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
|
||||
MultiLabelLinearClsHead)
|
||||
|
||||
|
||||
def test_cls_head():
|
||||
|
||||
# test ClsHead with cal_acc=True
|
||||
# test ClsHead with cal_acc=False
|
||||
head = ClsHead()
|
||||
fake_cls_score = torch.rand(4, 3)
|
||||
fake_gt_label = torch.randint(0, 2, (4, ))
|
||||
|
@ -14,14 +14,22 @@ def test_cls_head():
|
|||
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test ClsHead with cal_acc=False
|
||||
head = ClsHead(cal_acc=False)
|
||||
# test ClsHead with cal_acc=True
|
||||
head = ClsHead(cal_acc=True)
|
||||
fake_cls_score = torch.rand(4, 3)
|
||||
fake_gt_label = torch.randint(0, 2, (4, ))
|
||||
|
||||
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test LinearClsHead
|
||||
head = LinearClsHead(10, 100)
|
||||
fake_cls_score = torch.rand(4, 10)
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
losses = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_multilabel_head():
|
||||
head = MultiLabelClsHead()
|
||||
|
|
Loading…
Reference in New Issue