mmpretrain/mmcls/models/utils/data_preprocessor.py

113 lines
4.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import List, Optional, Sequence, Tuple
import torch
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmcls.registry import MODELS
from .batch_augments import RandomBatchAugment
@MODELS.register_module()
class ClsDataPreprocessor(BaseDataPreprocessor):
"""Image pre-processor for classification tasks.
Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
batch_augments (dict, optional): The batch augmentations settings,
including "augments" and "probs". For more details, see
:class:`mmcls.models.RandomBatchAugment`.
"""
def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Number = 0,
to_rgb: bool = False,
batch_augments: Optional[List[dict]] = None):
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.to_rgb = to_rgb
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both `mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
if batch_augments is not None:
self.batch_augments = RandomBatchAugment(batch_augments)
else:
self.batch_augments = None
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, list]:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, list]: Data in the same format as the model
input.
"""
inputs, batch_data_samples = self.collate_data(data)
# --- Pad and stack --
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
self.pad_value)
# ------ To RGB ------
if self.to_rgb and batch_inputs.size(1) == 3:
batch_inputs = batch_inputs[:, [2, 1, 0], ...]
# -- Normalization ---
if self._enable_normalize:
batch_inputs = (batch_inputs - self.mean) / self.std
else:
batch_inputs = batch_inputs.to(torch.float32)
# ----- Batch Aug ----
if training and self.batch_augments is not None:
batch_inputs, batch_data_samples = self.batch_augments(
batch_inputs, batch_data_samples)
return batch_inputs, batch_data_samples