147 lines
5.5 KiB
Python
147 lines
5.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
from numbers import Number
|
|
from typing import Optional, Sequence
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from mmengine.model import BaseDataPreprocessor, stack_batch
|
|
|
|
from mmcls.registry import MODELS
|
|
from .batch_augments import RandomBatchAugment
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ClsDataPreprocessor(BaseDataPreprocessor):
|
|
"""Image pre-processor for classification tasks.
|
|
|
|
Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
|
|
|
|
1. It won't do normalization if ``mean`` is not specified.
|
|
2. It does normalization and color space conversion after stacking batch.
|
|
3. It supports batch augmentations like mixup and cutmix.
|
|
|
|
It provides the data pre-processing as follows
|
|
|
|
- Collate and move data to the target device.
|
|
- Pad inputs to the maximum size of current batch with defined
|
|
``pad_value``. The padding size can be divisible by a defined
|
|
``pad_size_divisor``
|
|
- Stack inputs to batch_inputs.
|
|
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
|
- Normalize image with defined std and mean.
|
|
- Do batch augmentations like Mixup and Cutmix during training.
|
|
|
|
Args:
|
|
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
|
|
Defaults to None.
|
|
std (Sequence[Number], optional): The pixel standard deviation of
|
|
R, G, B channels. Defaults to None.
|
|
pad_size_divisor (int): The size of padded image should be
|
|
divisible by ``pad_size_divisor``. Defaults to 1.
|
|
pad_value (Number): The padded pixel value. Defaults to 0.
|
|
to_rgb (bool): whether to convert image from BGR to RGB.
|
|
Defaults to False.
|
|
batch_augments (dict, optional): The batch augmentations settings,
|
|
including "augments" and "probs". For more details, see
|
|
:class:`mmcls.models.RandomBatchAugment`.
|
|
"""
|
|
|
|
def __init__(self,
|
|
mean: Sequence[Number] = None,
|
|
std: Sequence[Number] = None,
|
|
pad_size_divisor: int = 1,
|
|
pad_value: Number = 0,
|
|
to_rgb: bool = False,
|
|
batch_augments: Optional[dict] = None):
|
|
super().__init__()
|
|
self.pad_size_divisor = pad_size_divisor
|
|
self.pad_value = pad_value
|
|
self.to_rgb = to_rgb
|
|
|
|
if mean is not None:
|
|
assert std is not None, 'To enable the normalization in ' \
|
|
'preprocessing, please specify both `mean` and `std`.'
|
|
# Enable the normalization in preprocessing.
|
|
self._enable_normalize = True
|
|
self.register_buffer('mean',
|
|
torch.tensor(mean).view(-1, 1, 1), False)
|
|
self.register_buffer('std',
|
|
torch.tensor(std).view(-1, 1, 1), False)
|
|
else:
|
|
self._enable_normalize = False
|
|
|
|
if batch_augments is not None:
|
|
self.batch_augments = RandomBatchAugment(**batch_augments)
|
|
else:
|
|
self.batch_augments = None
|
|
|
|
def forward(self, data: dict, training: bool = False) -> dict:
|
|
"""Perform normalization, padding, bgr2rgb conversion and batch
|
|
augmentation based on ``BaseDataPreprocessor``.
|
|
|
|
Args:
|
|
data (dict): data sampled from dataloader.
|
|
training (bool): Whether to enable training time augmentation.
|
|
|
|
Returns:
|
|
dict: Data in the same format as the model input.
|
|
"""
|
|
data = self.cast_data(data)
|
|
inputs = data['inputs']
|
|
|
|
if isinstance(inputs, torch.Tensor):
|
|
# The branch if use `default_collate` as the collate_fn in the
|
|
# dataloader.
|
|
|
|
# ------ To RGB ------
|
|
if self.to_rgb and inputs.size(1) == 3:
|
|
inputs = inputs.flip(1)
|
|
|
|
# -- Normalization ---
|
|
inputs = inputs.float()
|
|
if self._enable_normalize:
|
|
inputs = (inputs - self.mean) / self.std
|
|
|
|
# ------ Padding -----
|
|
if self.pad_size_divisor > 1:
|
|
h, w = inputs.shape[-2:]
|
|
|
|
target_h = math.ceil(
|
|
h / self.pad_size_divisor) * self.pad_size_divisor
|
|
target_w = math.ceil(
|
|
w / self.pad_size_divisor) * self.pad_size_divisor
|
|
pad_h = target_h - h
|
|
pad_w = target_w - w
|
|
inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
|
|
self.pad_value)
|
|
else:
|
|
# The branch if use `pseudo_collate` as the collate_fn in the
|
|
# dataloader.
|
|
|
|
processed_inputs = []
|
|
for input_ in inputs:
|
|
# ------ To RGB ------
|
|
if self.to_rgb and input_.size(0) == 3:
|
|
input_ = input_.flip(0)
|
|
|
|
# -- Normalization ---
|
|
input_ = input_.float()
|
|
if self._enable_normalize:
|
|
input_ = (input_ - self.mean) / self.std
|
|
|
|
processed_inputs.append(input_)
|
|
# Combine padding and stack
|
|
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
|
|
self.pad_value)
|
|
|
|
# ----- Batch Aug ----
|
|
if training and self.batch_augments is not None:
|
|
data_samples = data['data_samples']
|
|
inputs, data_samples = self.batch_augments(inputs, data_samples)
|
|
data['data_samples'] = data_samples
|
|
|
|
data['inputs'] = inputs
|
|
|
|
return data
|