# Copyright (c) OpenMMLab. All rights reserved. from numbers import Number from typing import List, Optional, Sequence, Tuple import torch from mmengine.model import BaseDataPreprocessor, stack_batch from mmcls.registry import MODELS from .batch_augments import RandomBatchAugment @MODELS.register_module() class ClsDataPreprocessor(BaseDataPreprocessor): """Image pre-processor for classification tasks. Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, 1. It won't do normalization if ``mean`` is not specified. 2. It does normalization and color space conversion after stacking batch. 3. It supports batch augmentations like mixup and cutmix. It provides the data pre-processing as follows - Collate and move data to the target device. - Pad inputs to the maximum size of current batch with defined ``pad_value``. The padding size can be divisible by a defined ``pad_size_divisor`` - Stack inputs to batch_inputs. - Convert inputs from bgr to rgb if the shape of input is (3, H, W). - Normalize image with defined std and mean. - Do batch augmentations like Mixup and Cutmix during training. Args: mean (Sequence[Number], optional): The pixel mean of R, G, B channels. Defaults to None. std (Sequence[Number], optional): The pixel standard deviation of R, G, B channels. Defaults to None. pad_size_divisor (int): The size of padded image should be divisible by ``pad_size_divisor``. Defaults to 1. pad_value (Number): The padded pixel value. Defaults to 0. to_rgb (bool): whether to convert image from BGR to RGB. Defaults to False. batch_augments (dict, optional): The batch augmentations settings, including "augments" and "probs". For more details, see :class:`mmcls.models.RandomBatchAugment`. """ def __init__(self, mean: Sequence[Number] = None, std: Sequence[Number] = None, pad_size_divisor: int = 1, pad_value: Number = 0, to_rgb: bool = False, batch_augments: Optional[List[dict]] = None): super().__init__() self.pad_size_divisor = pad_size_divisor self.pad_value = pad_value self.to_rgb = to_rgb if mean is not None: assert std is not None, 'To enable the normalization in ' \ 'preprocessing, please specify both `mean` and `std`.' # Enable the normalization in preprocessing. self._enable_normalize = True self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False) self.register_buffer('std', torch.tensor(std).view(-1, 1, 1), False) else: self._enable_normalize = False if batch_augments is not None: self.batch_augments = RandomBatchAugment(batch_augments) else: self.batch_augments = None def forward(self, data: Sequence[dict], training: bool = False) -> Tuple[torch.Tensor, list]: """Perform normalization, padding, bgr2rgb conversion and batch augmentation based on ``BaseDataPreprocessor``. Args: data (Sequence[dict]): data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: Tuple[torch.Tensor, list]: Data in the same format as the model input. """ inputs, batch_data_samples = self.collate_data(data) # --- Pad and stack -- batch_inputs = stack_batch(inputs, self.pad_size_divisor, self.pad_value) # ------ To RGB ------ if self.to_rgb and batch_inputs.size(1) == 3: batch_inputs = batch_inputs[:, [2, 1, 0], ...] # -- Normalization --- if self._enable_normalize: batch_inputs = (batch_inputs - self.mean) / self.std else: batch_inputs = batch_inputs.to(torch.float32) # ----- Batch Aug ---- if training and self.batch_augments is not None: batch_inputs, batch_data_samples = self.batch_augments( batch_inputs, batch_data_samples) return batch_inputs, batch_data_samples