Add multi_label heads

pull/913/head
Ezra-Yu 2022-07-08 06:51:09 +00:00 committed by mzr1996
parent 4fcd7ee072
commit e9342d9e4c
6 changed files with 322 additions and 261 deletions

View File

@ -407,7 +407,7 @@ def _average_precision(pred: torch.Tensor,
total_pos = tps[-1].item() # the last of tensor may change later
# Calculate cumulative tp&fp(pred_poss) case numbers
pred_pos_nums = torch.arange(1, len(sorted_target) + 1)
pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device)
pred_pos_nums[pred_pos_nums < eps] = eps
tps[torch.logical_not(pos_inds)] = 0

View File

@ -3,7 +3,7 @@ from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .stacked_head import StackedLinearClsHead
from .vision_transformer_head import VisionTransformerClsHead

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import torch
from mmengine.data import LabelData
from mmcls.core import ClsDataSample
from mmcls.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class MultiLabelClsHead(BaseHead):
"""Classification head for multilabel task.
Args:
loss (dict): Config of classification loss. Defaults to
dict(type='CrossEntropyLoss', use_sigmoid=True).
thr (float, optional): Predictions with scores under the thresholds
are considered as negative. Defaults to None.
topk (int, optional): Predictions with the k-th highest scores are
considered as positive. Defaults to None.
init_cfg (dict, optional): The extra init config of layers.
Defaults to None.
Notes:
If both ``thr`` and ``topk`` are set, use ``thr` to determine
positive predictions. If neither is set, use ``thr=0.5`` as
default.
"""
def __init__(self,
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
thr: Optional[float] = None,
topk: Optional[int] = None,
init_cfg: Optional[dict] = None):
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
if thr is None and topk is None:
thr = 0.5
self.thr = thr
self.topk = topk
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain
the feature of the last stage.
"""
# The MultiLabelClsHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The MultiLabelClsHead doesn't have the final classification head,
# just return the unpacked inputs.
return pre_logits
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# The part can be traced by torch.fx
cls_score = self(feats)
# The part can not be traced by torch.fx
losses = self._get_loss(cls_score, data_samples, **kwargs)
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
"""Unpack data samples and compute loss."""
num_classes = cls_score.size()[-1]
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
target = torch.stack(
[i.gt_label.score.float() for i in data_samples])
else:
target = torch.stack([
LabelData.label_to_onehot(i.gt_label.label,
num_classes).float()
for i in data_samples
])
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
cls_score = self(feats)
# The part can not be traced by torch.fx
predictions = self._get_predictions(cls_score, data_samples)
return predictions
def _get_predictions(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample]):
"""Post-process the output of head.
Including softmax and set ``pred_label`` of data samples.
"""
pred_scores = torch.sigmoid(cls_score)
if data_samples is None:
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))]
for data_sample, score in zip(data_samples, pred_scores):
if self.thr is not None:
# a label is predicted positive if larger than thr
label = torch.where(score >= self.thr)[0]
else:
# top-k labels will be predicted positive for any example
_, label = score.topk(self.topk)
data_sample.set_pred_score(score).set_pred_label(label)
return data_samples

View File

@ -1,99 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.registry import MODELS
from ..utils import is_tracing
from .base_head import BaseHead
@MODELS.register_module()
class MultiLabelClsHead(BaseHead):
"""Classification head for multilabel task.
Args:
loss (dict): Config of classification loss.
"""
def __init__(self,
loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
init_cfg=None):
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
assert isinstance(loss, dict)
self.compute_loss = MODELS.build(loss)
def loss(self, cls_score, gt_label):
gt_label = gt_label.type_as(cls_score)
num_samples = len(cls_score)
losses = dict()
# map difficult examples to positive ones
_gt_label = torch.abs(gt_label)
# compute loss
loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples)
losses['loss'] = loss
return losses
def forward_train(self, cls_score, gt_label, **kwargs):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
gt_label = gt_label.type_as(cls_score)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.warning(
'The input of MultiLabelClsHead should be already logits. '
'Please modify the backbone if you want to get pre-logits feature.'
)
return x
def simple_test(self, x, sigmoid=True, post_process=True):
"""Inference without augmentation.
Args:
cls_score (tuple[Tensor]): The input classification score logits.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
sigmoid (bool): Whether to sigmoid the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
if isinstance(x, tuple):
x = x[-1]
if sigmoid:
pred = torch.sigmoid(x) if x is not None else None
else:
pred = x
if post_process:
return self.post_process(pred)
else:
return pred
def post_process(self, pred):
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from mmcls.registry import MODELS
from .multi_label_head import MultiLabelClsHead
from .multi_label_cls_head import MultiLabelClsHead
@MODELS.register_module()
@ -11,75 +13,54 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
"""Linear classification head for multilabel task.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
loss (dict): Config of classification loss.
init_cfg (dict | optional): The extra init config of layers.
loss (dict): Config of classification loss. Defaults to
dict(type='CrossEntropyLoss', use_sigmoid=True).
thr (float, optional): Predictions with scores under the thresholds
are considered as negative. Defaults to None.
topk (int, optional): Predictions with the k-th highest scores are
considered as positive. Defaults to None.
init_cfg (dict, optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
Notes:
If both ``thr`` and ``topk`` are set, use ``thr` to determine
positive predictions. If neither is set, use ``thr=0.5`` as
default.
"""
def __init__(self,
num_classes,
in_channels,
loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
num_classes: int,
in_channels: int,
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
thr: Optional[float] = None,
topk: Optional[int] = None,
init_cfg: Optional[dict] = dict(
type='Normal', layer='Linear', std=0.01)):
super(MultiLabelLinearClsHead, self).__init__(
loss=loss, init_cfg=init_cfg)
loss=loss, thr=thr, topk=topk, init_cfg=init_cfg)
if num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \
'positive integer.'
self.in_channels = in_channels
self.num_classes = num_classes
self.fc = nn.Linear(self.in_channels, self.num_classes)
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
def forward_train(self, x, gt_label, **kwargs):
x = self.pre_logits(x)
gt_label = gt_label.type_as(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, x, sigmoid=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[Tensor]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, in_channels)``.
sigmoid (bool): Whether to sigmoid the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just
obtain the feature of the last stage.
"""
x = self.pre_logits(x)
cls_score = self.fc(x)
# The obtain the MultiLabelLinearClsHead doesn't have other module,
# just return after unpacking.
return feats[-1]
if sigmoid:
pred = torch.sigmoid(cls_score) if cls_score is not None else None
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The final classification head.
cls_score = self.fc(pre_logits)
return cls_score

View File

@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
from unittest import TestCase
import numpy as np
import torch
from mmengine import is_seq_of
@ -11,6 +14,14 @@ from mmcls.utils import register_all_modules
register_all_modules()
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
class TestClsHead(TestCase):
DEFAULT_ARGS = dict(type='ClsHead')
@ -305,113 +316,124 @@ class TestStackedLinearClsHead(TestCase):
self.assertEqual(outs.shape, (4, 5))
"""Temporarily disabled.
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
def test_multilabel_head(feat):
head = MultiLabelClsHead()
fake_gt_label = torch.randint(0, 2, (4, 10))
class TestMultiLabelClsHead(TestCase):
DEFAULT_ARGS = dict(type='MultiLabelClsHead')
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
def test_pre_logits(self):
head = MODELS.build(self.DEFAULT_ARGS)
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# return the last item
feats = (torch.rand(4, 10), torch.rand(4, 10))
pre_logits = head.pre_logits(feats)
self.assertIs(pre_logits, feats[-1])
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
outs = head(feats)
self.assertIs(outs, feats[-1])
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
# Test with thr and topk are all None
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, None, cfg)
self.assertEqual(head.topk, 2)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# Test with thr
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
head = MODELS.build(cfg)
thr_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, None)
self.assertEqual(thr_losses.keys(), {'loss'})
self.assertGreater(thr_losses['loss'].item(), 0)
# Test with thr and topk are all not None
setup_seed(0)
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['thr'] = 0.1
cfg['topk'] = 2
head = MODELS.build(cfg)
thr_topk_losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.1)
self.assertEqual(head.topk, 2)
self.assertEqual(thr_topk_losses.keys(), {'loss'})
self.assertGreater(thr_topk_losses['loss'].item(), 0)
# Test with gt_lable with score
data_samples = [
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
]
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(head.thr, 0.5)
self.assertEqual(head.topk, None)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
cfg['topk'] = 2
head = MODELS.build(cfg)
predictions = head.predict(feats, data_samples)
self.assertEqual(head.thr, None)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_multilabel_linear_head(feat):
head = MultiLabelLinearClsHead(10, 5)
fake_gt_label = torch.randint(0, 2, (4, 10))
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
DEFAULT_ARGS = dict(
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
self.assertTrue(hasattr(head, 'fc'))
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_stacked_linear_cls_head(feat):
# test assertion
with pytest.raises(AssertionError):
StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10)
with pytest.raises(AssertionError):
StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10])
fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes
# test forward with default setting
head = StackedLinearClsHead(
num_classes=10, in_channels=5, mid_channels=[20])
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
features = head.pre_logits(feat)
assert features.shape == (4, 20)
# test forward with full function
head = StackedLinearClsHead(
num_classes=3,
in_channels=5,
mid_channels=[8, 10],
dropout_rate=0.2,
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='HSwish'))
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
"""
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), torch.rand(4, 10))
head(feats)