# Copyright (c) OpenMMLab. All rights reserved. import math from numbers import Number from typing import Optional, Sequence import torch import torch.nn.functional as F from mmengine.model import BaseDataPreprocessor, stack_batch from mmcls.registry import MODELS from .batch_augments import RandomBatchAugment @MODELS.register_module() class ClsDataPreprocessor(BaseDataPreprocessor): """Image pre-processor for classification tasks. Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, 1. It won't do normalization if ``mean`` is not specified. 2. It does normalization and color space conversion after stacking batch. 3. It supports batch augmentations like mixup and cutmix. It provides the data pre-processing as follows - Collate and move data to the target device. - Pad inputs to the maximum size of current batch with defined ``pad_value``. The padding size can be divisible by a defined ``pad_size_divisor`` - Stack inputs to batch_inputs. - Convert inputs from bgr to rgb if the shape of input is (3, H, W). - Normalize image with defined std and mean. - Do batch augmentations like Mixup and Cutmix during training. Args: mean (Sequence[Number], optional): The pixel mean of R, G, B channels. Defaults to None. std (Sequence[Number], optional): The pixel standard deviation of R, G, B channels. Defaults to None. pad_size_divisor (int): The size of padded image should be divisible by ``pad_size_divisor``. Defaults to 1. pad_value (Number): The padded pixel value. Defaults to 0. to_rgb (bool): whether to convert image from BGR to RGB. Defaults to False. batch_augments (dict, optional): The batch augmentations settings, including "augments" and "probs". For more details, see :class:`mmcls.models.RandomBatchAugment`. """ def __init__(self, mean: Sequence[Number] = None, std: Sequence[Number] = None, pad_size_divisor: int = 1, pad_value: Number = 0, to_rgb: bool = False, batch_augments: Optional[dict] = None): super().__init__() self.pad_size_divisor = pad_size_divisor self.pad_value = pad_value self.to_rgb = to_rgb if mean is not None: assert std is not None, 'To enable the normalization in ' \ 'preprocessing, please specify both `mean` and `std`.' # Enable the normalization in preprocessing. self._enable_normalize = True self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False) self.register_buffer('std', torch.tensor(std).view(-1, 1, 1), False) else: self._enable_normalize = False if batch_augments is not None: self.batch_augments = RandomBatchAugment(**batch_augments) else: self.batch_augments = None def forward(self, data: dict, training: bool = False) -> dict: """Perform normalization, padding, bgr2rgb conversion and batch augmentation based on ``BaseDataPreprocessor``. Args: data (dict): data sampled from dataloader. training (bool): Whether to enable training time augmentation. Returns: dict: Data in the same format as the model input. """ data = self.cast_data(data) inputs = data['inputs'] if isinstance(inputs, torch.Tensor): # The branch if use `default_collate` as the collate_fn in the # dataloader. # ------ To RGB ------ if self.to_rgb and inputs.size(1) == 3: inputs = inputs.flip(1) # -- Normalization --- inputs = inputs.float() if self._enable_normalize: inputs = (inputs - self.mean) / self.std # ------ Padding ----- if self.pad_size_divisor > 1: h, w = inputs.shape[-2:] target_h = math.ceil( h / self.pad_size_divisor) * self.pad_size_divisor target_w = math.ceil( w / self.pad_size_divisor) * self.pad_size_divisor pad_h = target_h - h pad_w = target_w - w inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant', self.pad_value) else: # The branch if use `pseudo_collate` as the collate_fn in the # dataloader. processed_inputs = [] for input_ in inputs: # ------ To RGB ------ if self.to_rgb and input_.size(0) == 3: input_ = input_.flip(0) # -- Normalization --- input_ = input_.float() if self._enable_normalize: input_ = (input_ - self.mean) / self.std processed_inputs.append(input_) # Combine padding and stack inputs = stack_batch(processed_inputs, self.pad_size_divisor, self.pad_value) # ----- Batch Aug ---- if training and self.batch_augments is not None: data_samples = data['data_samples'] inputs, data_samples = self.batch_augments(inputs, data_samples) data['data_samples'] = data_samples data['inputs'] = inputs return data