242 lines
9.7 KiB
Python
242 lines
9.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcls.registry import MODELS
|
|
from mmcls.structures import ClsDataSample
|
|
from .base import BaseClassifier
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ImageClassifier(BaseClassifier):
|
|
"""Image classifiers for supervised classification task.
|
|
|
|
Args:
|
|
backbone (dict): The backbone module. See
|
|
:mod:`mmcls.models.backbones`.
|
|
neck (dict, optional): The neck module to process features from
|
|
backbone. See :mod:`mmcls.models.necks`. Defaults to None.
|
|
head (dict, optional): The head module to do prediction and calculate
|
|
loss from processed features. See :mod:`mmcls.models.heads`.
|
|
Notice that if the head is not set, almost all methods cannot be
|
|
used except :meth:`extract_feat`. Defaults to None.
|
|
pretrained (str, optional): The pretrained checkpoint path, support
|
|
local path and remote path. Defaults to None.
|
|
train_cfg (dict, optional): The training setting. The acceptable
|
|
fields are:
|
|
|
|
- augments (List[dict]): The batch augmentation methods to use.
|
|
More details can be found in :mod:`mmcls.model.utils.augment`.
|
|
|
|
Defaults to None.
|
|
data_preprocessor (dict, optional): The config for preprocessing input
|
|
data. If None or no specified type, it will use
|
|
"ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
|
|
more details. Defaults to None.
|
|
init_cfg (dict, optional): the config to control the initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone: dict,
|
|
neck: Optional[dict] = None,
|
|
head: Optional[dict] = None,
|
|
pretrained: Optional[str] = None,
|
|
train_cfg: Optional[dict] = None,
|
|
data_preprocessor: Optional[dict] = None,
|
|
init_cfg: Optional[dict] = None):
|
|
if pretrained is not None:
|
|
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
|
|
|
if data_preprocessor is None:
|
|
data_preprocessor = {}
|
|
# The build process is in MMEngine, so we need to add scope here.
|
|
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
|
|
|
|
if train_cfg is not None and 'augments' in train_cfg:
|
|
# Set batch augmentations by `train_cfg`
|
|
data_preprocessor['batch_augments'] = train_cfg
|
|
|
|
super(ImageClassifier, self).__init__(
|
|
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
|
|
|
if not isinstance(backbone, nn.Module):
|
|
backbone = MODELS.build(backbone)
|
|
if neck is not None and not isinstance(neck, nn.Module):
|
|
neck = MODELS.build(neck)
|
|
if head is not None and not isinstance(head, nn.Module):
|
|
head = MODELS.build(head)
|
|
|
|
self.backbone = backbone
|
|
self.neck = neck
|
|
self.head = head
|
|
|
|
def forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[List[ClsDataSample]] = None,
|
|
mode: str = 'tensor'):
|
|
"""The unified entry for a forward process in both training and test.
|
|
|
|
The method should accept three modes: "tensor", "predict" and "loss":
|
|
|
|
- "tensor": Forward the whole network and return tensor or tuple of
|
|
tensor without any post-processing, same as a common nn.Module.
|
|
- "predict": Forward and return the predictions, which are fully
|
|
processed to a list of :obj:`ClsDataSample`.
|
|
- "loss": Forward and return a dict of losses according to the given
|
|
inputs and data samples.
|
|
|
|
Note that this method doesn't handle neither back propagation nor
|
|
optimizer updating, which are done in the :meth:`train_step`.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor with shape
|
|
(N, C, ...) in general.
|
|
data_samples (List[ClsDataSample], optional): The annotation
|
|
data of every samples. It's required if ``mode="loss"``.
|
|
Defaults to None.
|
|
mode (str): Return what kind of value. Defaults to 'tensor'.
|
|
|
|
Returns:
|
|
The return type depends on ``mode``.
|
|
|
|
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
|
- If ``mode="predict"``, return a list of
|
|
:obj:`mmcls.structures.ClsDataSample`.
|
|
- If ``mode="loss"``, return a dict of tensor.
|
|
"""
|
|
if mode == 'tensor':
|
|
feats = self.extract_feat(inputs)
|
|
return self.head(feats) if self.with_head else feats
|
|
elif mode == 'loss':
|
|
return self.loss(inputs, data_samples)
|
|
elif mode == 'predict':
|
|
return self.predict(inputs, data_samples)
|
|
else:
|
|
raise RuntimeError(f'Invalid mode "{mode}".')
|
|
|
|
def extract_feat(self, inputs, stage='neck'):
|
|
"""Extract features from the input tensor with shape (N, C, ...).
|
|
|
|
Args:
|
|
inputs (Tensor): A batch of inputs. The shape of it should be
|
|
``(num_samples, num_channels, *img_shape)``.
|
|
stage (str): Which stage to output the feature. Choose from:
|
|
|
|
- "backbone": The output of backbone network. Returns a tuple
|
|
including multiple stages features.
|
|
- "neck": The output of neck module. Returns a tuple including
|
|
multiple stages features.
|
|
- "pre_logits": The feature before the final classification
|
|
linear layer. Usually returns a tensor.
|
|
|
|
Defaults to "neck".
|
|
|
|
Returns:
|
|
tuple | Tensor: The output of specified stage.
|
|
The output depends on detailed implementation. In general, the
|
|
output of backbone and neck is a tuple and the output of
|
|
pre_logits is a tensor.
|
|
|
|
Examples:
|
|
1. Backbone output
|
|
|
|
>>> import torch
|
|
>>> from mmengine import Config
|
|
>>> from mmcls.models import build_classifier
|
|
>>>
|
|
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
|
|
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
|
|
>>> model = build_classifier(cfg)
|
|
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
|
|
>>> for out in outs:
|
|
... print(out.shape)
|
|
torch.Size([1, 64, 56, 56])
|
|
torch.Size([1, 128, 28, 28])
|
|
torch.Size([1, 256, 14, 14])
|
|
torch.Size([1, 512, 7, 7])
|
|
|
|
2. Neck output
|
|
|
|
>>> import torch
|
|
>>> from mmengine import Config
|
|
>>> from mmcls.models import build_classifier
|
|
>>>
|
|
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
|
|
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
|
|
>>> model = build_classifier(cfg)
|
|
>>>
|
|
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
|
|
>>> for out in outs:
|
|
... print(out.shape)
|
|
torch.Size([1, 64])
|
|
torch.Size([1, 128])
|
|
torch.Size([1, 256])
|
|
torch.Size([1, 512])
|
|
|
|
3. Pre-logits output (without the final linear classifier head)
|
|
|
|
>>> import torch
|
|
>>> from mmengine import Config
|
|
>>> from mmcls.models import build_classifier
|
|
>>>
|
|
>>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model
|
|
>>> model = build_classifier(cfg)
|
|
>>>
|
|
>>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
|
|
>>> print(out.shape) # The hidden dims in head is 3072
|
|
torch.Size([1, 3072])
|
|
""" # noqa: E501
|
|
assert stage in ['backbone', 'neck', 'pre_logits'], \
|
|
(f'Invalid output stage "{stage}", please choose from "backbone", '
|
|
'"neck" and "pre_logits"')
|
|
|
|
x = self.backbone(inputs)
|
|
|
|
if stage == 'backbone':
|
|
return x
|
|
|
|
if self.with_neck:
|
|
x = self.neck(x)
|
|
if stage == 'neck':
|
|
return x
|
|
|
|
assert self.with_head and hasattr(self.head, 'pre_logits'), \
|
|
"No head or the head doesn't implement `pre_logits` method."
|
|
return self.head.pre_logits(x)
|
|
|
|
def loss(self, inputs: torch.Tensor,
|
|
data_samples: List[ClsDataSample]) -> dict:
|
|
"""Calculate losses from a batch of inputs and data samples.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor with shape
|
|
(N, C, ...) in general.
|
|
data_samples (List[ClsDataSample]): The annotation data of
|
|
every samples.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
feats = self.extract_feat(inputs)
|
|
return self.head.loss(feats, data_samples)
|
|
|
|
def predict(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[List[ClsDataSample]] = None,
|
|
**kwargs) -> List[ClsDataSample]:
|
|
"""Predict results from a batch of inputs.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor with shape
|
|
(N, C, ...) in general.
|
|
data_samples (List[ClsDataSample], optional): The annotation
|
|
data of every samples. Defaults to None.
|
|
**kwargs: Other keyword arguments accepted by the ``predict``
|
|
method of :attr:`head`.
|
|
"""
|
|
feats = self.extract_feat(inputs)
|
|
return self.head.predict(feats, data_samples, **kwargs)
|