diff --git a/mmcls/metrics/multi_label.py b/mmcls/metrics/multi_label.py index c3b964171..fef8224c7 100644 --- a/mmcls/metrics/multi_label.py +++ b/mmcls/metrics/multi_label.py @@ -407,7 +407,7 @@ def _average_precision(pred: torch.Tensor, total_pos = tps[-1].item() # the last of tensor may change later # Calculate cumulative tp&fp(pred_poss) case numbers - pred_pos_nums = torch.arange(1, len(sorted_target) + 1) + pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device) pred_pos_nums[pred_pos_nums < eps] = eps tps[torch.logical_not(pos_inds)] = 0 diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py index b81106fbe..13dbba982 100644 --- a/mmcls/models/heads/__init__.py +++ b/mmcls/models/heads/__init__.py @@ -3,7 +3,7 @@ from .cls_head import ClsHead from .conformer_head import ConformerHead from .deit_head import DeiTClsHead from .linear_head import LinearClsHead -from .multi_label_head import MultiLabelClsHead +from .multi_label_cls_head import MultiLabelClsHead from .multi_label_linear_head import MultiLabelLinearClsHead from .stacked_head import StackedLinearClsHead from .vision_transformer_head import VisionTransformerClsHead diff --git a/mmcls/models/heads/multi_label_cls_head.py b/mmcls/models/heads/multi_label_cls_head.py new file mode 100644 index 000000000..268bb7946 --- /dev/null +++ b/mmcls/models/heads/multi_label_cls_head.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +from mmengine.data import LabelData + +from mmcls.core import ClsDataSample +from mmcls.registry import MODELS +from .base_head import BaseHead + + +@MODELS.register_module() +class MultiLabelClsHead(BaseHead): + """Classification head for multilabel task. + + Args: + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. + Defaults to None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + + def __init__(self, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = None): + super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) + + self.loss_module = MODELS.build(loss) + + if thr is None and topk is None: + thr = 0.5 + + self.thr = thr + self.topk = topk + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain + the feature of the last stage. + """ + # The MultiLabelClsHead doesn't have other module, just return after + # unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The MultiLabelClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[ClsDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[ClsDataSample], **kwargs): + """Unpack data samples and compute loss.""" + num_classes = cls_score.size()[-1] + # Unpack data samples and pack targets + if 'score' in data_samples[0].gt_label: + target = torch.stack( + [i.gt_label.score.float() for i in data_samples]) + else: + target = torch.stack([ + LabelData.label_to_onehot(i.gt_label.label, + num_classes).float() + for i in data_samples + ]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[ClsDataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score: torch.Tensor, + data_samples: List[ClsDataSample]): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = torch.sigmoid(cls_score) + + if data_samples is None: + data_samples = [ClsDataSample() for _ in range(cls_score.size(0))] + + for data_sample, score in zip(data_samples, pred_scores): + if self.thr is not None: + # a label is predicted positive if larger than thr + label = torch.where(score >= self.thr)[0] + else: + # top-k labels will be predicted positive for any example + _, label = score.topk(self.topk) + data_sample.set_pred_score(score).set_pred_label(label) + + return data_samples diff --git a/mmcls/models/heads/multi_label_head.py b/mmcls/models/heads/multi_label_head.py deleted file mode 100644 index 0b6aceb43..000000000 --- a/mmcls/models/heads/multi_label_head.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmcls.registry import MODELS -from ..utils import is_tracing -from .base_head import BaseHead - - -@MODELS.register_module() -class MultiLabelClsHead(BaseHead): - """Classification head for multilabel task. - - Args: - loss (dict): Config of classification loss. - """ - - def __init__(self, - loss=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - reduction='mean', - loss_weight=1.0), - init_cfg=None): - super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) - - assert isinstance(loss, dict) - - self.compute_loss = MODELS.build(loss) - - def loss(self, cls_score, gt_label): - gt_label = gt_label.type_as(cls_score) - num_samples = len(cls_score) - losses = dict() - - # map difficult examples to positive ones - _gt_label = torch.abs(gt_label) - # compute loss - loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples) - losses['loss'] = loss - return losses - - def forward_train(self, cls_score, gt_label, **kwargs): - if isinstance(cls_score, tuple): - cls_score = cls_score[-1] - gt_label = gt_label.type_as(cls_score) - losses = self.loss(cls_score, gt_label, **kwargs) - return losses - - def pre_logits(self, x): - if isinstance(x, tuple): - x = x[-1] - - from mmcls.utils import get_root_logger - logger = get_root_logger() - logger.warning( - 'The input of MultiLabelClsHead should be already logits. ' - 'Please modify the backbone if you want to get pre-logits feature.' - ) - return x - - def simple_test(self, x, sigmoid=True, post_process=True): - """Inference without augmentation. - - Args: - cls_score (tuple[Tensor]): The input classification score logits. - Multi-stage inputs are acceptable but only the last stage will - be used to classify. The shape of every item should be - ``(num_samples, num_classes)``. - sigmoid (bool): Whether to sigmoid the classification score. - post_process (bool): Whether to do post processing the - inference results. It will convert the output to a list. - - Returns: - Tensor | list: The inference results. - - - If no post processing, the output is a tensor with shape - ``(num_samples, num_classes)``. - - If post processing, the output is a multi-dimentional list of - float and the dimensions are ``(num_samples, num_classes)``. - """ - if isinstance(x, tuple): - x = x[-1] - - if sigmoid: - pred = torch.sigmoid(x) if x is not None else None - else: - pred = x - - if post_process: - return self.post_process(pred) - else: - return pred - - def post_process(self, pred): - on_trace = is_tracing() - if torch.onnx.is_in_onnx_export() or on_trace: - return pred - pred = list(pred.detach().cpu().numpy()) - return pred diff --git a/mmcls/models/heads/multi_label_linear_head.py b/mmcls/models/heads/multi_label_linear_head.py index fadb6ec00..08742f9c8 100644 --- a/mmcls/models/heads/multi_label_linear_head.py +++ b/mmcls/models/heads/multi_label_linear_head.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + import torch import torch.nn as nn from mmcls.registry import MODELS -from .multi_label_head import MultiLabelClsHead +from .multi_label_cls_head import MultiLabelClsHead @MODELS.register_module() @@ -11,75 +13,54 @@ class MultiLabelLinearClsHead(MultiLabelClsHead): """Linear classification head for multilabel task. Args: - num_classes (int): Number of categories. - in_channels (int): Number of channels in the input feature map. - loss (dict): Config of classification loss. - init_cfg (dict | optional): The extra init config of layers. + loss (dict): Config of classification loss. Defaults to + dict(type='CrossEntropyLoss', use_sigmoid=True). + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + init_cfg (dict, optional): The extra init config of layers. Defaults to use dict(type='Normal', layer='Linear', std=0.01). + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. """ def __init__(self, - num_classes, - in_channels, - loss=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - reduction='mean', - loss_weight=1.0), - init_cfg=dict(type='Normal', layer='Linear', std=0.01)): + num_classes: int, + in_channels: int, + loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), + thr: Optional[float] = None, + topk: Optional[int] = None, + init_cfg: Optional[dict] = dict( + type='Normal', layer='Linear', std=0.01)): super(MultiLabelLinearClsHead, self).__init__( - loss=loss, init_cfg=init_cfg) + loss=loss, thr=thr, topk=topk, init_cfg=init_cfg) - if num_classes <= 0: - raise ValueError( - f'num_classes={num_classes} must be a positive integer') + assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \ + 'positive integer.' self.in_channels = in_channels self.num_classes = num_classes self.fc = nn.Linear(self.in_channels, self.num_classes) - def pre_logits(self, x): - if isinstance(x, tuple): - x = x[-1] - return x + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. - def forward_train(self, x, gt_label, **kwargs): - x = self.pre_logits(x) - gt_label = gt_label.type_as(x) - cls_score = self.fc(x) - losses = self.loss(cls_score, gt_label, **kwargs) - return losses - - def simple_test(self, x, sigmoid=True, post_process=True): - """Inference without augmentation. - - Args: - x (tuple[Tensor]): The input features. - Multi-stage inputs are acceptable but only the last stage will - be used to classify. The shape of every item should be - ``(num_samples, in_channels)``. - sigmoid (bool): Whether to sigmoid the classification score. - post_process (bool): Whether to do post processing the - inference results. It will convert the output to a list. - - Returns: - Tensor | list: The inference results. - - - If no post processing, the output is a tensor with shape - ``(num_samples, num_classes)``. - - If post processing, the output is a multi-dimentional list of - float and the dimensions are ``(num_samples, num_classes)``. + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just + obtain the feature of the last stage. """ - x = self.pre_logits(x) - cls_score = self.fc(x) + # The obtain the MultiLabelLinearClsHead doesn't have other module, + # just return after unpacking. + return feats[-1] - if sigmoid: - pred = torch.sigmoid(cls_score) if cls_score is not None else None - else: - pred = cls_score - - if post_process: - return self.post_process(pred) - else: - return pred + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The final classification head. + cls_score = self.fc(pre_logits) + return cls_score diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 3cb659d04..f1c00cf4f 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +import random from unittest import TestCase +import numpy as np import torch from mmengine import is_seq_of @@ -11,6 +14,14 @@ from mmcls.utils import register_all_modules register_all_modules() +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + class TestClsHead(TestCase): DEFAULT_ARGS = dict(type='ClsHead') @@ -305,113 +316,124 @@ class TestStackedLinearClsHead(TestCase): self.assertEqual(outs.shape, (4, 5)) -"""Temporarily disabled. -@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )]) -def test_multilabel_head(feat): - head = MultiLabelClsHead() - fake_gt_label = torch.randint(0, 2, (4, 10)) +class TestMultiLabelClsHead(TestCase): + DEFAULT_ARGS = dict(type='MultiLabelClsHead') - losses = head.forward_train(feat, fake_gt_label) - assert losses['loss'].item() > 0 + def test_pre_logits(self): + head = MODELS.build(self.DEFAULT_ARGS) - # test simple_test with post_process - pred = head.simple_test(feat) - assert isinstance(pred, list) and len(pred) == 4 - with patch('torch.onnx.is_in_onnx_export', return_value=True): - pred = head.simple_test(feat) - assert pred.shape == (4, 10) + # return the last item + feats = (torch.rand(4, 10), torch.rand(4, 10)) + pre_logits = head.pre_logits(feats) + self.assertIs(pre_logits, feats[-1]) - # test simple_test without post_process - pred = head.simple_test(feat, post_process=False) - assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) - logits = head.simple_test(feat, sigmoid=False, post_process=False) - torch.testing.assert_allclose(pred, torch.sigmoid(logits)) + def test_forward(self): + head = MODELS.build(self.DEFAULT_ARGS) - # test pre_logits - features = head.pre_logits(feat) - if isinstance(feat, tuple): - torch.testing.assert_allclose(features, feat[0]) - else: - torch.testing.assert_allclose(features, feat) + # return the last item (same as pre_logits) + feats = (torch.rand(4, 10), torch.rand(4, 10)) + outs = head(feats) + self.assertIs(outs, feats[-1]) + + def test_loss(self): + feats = (torch.rand(4, 10), ) + data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)] + + # Test with thr and topk are all None + head = MODELS.build(self.DEFAULT_ARGS) + losses = head.loss(feats, data_samples) + self.assertEqual(head.thr, 0.5) + self.assertEqual(head.topk, None) + self.assertEqual(losses.keys(), {'loss'}) + self.assertGreater(losses['loss'].item(), 0) + + # Test with topk + cfg = copy.deepcopy(self.DEFAULT_ARGS) + cfg['topk'] = 2 + head = MODELS.build(cfg) + losses = head.loss(feats, data_samples) + self.assertEqual(head.thr, None, cfg) + self.assertEqual(head.topk, 2) + self.assertEqual(losses.keys(), {'loss'}) + self.assertGreater(losses['loss'].item(), 0) + + # Test with thr + setup_seed(0) + cfg = copy.deepcopy(self.DEFAULT_ARGS) + cfg['thr'] = 0.1 + head = MODELS.build(cfg) + thr_losses = head.loss(feats, data_samples) + self.assertEqual(head.thr, 0.1) + self.assertEqual(head.topk, None) + self.assertEqual(thr_losses.keys(), {'loss'}) + self.assertGreater(thr_losses['loss'].item(), 0) + + # Test with thr and topk are all not None + setup_seed(0) + cfg = copy.deepcopy(self.DEFAULT_ARGS) + cfg['thr'] = 0.1 + cfg['topk'] = 2 + head = MODELS.build(cfg) + thr_topk_losses = head.loss(feats, data_samples) + self.assertEqual(head.thr, 0.1) + self.assertEqual(head.topk, 2) + self.assertEqual(thr_topk_losses.keys(), {'loss'}) + self.assertGreater(thr_topk_losses['loss'].item(), 0) + + # Test with gt_lable with score + data_samples = [ + ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4) + ] + + head = MODELS.build(self.DEFAULT_ARGS) + losses = head.loss(feats, data_samples) + self.assertEqual(head.thr, 0.5) + self.assertEqual(head.topk, None) + self.assertEqual(losses.keys(), {'loss'}) + self.assertGreater(losses['loss'].item(), 0) + + def test_predict(self): + feats = (torch.rand(4, 10), ) + data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)] + head = MODELS.build(self.DEFAULT_ARGS) + + # with without data_samples + predictions = head.predict(feats) + self.assertTrue(is_seq_of(predictions, ClsDataSample)) + for pred in predictions: + self.assertIn('label', pred.pred_label) + self.assertIn('score', pred.pred_label) + + # with with data_samples + predictions = head.predict(feats, data_samples) + self.assertTrue(is_seq_of(predictions, ClsDataSample)) + for sample, pred in zip(data_samples, predictions): + self.assertIs(sample, pred) + self.assertIn('label', pred.pred_label) + self.assertIn('score', pred.pred_label) + + # Test with topk + cfg = copy.deepcopy(self.DEFAULT_ARGS) + cfg['topk'] = 2 + head = MODELS.build(cfg) + predictions = head.predict(feats, data_samples) + self.assertEqual(head.thr, None) + self.assertTrue(is_seq_of(predictions, ClsDataSample)) + for sample, pred in zip(data_samples, predictions): + self.assertIs(sample, pred) + self.assertIn('label', pred.pred_label) + self.assertIn('score', pred.pred_label) -@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )]) -def test_multilabel_linear_head(feat): - head = MultiLabelLinearClsHead(10, 5) - fake_gt_label = torch.randint(0, 2, (4, 10)) +class TestMultiLabelLinearClsHead(TestMultiLabelClsHead): + DEFAULT_ARGS = dict( + type='MultiLabelLinearClsHead', num_classes=10, in_channels=10) - head.init_weights() - losses = head.forward_train(feat, fake_gt_label) - assert losses['loss'].item() > 0 + def test_forward(self): + head = MODELS.build(self.DEFAULT_ARGS) + self.assertTrue(hasattr(head, 'fc')) + self.assertTrue(isinstance(head.fc, torch.nn.Linear)) - # test simple_test with post_process - pred = head.simple_test(feat) - assert isinstance(pred, list) and len(pred) == 4 - with patch('torch.onnx.is_in_onnx_export', return_value=True): - pred = head.simple_test(feat) - assert pred.shape == (4, 10) - - # test simple_test without post_process - pred = head.simple_test(feat, post_process=False) - assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) - logits = head.simple_test(feat, sigmoid=False, post_process=False) - torch.testing.assert_allclose(pred, torch.sigmoid(logits)) - - # test pre_logits - features = head.pre_logits(feat) - if isinstance(feat, tuple): - torch.testing.assert_allclose(features, feat[0]) - else: - torch.testing.assert_allclose(features, feat) - - -@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )]) -def test_stacked_linear_cls_head(feat): - # test assertion - with pytest.raises(AssertionError): - StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10) - - with pytest.raises(AssertionError): - StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10]) - - fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes - - # test forward with default setting - head = StackedLinearClsHead( - num_classes=10, in_channels=5, mid_channels=[20]) - head.init_weights() - - losses = head.forward_train(feat, fake_gt_label) - assert losses['loss'].item() > 0 - - # test simple_test with post_process - pred = head.simple_test(feat) - assert isinstance(pred, list) and len(pred) == 4 - with patch('torch.onnx.is_in_onnx_export', return_value=True): - pred = head.simple_test(feat) - assert pred.shape == (4, 10) - - # test simple_test without post_process - pred = head.simple_test(feat, post_process=False) - assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10) - logits = head.simple_test(feat, softmax=False, post_process=False) - torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1)) - - # test pre_logits - features = head.pre_logits(feat) - assert features.shape == (4, 20) - - # test forward with full function - head = StackedLinearClsHead( - num_classes=3, - in_channels=5, - mid_channels=[8, 10], - dropout_rate=0.2, - norm_cfg=dict(type='BN1d'), - act_cfg=dict(type='HSwish')) - head.init_weights() - - losses = head.forward_train(feat, fake_gt_label) - assert losses['loss'].item() > 0 - -""" + # return the last item (same as pre_logits) + feats = (torch.rand(4, 10), torch.rand(4, 10)) + head(feats)