Add multi_label heads

pull/913/head
Ezra-Yu 2022-07-08 06:51:09 +00:00 committed by mzr1996
parent 4fcd7ee072
commit e9342d9e4c
6 changed files with 322 additions and 261 deletions

View File

@ -407,7 +407,7 @@ def _average_precision(pred: torch.Tensor,
total_pos = tps[-1].item() # the last of tensor may change later total_pos = tps[-1].item() # the last of tensor may change later
# Calculate cumulative tp&fp(pred_poss) case numbers # Calculate cumulative tp&fp(pred_poss) case numbers
pred_pos_nums = torch.arange(1, len(sorted_target) + 1) pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device)
pred_pos_nums[pred_pos_nums < eps] = eps pred_pos_nums[pred_pos_nums < eps] = eps
tps[torch.logical_not(pos_inds)] = 0 tps[torch.logical_not(pos_inds)] = 0

View File

@ -3,7 +3,7 @@ from .cls_head import ClsHead
from .conformer_head import ConformerHead from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead from .deit_head import DeiTClsHead
from .linear_head import LinearClsHead from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead from .multi_label_linear_head import MultiLabelLinearClsHead
from .stacked_head import StackedLinearClsHead from .stacked_head import StackedLinearClsHead
from .vision_transformer_head import VisionTransformerClsHead from .vision_transformer_head import VisionTransformerClsHead

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import torch
from mmengine.data import LabelData
from mmcls.core import ClsDataSample
from mmcls.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class MultiLabelClsHead(BaseHead):
"""Classification head for multilabel task.
Args:
loss (dict): Config of classification loss. Defaults to
dict(type='CrossEntropyLoss', use_sigmoid=True).
thr (float, optional): Predictions with scores under the thresholds
are considered as negative. Defaults to None.
topk (int, optional): Predictions with the k-th highest scores are
considered as positive. Defaults to None.
init_cfg (dict, optional): The extra init config of layers.
Defaults to None.
Notes:
If both ``thr`` and ``topk`` are set, use ``thr` to determine
positive predictions. If neither is set, use ``thr=0.5`` as
default.
"""
def __init__(self,
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
thr: Optional[float] = None,
topk: Optional[int] = None,
init_cfg: Optional[dict] = None):
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
if thr is None and topk is None:
thr = 0.5
self.thr = thr
self.topk = topk
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain
the feature of the last stage.
"""
# The MultiLabelClsHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The MultiLabelClsHead doesn't have the final classification head,
# just return the unpacked inputs.
return pre_logits
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# The part can be traced by torch.fx
cls_score = self(feats)
# The part can not be traced by torch.fx
losses = self._get_loss(cls_score, data_samples, **kwargs)
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
"""Unpack data samples and compute loss."""
num_classes = cls_score.size()[-1]
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
target = torch.stack(
[i.gt_label.score.float() for i in data_samples])
else:
target = torch.stack([
LabelData.label_to_onehot(i.gt_label.label,
num_classes).float()
for i in data_samples
])
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
cls_score = self(feats)
# The part can not be traced by torch.fx
predictions = self._get_predictions(cls_score, data_samples)
return predictions
def _get_predictions(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample]):
"""Post-process the output of head.
Including softmax and set ``pred_label`` of data samples.
"""
pred_scores = torch.sigmoid(cls_score)
if data_samples is None:
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))]
for data_sample, score in zip(data_samples, pred_scores):
if self.thr is not None:
# a label is predicted positive if larger than thr
label = torch.where(score >= self.thr)[0]
else:
# top-k labels will be predicted positive for any example
_, label = score.topk(self.topk)
data_sample.set_pred_score(score).set_pred_label(label)
return data_samples

View File

@ -1,99 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.registry import MODELS
from ..utils import is_tracing
from .base_head import BaseHead
@MODELS.register_module()
class MultiLabelClsHead(BaseHead):
"""Classification head for multilabel task.
Args:
loss (dict): Config of classification loss.
"""
def __init__(self,
loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
init_cfg=None):
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
assert isinstance(loss, dict)
self.compute_loss = MODELS.build(loss)
def loss(self, cls_score, gt_label):
gt_label = gt_label.type_as(cls_score)
num_samples = len(cls_score)
losses = dict()
# map difficult examples to positive ones
_gt_label = torch.abs(gt_label)
# compute loss
loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples)
losses['loss'] = loss
return losses
def forward_train(self, cls_score, gt_label, **kwargs):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
gt_label = gt_label.type_as(cls_score)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.warning(
'The input of MultiLabelClsHead should be already logits. '
'Please modify the backbone if you want to get pre-logits feature.'
)
return x
def simple_test(self, x, sigmoid=True, post_process=True):
"""Inference without augmentation.
Args:
cls_score (tuple[Tensor]): The input classification score logits.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
sigmoid (bool): Whether to sigmoid the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
if isinstance(x, tuple):
x = x[-1]
if sigmoid:
pred = torch.sigmoid(x) if x is not None else None
else:
pred = x
if post_process:
return self.post_process(pred)
else:
return pred
def post_process(self, pred):
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcls.registry import MODELS from mmcls.registry import MODELS
from .multi_label_head import MultiLabelClsHead from .multi_label_cls_head import MultiLabelClsHead
@MODELS.register_module() @MODELS.register_module()
@ -11,75 +13,54 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
"""Linear classification head for multilabel task. """Linear classification head for multilabel task.
Args: Args:
num_classes (int): Number of categories. loss (dict): Config of classification loss. Defaults to
in_channels (int): Number of channels in the input feature map. dict(type='CrossEntropyLoss', use_sigmoid=True).
loss (dict): Config of classification loss. thr (float, optional): Predictions with scores under the thresholds
init_cfg (dict | optional): The extra init config of layers. are considered as negative. Defaults to None.
topk (int, optional): Predictions with the k-th highest scores are
considered as positive. Defaults to None.
init_cfg (dict, optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01). Defaults to use dict(type='Normal', layer='Linear', std=0.01).
Notes:
If both ``thr`` and ``topk`` are set, use ``thr` to determine
positive predictions. If neither is set, use ``thr=0.5`` as
default.
""" """
def __init__(self, def __init__(self,
num_classes, num_classes: int,
in_channels, in_channels: int,
loss=dict( loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
type='CrossEntropyLoss', thr: Optional[float] = None,
use_sigmoid=True, topk: Optional[int] = None,
reduction='mean', init_cfg: Optional[dict] = dict(
loss_weight=1.0), type='Normal', layer='Linear', std=0.01)):
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
super(MultiLabelLinearClsHead, self).__init__( super(MultiLabelLinearClsHead, self).__init__(
loss=loss, init_cfg=init_cfg) loss=loss, thr=thr, topk=topk, init_cfg=init_cfg)
if num_classes <= 0: assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \
raise ValueError( 'positive integer.'
f'num_classes={num_classes} must be a positive integer')
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.fc = nn.Linear(self.in_channels, self.num_classes) self.fc = nn.Linear(self.in_channels, self.num_classes)
def pre_logits(self, x): def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
if isinstance(x, tuple): """The process before the final classification head.
x = x[-1]
return x
def forward_train(self, x, gt_label, **kwargs): The input ``feats`` is a tuple of tensor, and each tensor is the
x = self.pre_logits(x) feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just
gt_label = gt_label.type_as(x) obtain the feature of the last stage.
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, x, sigmoid=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[Tensor]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, in_channels)``.
sigmoid (bool): Whether to sigmoid the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
""" """
x = self.pre_logits(x) # The obtain the MultiLabelLinearClsHead doesn't have other module,
cls_score = self.fc(x) # just return after unpacking.
return feats[-1]
if sigmoid: def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
pred = torch.sigmoid(cls_score) if cls_score is not None else None """The forward process."""
else: pre_logits = self.pre_logits(feats)
pred = cls_score # The final classification head.
cls_score = self.fc(pre_logits)
if post_process: return cls_score
return self.post_process(pred)
else:
return pred

View File

@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
from unittest import TestCase from unittest import TestCase
import numpy as np
import torch import torch
from mmengine import is_seq_of from mmengine import is_seq_of
@ -11,6 +14,14 @@ from mmcls.utils import register_all_modules
register_all_modules() register_all_modules()
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
class TestClsHead(TestCase): class TestClsHead(TestCase):
DEFAULT_ARGS = dict(type='ClsHead') DEFAULT_ARGS = dict(type='ClsHead')
@ -305,113 +316,124 @@ class TestStackedLinearClsHead(TestCase):
self.assertEqual(outs.shape, (4, 5)) self.assertEqual(outs.shape, (4, 5))
"""Temporarily disabled. class TestMultiLabelClsHead(TestCase):
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )]) DEFAULT_ARGS = dict(type='MultiLabelClsHead')
def test_multilabel_head(feat):
head = MultiLabelClsHead()
fake_gt_label = torch.randint(0, 2, (4, 10))
losses = head.forward_train(feat, fake_gt_label) def test_pre_logits(self):
assert losses['loss'].item() > 0 head = MODELS.build(self.DEFAULT_ARGS)
# test simple_test with post_process # return the last item
pred = head.simple_test(feat) feats = (torch.rand(4, 10), torch.rand(4, 10))
assert isinstance(pred, list) and len(pred) == 4 pre_logits = head.pre_logits(feats)
with patch('torch.onnx.is_in_onnx_export', return_value=True): self.assertIs(pre_logits, feats[-1])
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process def test_forward(self):
pred = head.simple_test(feat, post_process=False) head = MODELS.build(self.DEFAULT_ARGS)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
# test pre_logits # return the last item (same as pre_logits)
features = head.pre_logits(feat) feats = (torch.rand(4, 10), torch.rand(4, 10))
if isinstance(feat, tuple): outs = head(feats)
torch.testing.assert_allclose(features, feat[0]) self.assertIs(outs, feats[-1])
else:
torch.testing.assert_allclose(features, feat) def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
# Test with thr and topk are all None
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, None, cfg)
self.assertEqual(head.topk, 2)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with thr
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
head = MODELS.build(cfg)
thr_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, None)
self.assertEqual(thr_losses.keys(), {'loss'})
self.assertGreater(thr_losses['loss'].item(), 0)
# Test with thr and topk are all not None
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
cfg['topk'] = 2
head = MODELS.build(cfg)
thr_topk_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, 2)
self.assertEqual(thr_topk_losses.keys(), {'loss'})
self.assertGreater(thr_topk_losses['loss'].item(), 0)
# Test with gt_lable with score
data_samples = [
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
]
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
predictions = head.predict(feats, data_samples)
self.assertEqual(head.thr, None)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )]) class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
def test_multilabel_linear_head(feat): DEFAULT_ARGS = dict(
head = MultiLabelLinearClsHead(10, 5) type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
fake_gt_label = torch.randint(0, 2, (4, 10))
head.init_weights() def test_forward(self):
losses = head.forward_train(feat, fake_gt_label) head = MODELS.build(self.DEFAULT_ARGS)
assert losses['loss'].item() > 0 self.assertTrue(hasattr(head, 'fc'))
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
# test simple_test with post_process # return the last item (same as pre_logits)
pred = head.simple_test(feat) feats = (torch.rand(4, 10), torch.rand(4, 10))
assert isinstance(pred, list) and len(pred) == 4 head(feats)
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_stacked_linear_cls_head(feat):
# test assertion
with pytest.raises(AssertionError):
StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10)
with pytest.raises(AssertionError):
StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10])
fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes
# test forward with default setting
head = StackedLinearClsHead(
num_classes=10, in_channels=5, mid_channels=[20])
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
features = head.pre_logits(feat)
assert features.shape == (4, 20)
# test forward with full function
head = StackedLinearClsHead(
num_classes=3,
in_channels=5,
mid_channels=[8, 10],
dropout_rate=0.2,
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='HSwish'))
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
"""