109 lines
4.2 KiB
Python
109 lines
4.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from abc import ABCMeta, abstractmethod
|
|
from typing import List, Optional, Sequence
|
|
|
|
import torch
|
|
from mmengine.model import BaseModel
|
|
from mmengine.structures import BaseDataElement
|
|
|
|
|
|
class BaseClassifier(BaseModel, metaclass=ABCMeta):
|
|
"""Base class for classifiers.
|
|
|
|
Args:
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
Defaults to None.
|
|
data_preprocessor (dict, optional): The config for preprocessing input
|
|
data. If None, it will use "BaseDataPreprocessor" as type, see
|
|
:class:`mmengine.model.BaseDataPreprocessor` for more details.
|
|
Defaults to None.
|
|
|
|
Attributes:
|
|
init_cfg (dict): Initialization config dict.
|
|
data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An
|
|
extra data pre-processing module, which processes data from
|
|
dataloader to the format accepted by :meth:`forward`.
|
|
"""
|
|
|
|
def __init__(self,
|
|
init_cfg: Optional[dict] = None,
|
|
data_preprocessor: Optional[dict] = None):
|
|
super(BaseClassifier, self).__init__(
|
|
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
|
|
|
@property
|
|
def with_neck(self) -> bool:
|
|
"""Whether the classifier has a neck."""
|
|
return hasattr(self, 'neck') and self.neck is not None
|
|
|
|
@property
|
|
def with_head(self) -> bool:
|
|
"""Whether the classifier has a head."""
|
|
return hasattr(self, 'head') and self.head is not None
|
|
|
|
@abstractmethod
|
|
def forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[List[BaseDataElement]] = None,
|
|
mode: str = 'tensor'):
|
|
"""The unified entry for a forward process in both training and test.
|
|
|
|
The method should accept three modes: "tensor", "predict" and "loss":
|
|
|
|
- "tensor": Forward the whole network and return tensor or tuple of
|
|
tensor without any post-processing, same as a common nn.Module.
|
|
- "predict": Forward and return the predictions, which are fully
|
|
processed to a list of :obj:`BaseDataElement`.
|
|
- "loss": Forward and return a dict of losses according to the given
|
|
inputs and data samples.
|
|
|
|
Note that this method doesn't handle neither back propagation nor
|
|
optimizer updating, which are done in the :meth:`train_step`.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor with shape (N, C, ...)
|
|
in general.
|
|
data_samples (List[BaseDataElement], optional): The annotation
|
|
data of every samples. It's required if ``mode="loss"``.
|
|
Defaults to None.
|
|
mode (str): Return what kind of value. Defaults to 'tensor'.
|
|
|
|
Returns:
|
|
The return type depends on ``mode``.
|
|
|
|
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
|
- If ``mode="predict"``, return a list of
|
|
:obj:`mmengine.BaseDataElement`.
|
|
- If ``mode="loss"``, return a dict of tensor.
|
|
"""
|
|
pass
|
|
|
|
def extract_feat(self, inputs: torch.Tensor):
|
|
"""Extract features from the input tensor with shape (N, C, ...).
|
|
|
|
The sub-classes are recommended to implement this method to extract
|
|
features from backbone and neck.
|
|
|
|
Args:
|
|
inputs (Tensor): A batch of inputs. The shape of it should be
|
|
``(num_samples, num_channels, *img_shape)``.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def extract_feats(self, multi_inputs: Sequence[torch.Tensor],
|
|
**kwargs) -> list:
|
|
"""Extract features from a sequence of input tensor.
|
|
|
|
Args:
|
|
multi_inputs (Sequence[torch.Tensor]): A sequence of input
|
|
tensor. It can be used in augmented inference.
|
|
**kwargs: Other keyword arguments accepted by :meth:`extract_feat`.
|
|
|
|
Returns:
|
|
list: Features of every input tensor.
|
|
"""
|
|
assert isinstance(multi_inputs, Sequence), \
|
|
'`extract_feats` is used for a sequence of inputs tensor. If you '\
|
|
'want to extract on single inputs tensor, use `extract_feat`.'
|
|
return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs]
|