Add multi_label heads
parent
4fcd7ee072
commit
e9342d9e4c
|
@ -407,7 +407,7 @@ def _average_precision(pred: torch.Tensor,
|
|||
total_pos = tps[-1].item() # the last of tensor may change later
|
||||
|
||||
# Calculate cumulative tp&fp(pred_poss) case numbers
|
||||
pred_pos_nums = torch.arange(1, len(sorted_target) + 1)
|
||||
pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device)
|
||||
pred_pos_nums[pred_pos_nums < eps] = eps
|
||||
|
||||
tps[torch.logical_not(pos_inds)] = 0
|
||||
|
|
|
@ -3,7 +3,7 @@ from .cls_head import ClsHead
|
|||
from .conformer_head import ConformerHead
|
||||
from .deit_head import DeiTClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
from .stacked_head import StackedLinearClsHead
|
||||
from .vision_transformer_head import VisionTransformerClsHead
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MultiLabelClsHead(BaseHead):
|
||||
"""Classification head for multilabel task.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of classification loss. Defaults to
|
||||
dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||
thr (float, optional): Predictions with scores under the thresholds
|
||||
are considered as negative. Defaults to None.
|
||||
topk (int, optional): Predictions with the k-th highest scores are
|
||||
considered as positive. Defaults to None.
|
||||
init_cfg (dict, optional): The extra init config of layers.
|
||||
Defaults to None.
|
||||
|
||||
Notes:
|
||||
If both ``thr`` and ``topk`` are set, use ``thr` to determine
|
||||
positive predictions. If neither is set, use ``thr=0.5`` as
|
||||
default.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
|
||||
thr: Optional[float] = None,
|
||||
topk: Optional[int] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
if thr is None and topk is None:
|
||||
thr = 0.5
|
||||
|
||||
self.thr = thr
|
||||
self.topk = topk
|
||||
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``MultiLabelClsHead``, we just obtain
|
||||
the feature of the last stage.
|
||||
"""
|
||||
# The MultiLabelClsHead doesn't have other module, just return after
|
||||
# unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
pre_logits = self.pre_logits(feats)
|
||||
# The MultiLabelClsHead doesn't have the final classification head,
|
||||
# just return the unpacked inputs.
|
||||
return pre_logits
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
cls_score = self(feats)
|
||||
|
||||
# The part can not be traced by torch.fx
|
||||
losses = self._get_loss(cls_score, data_samples, **kwargs)
|
||||
return losses
|
||||
|
||||
def _get_loss(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample], **kwargs):
|
||||
"""Unpack data samples and compute loss."""
|
||||
num_classes = cls_score.size()[-1]
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
target = torch.stack(
|
||||
[i.gt_label.score.float() for i in data_samples])
|
||||
else:
|
||||
target = torch.stack([
|
||||
LabelData.label_to_onehot(i.gt_label.label,
|
||||
num_classes).float()
|
||||
for i in data_samples
|
||||
])
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
loss = self.loss_module(
|
||||
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
|
||||
losses['loss'] = loss
|
||||
|
||||
return losses
|
||||
|
||||
def predict(
|
||||
self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: A list of data samples which contains the
|
||||
predicted results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
cls_score = self(feats)
|
||||
|
||||
# The part can not be traced by torch.fx
|
||||
predictions = self._get_predictions(cls_score, data_samples)
|
||||
return predictions
|
||||
|
||||
def _get_predictions(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]):
|
||||
"""Post-process the output of head.
|
||||
|
||||
Including softmax and set ``pred_label`` of data samples.
|
||||
"""
|
||||
pred_scores = torch.sigmoid(cls_score)
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))]
|
||||
|
||||
for data_sample, score in zip(data_samples, pred_scores):
|
||||
if self.thr is not None:
|
||||
# a label is predicted positive if larger than thr
|
||||
label = torch.where(score >= self.thr)[0]
|
||||
else:
|
||||
# top-k labels will be predicted positive for any example
|
||||
_, label = score.topk(self.topk)
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
|
||||
return data_samples
|
|
@ -1,99 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import is_tracing
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MultiLabelClsHead(BaseHead):
|
||||
"""Classification head for multilabel task.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of classification loss.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
init_cfg=None):
|
||||
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(loss, dict)
|
||||
|
||||
self.compute_loss = MODELS.build(loss)
|
||||
|
||||
def loss(self, cls_score, gt_label):
|
||||
gt_label = gt_label.type_as(cls_score)
|
||||
num_samples = len(cls_score)
|
||||
losses = dict()
|
||||
|
||||
# map difficult examples to positive ones
|
||||
_gt_label = torch.abs(gt_label)
|
||||
# compute loss
|
||||
loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples)
|
||||
losses['loss'] = loss
|
||||
return losses
|
||||
|
||||
def forward_train(self, cls_score, gt_label, **kwargs):
|
||||
if isinstance(cls_score, tuple):
|
||||
cls_score = cls_score[-1]
|
||||
gt_label = gt_label.type_as(cls_score)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
'The input of MultiLabelClsHead should be already logits. '
|
||||
'Please modify the backbone if you want to get pre-logits feature.'
|
||||
)
|
||||
return x
|
||||
|
||||
def simple_test(self, x, sigmoid=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
cls_score (tuple[Tensor]): The input classification score logits.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
sigmoid (bool): Whether to sigmoid the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
|
||||
if sigmoid:
|
||||
pred = torch.sigmoid(x) if x is not None else None
|
||||
else:
|
||||
pred = x
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def post_process(self, pred):
|
||||
on_trace = is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -11,75 +13,54 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
"""Linear classification head for multilabel task.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
loss (dict): Config of classification loss.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
loss (dict): Config of classification loss. Defaults to
|
||||
dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||
thr (float, optional): Predictions with scores under the thresholds
|
||||
are considered as negative. Defaults to None.
|
||||
topk (int, optional): Predictions with the k-th highest scores are
|
||||
considered as positive. Defaults to None.
|
||||
init_cfg (dict, optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
|
||||
Notes:
|
||||
If both ``thr`` and ``topk`` are set, use ``thr` to determine
|
||||
positive predictions. If neither is set, use ``thr=0.5`` as
|
||||
default.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
loss=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
|
||||
thr: Optional[float] = None,
|
||||
topk: Optional[int] = None,
|
||||
init_cfg: Optional[dict] = dict(
|
||||
type='Normal', layer='Linear', std=0.01)):
|
||||
super(MultiLabelLinearClsHead, self).__init__(
|
||||
loss=loss, init_cfg=init_cfg)
|
||||
loss=loss, thr=thr, topk=topk, init_cfg=init_cfg)
|
||||
|
||||
if num_classes <= 0:
|
||||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \
|
||||
'positive integer.'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
return x
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
x = self.pre_logits(x)
|
||||
gt_label = gt_label.type_as(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, x, sigmoid=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, in_channels)``.
|
||||
sigmoid (bool): Whether to sigmoid the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just
|
||||
obtain the feature of the last stage.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
# The obtain the MultiLabelLinearClsHead doesn't have other module,
|
||||
# just return after unpacking.
|
||||
return feats[-1]
|
||||
|
||||
if sigmoid:
|
||||
pred = torch.sigmoid(cls_score) if cls_score is not None else None
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
pre_logits = self.pre_logits(feats)
|
||||
# The final classification head.
|
||||
cls_score = self.fc(pre_logits)
|
||||
return cls_score
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import random
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import is_seq_of
|
||||
|
||||
|
@ -11,6 +14,14 @@ from mmcls.utils import register_all_modules
|
|||
register_all_modules()
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
class TestClsHead(TestCase):
|
||||
DEFAULT_ARGS = dict(type='ClsHead')
|
||||
|
||||
|
@ -305,113 +316,124 @@ class TestStackedLinearClsHead(TestCase):
|
|||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
|
||||
def test_multilabel_head(feat):
|
||||
head = MultiLabelClsHead()
|
||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
||||
class TestMultiLabelClsHead(TestCase):
|
||||
DEFAULT_ARGS = dict(type='MultiLabelClsHead')
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
def test_pre_logits(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
# return the last item
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
pre_logits = head.pre_logits(feats)
|
||||
self.assertIs(pre_logits, feats[-1])
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, sigmoid=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
|
||||
def test_forward(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
# return the last item (same as pre_logits)
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
outs = head(feats)
|
||||
self.assertIs(outs, feats[-1])
|
||||
|
||||
def test_loss(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
|
||||
|
||||
# Test with thr and topk are all None
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(head.thr, 0.5)
|
||||
self.assertEqual(head.topk, None)
|
||||
self.assertEqual(losses.keys(), {'loss'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
||||
# Test with topk
|
||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||
cfg['topk'] = 2
|
||||
head = MODELS.build(cfg)
|
||||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(head.thr, None, cfg)
|
||||
self.assertEqual(head.topk, 2)
|
||||
self.assertEqual(losses.keys(), {'loss'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
||||
# Test with thr
|
||||
setup_seed(0)
|
||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||
cfg['thr'] = 0.1
|
||||
head = MODELS.build(cfg)
|
||||
thr_losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(head.thr, 0.1)
|
||||
self.assertEqual(head.topk, None)
|
||||
self.assertEqual(thr_losses.keys(), {'loss'})
|
||||
self.assertGreater(thr_losses['loss'].item(), 0)
|
||||
|
||||
# Test with thr and topk are all not None
|
||||
setup_seed(0)
|
||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||
cfg['thr'] = 0.1
|
||||
cfg['topk'] = 2
|
||||
head = MODELS.build(cfg)
|
||||
thr_topk_losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(head.thr, 0.1)
|
||||
self.assertEqual(head.topk, 2)
|
||||
self.assertEqual(thr_topk_losses.keys(), {'loss'})
|
||||
self.assertGreater(thr_topk_losses['loss'].item(), 0)
|
||||
|
||||
# Test with gt_lable with score
|
||||
data_samples = [
|
||||
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
|
||||
]
|
||||
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(head.thr, 0.5)
|
||||
self.assertEqual(head.topk, None)
|
||||
self.assertEqual(losses.keys(), {'loss'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
||||
def test_predict(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# with without data_samples
|
||||
predictions = head.predict(feats)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
for pred in predictions:
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(feats, data_samples)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
|
||||
# Test with topk
|
||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||
cfg['topk'] = 2
|
||||
head = MODELS.build(cfg)
|
||||
predictions = head.predict(feats, data_samples)
|
||||
self.assertEqual(head.thr, None)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
|
||||
def test_multilabel_linear_head(feat):
|
||||
head = MultiLabelLinearClsHead(10, 5)
|
||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
||||
class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
|
||||
DEFAULT_ARGS = dict(
|
||||
type='MultiLabelLinearClsHead', num_classes=10, in_channels=10)
|
||||
|
||||
head.init_weights()
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
def test_forward(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
self.assertTrue(hasattr(head, 'fc'))
|
||||
self.assertTrue(isinstance(head.fc, torch.nn.Linear))
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, sigmoid=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
|
||||
def test_stacked_linear_cls_head(feat):
|
||||
# test assertion
|
||||
with pytest.raises(AssertionError):
|
||||
StackedLinearClsHead(num_classes=3, in_channels=5, mid_channels=10)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
StackedLinearClsHead(num_classes=-1, in_channels=5, mid_channels=[10])
|
||||
|
||||
fake_gt_label = torch.randint(0, 2, (4, )) # B, num_classes
|
||||
|
||||
# test forward with default setting
|
||||
head = StackedLinearClsHead(
|
||||
num_classes=10, in_channels=5, mid_channels=[20])
|
||||
head.init_weights()
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
assert features.shape == (4, 20)
|
||||
|
||||
# test forward with full function
|
||||
head = StackedLinearClsHead(
|
||||
num_classes=3,
|
||||
in_channels=5,
|
||||
mid_channels=[8, 10],
|
||||
dropout_rate=0.2,
|
||||
norm_cfg=dict(type='BN1d'),
|
||||
act_cfg=dict(type='HSwish'))
|
||||
head.init_weights()
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
"""
|
||||
# return the last item (same as pre_logits)
|
||||
feats = (torch.rand(4, 10), torch.rand(4, 10))
|
||||
head(feats)
|
||||
|
|
Loading…
Reference in New Issue