67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcls.registry import MODELS
|
|
from .multi_label_cls_head import MultiLabelClsHead
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|
"""Linear classification head for multilabel task.
|
|
|
|
Args:
|
|
loss (dict): Config of classification loss. Defaults to
|
|
dict(type='CrossEntropyLoss', use_sigmoid=True).
|
|
thr (float, optional): Predictions with scores under the thresholds
|
|
are considered as negative. Defaults to None.
|
|
topk (int, optional): Predictions with the k-th highest scores are
|
|
considered as positive. Defaults to None.
|
|
init_cfg (dict, optional): The extra init config of layers.
|
|
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
|
|
|
Notes:
|
|
If both ``thr`` and ``topk`` are set, use ``thr` to determine
|
|
positive predictions. If neither is set, use ``thr=0.5`` as
|
|
default.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_classes: int,
|
|
in_channels: int,
|
|
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
|
|
thr: Optional[float] = None,
|
|
topk: Optional[int] = None,
|
|
init_cfg: Optional[dict] = dict(
|
|
type='Normal', layer='Linear', std=0.01)):
|
|
super(MultiLabelLinearClsHead, self).__init__(
|
|
loss=loss, thr=thr, topk=topk, init_cfg=init_cfg)
|
|
|
|
assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \
|
|
'positive integer.'
|
|
|
|
self.in_channels = in_channels
|
|
self.num_classes = num_classes
|
|
|
|
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
|
|
|
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
"""The process before the final classification head.
|
|
|
|
The input ``feats`` is a tuple of tensor, and each tensor is the
|
|
feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just
|
|
obtain the feature of the last stage.
|
|
"""
|
|
# The obtain the MultiLabelLinearClsHead doesn't have other module,
|
|
# just return after unpacking.
|
|
return feats[-1]
|
|
|
|
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
"""The forward process."""
|
|
pre_logits = self.pre_logits(feats)
|
|
# The final classification head.
|
|
cls_score = self.fc(pre_logits)
|
|
return cls_score
|