mmclassification/mmcls/models/utils/batch_augments/resizemix.py

90 lines
3.6 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmcls.registry import BATCH_AUGMENTS
from .cutmix import CutMix
@BATCH_AUGMENTS.register_module()
class ResizeMix(CutMix):
r"""ResizeMix Random Paste layer for a batch of data.
The ResizeMix will resize an image to a small patch and paste it on another
image. It's proposed in `ResizeMix: Mixing Data with Preserved Object
Information and True Labels <https://arxiv.org/abs/2012.11101>`_
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`Mixup`.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' |
'area'. Default to 'bilinear'.
prob (float): The probability to execute resizemix. It should be in
range [0, 1]. Defaults to 1.0.
cutmix_minmax (List[float], optional): The min/max area ratio of the
patches. If not None, the bounding-box of patches is uniform
sampled within this ratio range, and the ``alpha`` will be ignored.
Otherwise, the bounding-box is generated according to the
``alpha``. Defaults to None.
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Defaults to True
**kwargs: Any other parameters accpeted by :class:`CutMix`.
Note:
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
to the range [``lam_min``, ``lam_max``].
.. math::
\lambda = \frac{Beta(\alpha, \alpha)}
{\lambda_{max} - \lambda_{min}} + \lambda_{min}
And the resize ratio of source images is calculated by :math:`\lambda`:
.. math::
\text{ratio} = \sqrt{1-\lambda}
"""
def __init__(self,
alpha,
num_classes=None,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
cutmix_minmax=None,
correct_lam=True):
super().__init__(
alpha=alpha,
num_classes=num_classes,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
img_shape = batch_inputs.shape[-2:]
batch_size = batch_inputs.size(0)
index = torch.randperm(batch_size)
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate(
batch_inputs[index],
size=(y2 - y1, x2 - x1),
mode=self.interpolation,
align_corners=False)
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
return batch_inputs, mixed_score