Add multi_label heads
parent
4fcd7ee072
commit
e9342d9e4c
|
@ -407,7 +407,7 @@ def _average_precision(pred: torch.Tensor,
|
||||||
total_pos = tps[-1].item() # the last of tensor may change later
|
total_pos = tps[-1].item() # the last of tensor may change later
|
||||||
|
|
||||||
# Calculate cumulative tp&fp(pred_poss) case numbers
|
# Calculate cumulative tp&fp(pred_poss) case numbers
|
||||||
pred_pos_nums = torch.arange(1, len(sorted_target) + 1)
|
pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device)
|
||||||
pred_pos_nums[pred_pos_nums < eps] = eps
|
pred_pos_nums[pred_pos_nums < eps] = eps
|
||||||
|
|
||||||
tps[torch.logical_not(pos_inds)] = 0
|
tps[torch.logical_not(pos_inds)] = 0
|
||||||
|
|
|
@ -3,7 +3,7 @@ from .cls_head import ClsHead
|
||||||
from .conformer_head import ConformerHead
|
from .conformer_head import ConformerHead
|
||||||
from .deit_head import DeiTClsHead
|
from .deit_head import DeiTClsHead
|
||||||
from .linear_head import LinearClsHead
|
from .linear_head import LinearClsHead
|
||||||
from .multi_label_head import MultiLabelClsHead
|
from .multi_label_cls_head import MultiLabelClsHead
|
||||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||||
from .stacked_head import StackedLinearClsHead
|
from .stacked_head import StackedLinearClsHead
|
||||||
from .vision_transformer_head import VisionTransformerClsHead
|
from .vision_transformer_head import VisionTransformerClsHead
|
||||||
|
|
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mmengine.data import LabelData
|
||||||
|
|
||||||
|
from mmcls.core import ClsDataSample
|
||||||
|
from mmcls.registry import MODELS
|
||||||
|
from .base_head import BaseHead
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class MultiLabelClsHead(BaseHead):
|
||||||
|
"""Classification head for multilabel task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (dict): Config of classification loss. Defaults to
|
||||||
|
dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||||
|
thr (float, optional): Predictions with scores under the thresholds
|
||||||
|
are considered as negative. Defaults to None.
|
||||||
|
topk (int, optional): Predictions with the k-th highest scores are
|
||||||
|
considered as positive. Defaults to None.
|
||||||
|
init_cfg (dict, optional): The extra init config of layers.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
If both ``thr`` and ``topk`` are set, use ``thr` to determine
|
||||||
|
positive predictions. If neither is set, use ``thr=0.5`` as
|
||||||
|
default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
|
||||||
|
thr: Optional[float] = None,
|
||||||
|
topk: Optional[int] = None,
|
||||||
|
init_cfg: Optional[dict] = None):
|
||||||
|
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
self.loss_module = MODELS.build(loss)
|
||||||
|
|
||||||
|
if thr is None and topk is None:
|
||||||
|
thr = 0.5
|
||||||
|
|
||||||
|
self.thr = thr
|
||||||
|
self.topk = topk
|
||||||
|
|
||||||
|
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""The process before the final classification head.
|
||||||
|
|
||||||
|
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||||
|
feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain
|
||||||
|
the feature of the last stage.
|
||||||
|
"""
|
||||||
|
# The MultiLabelClsHead doesn't have other module, just return after
|
||||||
|
# unpacking.
|
||||||
|
return feats[-1]
|
||||||
|
|
||||||
|
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""The forward process."""
|
||||||
|
pre_logits = self.pre_logits(feats)
|
||||||
|
# The MultiLabelClsHead doesn't have the final classification head,
|
||||||
|
# just return the unpacked inputs.
|
||||||
|
return pre_logits
|
||||||
|
|
||||||
|
def loss(self, feats: Tuple[torch.Tensor],
|
||||||
|
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||||
|
"""Calculate losses from the classification score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||||
|
Multiple stage inputs are acceptable but only the last stage
|
||||||
|
will be used to classify. The shape of every item should be
|
||||||
|
``(num_samples, num_classes)``.
|
||||||
|
data_samples (List[ClsDataSample]): The annotation data of
|
||||||
|
every samples.
|
||||||
|
**kwargs: Other keyword arguments to forward the loss module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Tensor]: a dictionary of loss components
|
||||||
|
"""
|
||||||
|
# The part can be traced by torch.fx
|
||||||
|
cls_score = self(feats)
|
||||||
|
|
||||||
|
# The part can not be traced by torch.fx
|
||||||
|
losses = self._get_loss(cls_score, data_samples, **kwargs)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def _get_loss(self, cls_score: torch.Tensor,
|
||||||
|
data_samples: List[ClsDataSample], **kwargs):
|
||||||
|
"""Unpack data samples and compute loss."""
|
||||||
|
num_classes = cls_score.size()[-1]
|
||||||
|
# Unpack data samples and pack targets
|
||||||
|
if 'score' in data_samples[0].gt_label:
|
||||||
|
target = torch.stack(
|
||||||
|
[i.gt_label.score.float() for i in data_samples])
|
||||||
|
else:
|
||||||
|
target = torch.stack([
|
||||||
|
LabelData.label_to_onehot(i.gt_label.label,
|
||||||
|
num_classes).float()
|
||||||
|
for i in data_samples
|
||||||
|
])
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
losses = dict()
|
||||||
|
loss = self.loss_module(
|
||||||
|
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
|
||||||
|
losses['loss'] = loss
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
feats: Tuple[torch.Tensor],
|
||||||
|
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
||||||
|
"""Inference without augmentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||||
|
Multiple stage inputs are acceptable but only the last stage
|
||||||
|
will be used to classify. The shape of every item should be
|
||||||
|
``(num_samples, num_classes)``.
|
||||||
|
data_samples (List[ClsDataSample], optional): The annotation
|
||||||
|
data of every samples. If not None, set ``pred_label`` of
|
||||||
|
the input data samples. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ClsDataSample]: A list of data samples which contains the
|
||||||
|
predicted results.
|
||||||
|
"""
|
||||||
|
# The part can be traced by torch.fx
|
||||||
|
cls_score = self(feats)
|
||||||
|
|
||||||
|
# The part can not be traced by torch.fx
|
||||||
|
predictions = self._get_predictions(cls_score, data_samples)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def _get_predictions(self, cls_score: torch.Tensor,
|
||||||
|
data_samples: List[ClsDataSample]):
|
||||||
|
"""Post-process the output of head.
|
||||||
|
|
||||||
|
Including softmax and set ``pred_label`` of data samples.
|
||||||
|
"""
|
||||||
|
pred_scores = torch.sigmoid(cls_score)
|
||||||
|
|
||||||
|
if data_samples is None:
|
||||||
|
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))]
|
||||||
|
|
||||||
|
for data_sample, score in zip(data_samples, pred_scores):
|
||||||
|
if self.thr is not None:
|
||||||
|
# a label is predicted positive if larger than thr
|
||||||
|
label = torch.where(score >= self.thr)[0]
|
||||||
|
else:
|
||||||
|
# top-k labels will be predicted positive for any example
|
||||||
|
_, label = score.topk(self.topk)
|
||||||
|
data_sample.set_pred_score(score).set_pred_label(label)
|
||||||
|
|
||||||
|
return data_samples
|
|
@ -1,99 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from mmcls.registry import MODELS
|
|
||||||
from ..utils import is_tracing
|
|
||||||
from .base_head import BaseHead
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class MultiLabelClsHead(BaseHead):
|
|
||||||
"""Classification head for multilabel task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss (dict): Config of classification loss.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
loss=dict(
|
|
||||||
type='CrossEntropyLoss',
|
|
||||||
use_sigmoid=True,
|
|
||||||
reduction='mean',
|
|
||||||
loss_weight=1.0),
|
|
||||||
init_cfg=None):
|
|
||||||
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
|
|
||||||
|
|
||||||
assert isinstance(loss, dict)
|
|
||||||
|
|
||||||
self.compute_loss = MODELS.build(loss)
|
|
||||||
|
|
||||||
def loss(self, cls_score, gt_label):
|
|
||||||
gt_label = gt_label.type_as(cls_score)
|
|
||||||
num_samples = len(cls_score)
|
|
||||||
losses = dict()
|
|
||||||
|
|
||||||
# map difficult examples to positive ones
|
|
||||||
_gt_label = torch.abs(gt_label)
|
|
||||||
# compute loss
|
|
||||||
loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples)
|
|
||||||
losses['loss'] = loss
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def forward_train(self, cls_score, gt_label, **kwargs):
|
|
||||||
if isinstance(cls_score, tuple):
|
|
||||||
cls_score = cls_score[-1]
|
|
||||||
gt_label = gt_label.type_as(cls_score)
|
|
||||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def pre_logits(self, x):
|
|
||||||
if isinstance(x, tuple):
|
|
||||||
x = x[-1]
|
|
||||||
|
|
||||||
from mmcls.utils import get_root_logger
|
|
||||||
logger = get_root_logger()
|
|
||||||
logger.warning(
|
|
||||||
'The input of MultiLabelClsHead should be already logits. '
|
|
||||||
'Please modify the backbone if you want to get pre-logits feature.'
|
|
||||||
)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def simple_test(self, x, sigmoid=True, post_process=True):
|
|
||||||
"""Inference without augmentation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cls_score (tuple[Tensor]): The input classification score logits.
|
|
||||||
Multi-stage inputs are acceptable but only the last stage will
|
|
||||||
be used to classify. The shape of every item should be
|
|
||||||
``(num_samples, num_classes)``.
|
|
||||||
sigmoid (bool): Whether to sigmoid the classification score.
|
|
||||||
post_process (bool): Whether to do post processing the
|
|
||||||
inference results. It will convert the output to a list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor | list: The inference results.
|
|
||||||
|
|
||||||
- If no post processing, the output is a tensor with shape
|
|
||||||
``(num_samples, num_classes)``.
|
|
||||||
- If post processing, the output is a multi-dimentional list of
|
|
||||||
float and the dimensions are ``(num_samples, num_classes)``.
|
|
||||||
"""
|
|
||||||
if isinstance(x, tuple):
|
|
||||||
x = x[-1]
|
|
||||||
|
|
||||||
if sigmoid:
|
|
||||||
pred = torch.sigmoid(x) if x is not None else None
|
|
||||||
else:
|
|
||||||
pred = x
|
|
||||||
|
|
||||||
if post_process:
|
|
||||||
return self.post_process(pred)
|
|
||||||
else:
|
|
||||||
return pred
|
|
||||||
|
|
||||||
def post_process(self, pred):
|
|
||||||
on_trace = is_tracing()
|
|
||||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
|
||||||
return pred
|
|
||||||
pred = list(pred.detach().cpu().numpy())
|
|
||||||
return pred
|
|
|
@ -1,9 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmcls.registry import MODELS
|
from mmcls.registry import MODELS
|
||||||
from .multi_label_head import MultiLabelClsHead
|
from .multi_label_cls_head import MultiLabelClsHead
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
|
@ -11,75 +13,54 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
||||||
"""Linear classification head for multilabel task.
|
"""Linear classification head for multilabel task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_classes (int): Number of categories.
|
loss (dict): Config of classification loss. Defaults to
|
||||||
in_channels (int): Number of channels in the input feature map.
|
dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||||
loss (dict): Config of classification loss.
|
thr (float, optional): Predictions with scores under the thresholds
|
||||||
init_cfg (dict | optional): The extra init config of layers.
|
are considered as negative. Defaults to None.
|
||||||
|
topk (int, optional): Predictions with the k-th highest scores are
|
||||||
|
considered as positive. Defaults to None.
|
||||||
|
init_cfg (dict, optional): The extra init config of layers.
|
||||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
If both ``thr`` and ``topk`` are set, use ``thr` to determine
|
||||||
|
positive predictions. If neither is set, use ``thr=0.5`` as
|
||||||
|
default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes,
|
num_classes: int,
|
||||||
in_channels,
|
in_channels: int,
|
||||||
loss=dict(
|
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
|
||||||
type='CrossEntropyLoss',
|
thr: Optional[float] = None,
|
||||||
use_sigmoid=True,
|
topk: Optional[int] = None,
|
||||||
reduction='mean',
|
init_cfg: Optional[dict] = dict(
|
||||||
loss_weight=1.0),
|
type='Normal', layer='Linear', std=0.01)):
|
||||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
|
|
||||||
super(MultiLabelLinearClsHead, self).__init__(
|
super(MultiLabelLinearClsHead, self).__init__(
|
||||||
loss=loss, init_cfg=init_cfg)
|
loss=loss, thr=thr, topk=topk, init_cfg=init_cfg)
|
||||||
|
|
||||||
if num_classes <= 0:
|
assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \
|
||||||
raise ValueError(
|
'positive integer.'
|
||||||
f'num_classes={num_classes} must be a positive integer')
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
|
||||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||||
|
|
||||||
def pre_logits(self, x):
|
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||||
if isinstance(x, tuple):
|
"""The process before the final classification head.
|
||||||
x = x[-1]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward_train(self, x, gt_label, **kwargs):
|
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||||
x = self.pre_logits(x)
|
feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just
|
||||||
gt_label = gt_label.type_as(x)
|
obtain the feature of the last stage.
|
||||||
cls_score = self.fc(x)
|
|
||||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def simple_test(self, x, sigmoid=True, post_process=True):
|
|
||||||
"""Inference without augmentation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tuple[Tensor]): The input features.
|
|
||||||
Multi-stage inputs are acceptable but only the last stage will
|
|
||||||
be used to classify. The shape of every item should be
|
|
||||||
``(num_samples, in_channels)``.
|
|
||||||
sigmoid (bool): Whether to sigmoid the classification score.
|
|
||||||
post_process (bool): Whether to do post processing the
|
|
||||||
inference results. It will convert the output to a list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor | list: The inference results.
|
|
||||||
|
|
||||||
- If no post processing, the output is a tensor with shape
|
|
||||||
``(num_samples, num_classes)``.
|
|
||||||
- If post processing, the output is a multi-dimentional list of
|
|
||||||
float and the dimensions are ``(num_samples, num_classes)``.
|
|
||||||
"""
|
"""
|
||||||
x = self.pre_logits(x)
|
# The obtain the MultiLabelLinearClsHead doesn't have other module,
|
||||||
cls_score = self.fc(x)
|
# just return after unpacking.
|
||||||
|
return feats[-1]
|
||||||
|
|
||||||
if sigmoid:
|
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||||
pred = torch.sigmoid(cls_score) if cls_score is not None else None
|
"""The forward process."""
|
||||||
else:
|
pre_logits = self.pre_logits(feats)
|
||||||
pred = cls_score
|
# The final classification head.
|
||||||
|
cls_score = self.fc(pre_logits)
|
||||||
if post_process:
|
return cls_score
|
||||||
return self.post_process(pred)
|
|
||||||
else:
|
|
||||||
return pred
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmengine import is_seq_of
|
from mmengine import is_seq_of
|
||||||
|
|
||||||
|
@ -11,6 +14,14 @@ from mmcls.utils import register_all_modules
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
class TestClsHead(TestCase):
|
class TestClsHead(TestCase):
|
||||||
DEFAULT_ARGS = dict(type='ClsHead')
|
DEFAULT_ARGS = dict(type='ClsHead')
|
||||||
|
|
||||||
|
@ -305,113 +316,124 @@ class TestStackedLinearClsHead(TestCase):
|
||||||
self.assertEqual(outs.shape, (4, 5))
|
self.assertEqual(outs.shape, (4, 5))
|
||||||
|
|
||||||
|
|
||||||
"""Temporarily disabled.
|
class TestMultiLabelClsHead(TestCase):
|
||||||
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
|
DEFAULT_ARGS = dict(type='MultiLabelClsHead')
|
||||||
def test_multilabel_head(feat):
|
|
||||||
head = MultiLabelClsHead()
|
|
||||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
|
||||||
|
|
||||||
losses = head.forward_train(feat, fake_gt_label)
|
def test_pre_logits(self):
|
||||||
assert losses['loss'].item() > 0
|
head = MODELS.build(self.DEFAULT_ARGS)
|
||||||
|
|
||||||
# test simple_test with post_process
|
# return the last item
|
||||||
pred = head.simple_test(feat)
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||||
assert isinstance(pred, list) and len(pred) == 4
|
pre_logits = head.pre_logits(feats)
|
||||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
self.assertIs(pre_logits, feats[-1])
|
||||||
pred = head.simple_test(feat)
|
|
||||||
assert pred.shape == (4, 10)
|
|
||||||
|
|
||||||
# test simple_test without post_process
|
def test_forward(self):
|
||||||
pred = head.simple_test(feat, post_process=False)
|
head = MODELS.build(self.DEFAULT_ARGS)
|
||||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
|
||||||
logits = head.simple_test(feat, sigmoid=False, post_process=False)
|
|
||||||
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
|
|
||||||
|
|
||||||
# test pre_logits
|
# return the last item (same as pre_logits)
|
||||||
features = head.pre_logits(feat)
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||||
if isinstance(feat, tuple):
|
outs = head(feats)
|
||||||
torch.testing.assert_allclose(features, feat[0])
|
self.assertIs(outs, feats[-1])
|
||||||
else:
|
|
||||||
torch.testing.assert_allclose(features, feat)
|
def test_loss(self):
|
||||||
|
feats = (torch.rand(4, 10), )
|
||||||
|
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
|
||||||
|
|
||||||
|
# Test with thr and topk are all None
|
||||||
|
head = MODELS.build(self.DEFAULT_ARGS)
|
||||||
|
losses = head.loss(feats, data_samples)
|
||||||
|
self.assertEqual(head.thr, 0.5)
|
||||||
|
self.assertEqual(head.topk, None)
|
||||||
|
self.assertEqual(losses.keys(), {'loss'})
|
||||||
|
self.assertGreater(losses['loss'].item(), 0)
|
||||||
|
|
||||||
|
# Test with topk
|
||||||
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||||
|
cfg['topk'] = 2
|
||||||
|
head = MODELS.build(cfg)
|
||||||
|
losses = head.loss(feats, data_samples)
|
||||||
|
self.assertEqual(head.thr, None, cfg)
|
||||||
|
self.assertEqual(head.topk, 2)
|
||||||
|
self.assertEqual(losses.keys(), {'loss'})
|
||||||
|
self.assertGreater(losses['loss'].item(), 0)
|
||||||
|
|
||||||
|
# Test with thr
|
||||||
|
setup_seed(0)
|
||||||
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||||
|
cfg['thr'] = 0.1
|
||||||
|
head = MODELS.build(cfg)
|
||||||
|
thr_losses = head.loss(feats, data_samples)
|
||||||
|
self.assertEqual(head.thr, 0.1)
|
||||||
|
self.assertEqual(head.topk, None)
|
||||||
|
self.assertEqual(thr_losses.keys(), {'loss'})
|
||||||
|
self.assertGreater(thr_losses['loss'].item(), 0)
|
||||||
|
|
||||||
|
# Test with thr and topk are all not None
|
||||||
|
setup_seed(0)
|
||||||
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||||
|
cfg['thr'] = 0.1
|
||||||
|
cfg['topk'] = 2
|
||||||
|
head = MODELS.build(cfg)
|
||||||
|
thr_topk_losses = head.loss(feats, data_samples)
|
||||||
|
self.assertEqual(head.thr, 0.1)
|
||||||
|
self.assertEqual(head.topk, 2)
|
||||||
|
self.assertEqual(thr_topk_losses.keys(), {'loss'})
|
||||||
|
self.assertGreater(thr_topk_losses['loss'].item(), 0)
|
||||||
|
|
||||||
|
# Test with gt_lable with score
|
||||||
|
data_samples = [
|
||||||
|
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
|
||||||
|
]
|
||||||
|
|
||||||
|
head = MODELS.build(self.DEFAULT_ARGS)
|
||||||
|
losses = head.loss(feats, data_samples)
|
||||||
|
self.assertEqual(head.thr, 0.5)
|
||||||
|
self.assertEqual(head.topk, None)
|
||||||
|
self.assertEqual(losses.keys(), {'loss'})
|
||||||
|
self.assertGreater(losses['loss'].item(), 0)
|
||||||
|
|
||||||
|
def test_predict(self):
|
||||||
|
feats = (torch.rand(4, 10), )
|
||||||
|
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
|
||||||
|
head = MODELS.build(self.DEFAULT_ARGS)
|
||||||
|
|
||||||
|
# with without data_samples
|
||||||
|
predictions = head.predict(feats)
|
||||||
|
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||||
|
for pred in predictions:
|
||||||
|
self.assertIn('label', pred.pred_label)
|
||||||
|
self.assertIn('score', pred.pred_label)
|
||||||
|
|
||||||
|
# with with data_samples
|
||||||
|
predictions = head.predict(feats, data_samples)
|
||||||
|
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||||
|
for sample, pred in zip(data_samples, predictions):
|
||||||
|
self.assertIs(sample, pred)
|
||||||
|
self.assertIn('label', pred.pred_label)
|
||||||
|
self.assertIn('score', pred.pred_label)
|
||||||
|
|
||||||
|
# Test with topk
|
||||||
|
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||||
|
cfg['topk'] = 2
|
||||||
|
head = MODELS.build(cfg)
|
||||||
|
predictions = head.predict(feats, data_samples)
|
||||||
|
self.assertEqual(head.thr, None)
|
||||||
|
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||||
|
for sample, pred in zip(data_samples, predictions):
|
||||||
|
self.assertIs(sample, pred)
|
||||||
|
self.assertIn('label', pred.pred_label)
|
||||||
|
self.assertIn('score', pred.pred_label)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
|
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
|
||||||
def test_multilabel_linear_head(feat):
|
DEFAULT_ARGS = dict(
|
||||||
head = MultiLabelLinearClsHead(10, 5)
|
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
|
||||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
|
||||||
|
|
||||||
head.init_weights()
|
def test_forward(self):
|
||||||
losses = head.forward_train(feat, fake_gt_label)
|
head = MODELS.build(self.DEFAULT_ARGS)
|
||||||
assert losses['loss'].item() > 0
|
self.assertTrue(hasattr(head, 'fc'))
|
||||||
|
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
|
||||||
|
|
||||||
# test simple_test with post_process
|
# return the last item (same as pre_logits)
|
||||||
pred = head.simple_test(feat)
|
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||||
assert isinstance(pred, list) and len(pred) == 4
|
head(feats)
|
||||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
|
||||||
pred = head.simple_test(feat)
|
|
||||||
assert pred.shape == (4, 10)
|
|
||||||
|
|
||||||
# test simple_test without post_process
|
|
||||||
pred = head.simple_test(feat, post_process=False)
|
|
||||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
|
||||||
logits = head.simple_test(feat, sigmoid=False, post_process=False)
|
|
||||||
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
|
|
||||||
|
|
||||||
# test pre_logits
|
|
||||||
features = head.pre_logits(feat)
|
|
||||||
if isinstance(feat, tuple):
|
|
||||||
torch.testing.assert_allclose(features, feat[0])
|
|
||||||
else:
|
|
||||||
torch.testing.assert_allclose(features, feat)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
|
|
||||||
def test_stacked_linear_cls_head(feat):
|
|
||||||
# test assertion
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10)
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10])
|
|
||||||
|
|
||||||
fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes
|
|
||||||
|
|
||||||
# test forward with default setting
|
|
||||||
head = StackedLinearClsHead(
|
|
||||||
num_classes=10, in_channels=5, mid_channels=[20])
|
|
||||||
head.init_weights()
|
|
||||||
|
|
||||||
losses = head.forward_train(feat, fake_gt_label)
|
|
||||||
assert losses['loss'].item() > 0
|
|
||||||
|
|
||||||
# test simple_test with post_process
|
|
||||||
pred = head.simple_test(feat)
|
|
||||||
assert isinstance(pred, list) and len(pred) == 4
|
|
||||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
|
||||||
pred = head.simple_test(feat)
|
|
||||||
assert pred.shape == (4, 10)
|
|
||||||
|
|
||||||
# test simple_test without post_process
|
|
||||||
pred = head.simple_test(feat, post_process=False)
|
|
||||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
|
||||||
logits = head.simple_test(feat, softmax=False, post_process=False)
|
|
||||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
|
||||||
|
|
||||||
# test pre_logits
|
|
||||||
features = head.pre_logits(feat)
|
|
||||||
assert features.shape == (4, 20)
|
|
||||||
|
|
||||||
# test forward with full function
|
|
||||||
head = StackedLinearClsHead(
|
|
||||||
num_classes=3,
|
|
||||||
in_channels=5,
|
|
||||||
mid_channels=[8, 10],
|
|
||||||
dropout_rate=0.2,
|
|
||||||
norm_cfg=dict(type='BN1d'),
|
|
||||||
act_cfg=dict(type='HSwish'))
|
|
||||||
head.init_weights()
|
|
||||||
|
|
||||||
losses = head.forward_train(feat, fake_gt_label)
|
|
||||||
assert losses['loss'].item() > 0
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
Loading…
Reference in New Issue