152 lines
5.8 KiB
Python
152 lines
5.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from mmcls.evaluation.metrics import Accuracy
|
|
from mmcls.registry import MODELS
|
|
from mmcls.structures import ClsDataSample
|
|
from .base_head import BaseHead
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ClsHead(BaseHead):
|
|
"""Classification head.
|
|
|
|
Args:
|
|
loss (dict): Config of classification loss. Defaults to
|
|
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
|
|
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
|
|
cal_acc (bool): Whether to calculate accuracy during training.
|
|
If you use batch augmentations like Mixup and CutMix during
|
|
training, it is pointless to calculate accuracy.
|
|
Defaults to False.
|
|
init_cfg (dict, optional): the config to control the initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
|
topk: Union[int, Tuple[int]] = (1, ),
|
|
cal_acc: bool = False,
|
|
init_cfg: Optional[dict] = None):
|
|
super(ClsHead, self).__init__(init_cfg=init_cfg)
|
|
|
|
self.topk = topk
|
|
self.loss_module = MODELS.build(loss)
|
|
self.cal_acc = cal_acc
|
|
|
|
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
"""The process before the final classification head.
|
|
|
|
The input ``feats`` is a tuple of tensor, and each tensor is the
|
|
feature of a backbone stage. In ``ClsHead``, we just obtain the feature
|
|
of the last stage.
|
|
"""
|
|
# The ClsHead doesn't have other module, just return after unpacking.
|
|
return feats[-1]
|
|
|
|
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
"""The forward process."""
|
|
pre_logits = self.pre_logits(feats)
|
|
# The ClsHead doesn't have the final classification head,
|
|
# just return the unpacked inputs.
|
|
return pre_logits
|
|
|
|
def loss(self, feats: Tuple[torch.Tensor],
|
|
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
|
"""Calculate losses from the classification score.
|
|
|
|
Args:
|
|
feats (tuple[Tensor]): The features extracted from the backbone.
|
|
Multiple stage inputs are acceptable but only the last stage
|
|
will be used to classify. The shape of every item should be
|
|
``(num_samples, num_classes)``.
|
|
data_samples (List[ClsDataSample]): The annotation data of
|
|
every samples.
|
|
**kwargs: Other keyword arguments to forward the loss module.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
# The part can be traced by torch.fx
|
|
cls_score = self(feats)
|
|
|
|
# The part can not be traced by torch.fx
|
|
losses = self._get_loss(cls_score, data_samples, **kwargs)
|
|
return losses
|
|
|
|
def _get_loss(self, cls_score: torch.Tensor,
|
|
data_samples: List[ClsDataSample], **kwargs):
|
|
"""Unpack data samples and compute loss."""
|
|
# Unpack data samples and pack targets
|
|
if 'score' in data_samples[0].gt_label:
|
|
# Batch augmentation may convert labels to one-hot format scores.
|
|
target = torch.stack([i.gt_label.score for i in data_samples])
|
|
else:
|
|
target = torch.hstack([i.gt_label.label for i in data_samples])
|
|
|
|
# compute loss
|
|
losses = dict()
|
|
loss = self.loss_module(
|
|
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
|
|
losses['loss'] = loss
|
|
|
|
# compute accuracy
|
|
if self.cal_acc:
|
|
assert target.ndim == 1, 'If you enable batch augmentation ' \
|
|
'like mixup during training, `cal_acc` is pointless.'
|
|
acc = Accuracy.calculate(cls_score, target, topk=self.topk)
|
|
losses.update(
|
|
{f'accuracy_top-{k}': a
|
|
for k, a in zip(self.topk, acc)})
|
|
|
|
return losses
|
|
|
|
def predict(
|
|
self,
|
|
feats: Tuple[torch.Tensor],
|
|
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
|
"""Inference without augmentation.
|
|
|
|
Args:
|
|
feats (tuple[Tensor]): The features extracted from the backbone.
|
|
Multiple stage inputs are acceptable but only the last stage
|
|
will be used to classify. The shape of every item should be
|
|
``(num_samples, num_classes)``.
|
|
data_samples (List[ClsDataSample], optional): The annotation
|
|
data of every samples. If not None, set ``pred_label`` of
|
|
the input data samples. Defaults to None.
|
|
|
|
Returns:
|
|
List[ClsDataSample]: A list of data samples which contains the
|
|
predicted results.
|
|
"""
|
|
# The part can be traced by torch.fx
|
|
cls_score = self(feats)
|
|
|
|
# The part can not be traced by torch.fx
|
|
predictions = self._get_predictions(cls_score, data_samples)
|
|
return predictions
|
|
|
|
def _get_predictions(self, cls_score, data_samples):
|
|
"""Post-process the output of head.
|
|
|
|
Including softmax and set ``pred_label`` of data samples.
|
|
"""
|
|
pred_scores = F.softmax(cls_score, dim=1)
|
|
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
|
|
|
|
if data_samples is not None:
|
|
for data_sample, score, label in zip(data_samples, pred_scores,
|
|
pred_labels):
|
|
data_sample.set_pred_score(score).set_pred_label(label)
|
|
else:
|
|
data_samples = []
|
|
for score, label in zip(pred_scores, pred_labels):
|
|
data_samples.append(ClsDataSample().set_pred_score(
|
|
score).set_pred_label(label))
|
|
|
|
return data_samples
|