[Feature] Support resizemix. (#676)

* add resizemix

* skip torch.__version__ < 1.7.0

* Update mmcls/models/utils/augment/resizemix.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Update mmcls/models/utils/augment/resizemix.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* resize -> F.interpolate

* fix docs

* fix test

* add Copyright

* add argument interpolation

Co-authored-by: Ma Zerun <mzr1996@163.com>
pull/703/head
takuoko 2022-03-07 13:11:20 +09:00 committed by GitHub
parent 2037260ea6
commit c1534f9126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 1 deletions

View File

@ -3,5 +3,7 @@ from .augments import Augments
from .cutmix import BatchCutMixLayer
from .identity import Identity
from .mixup import BatchMixupLayer
from .resizemix import BatchResizeMixLayer
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer')
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
'BatchResizeMixLayer')

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmcls.models.utils.augment.builder import AUGMENT
from .cutmix import BatchCutMixLayer
from .utils import one_hot_encoding
@AUGMENT.register_module(name='BatchResizeMix')
class BatchResizeMixLayer(BatchCutMixLayer):
r"""ResizeMix Random Paste layer for batch ResizeMix.
The ResizeMix will resize an image to a small patch and paste it on another
image. More details can be found in `ResizeMix: Mixing Data with Preserved
Object Information and True Labels <https://arxiv.org/abs/2012.11101>`_
Args:
alpha (float): Parameters for Beta distribution. Positive(>0)
num_classes (int): The number of classes.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'.
Default to 'bilinear'.
prob (float): mix probability. It should be in range [0, 1].
Default to 1.0.
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
(as percent of image size). When cutmix_minmax is not None, we
generate cutmix bounding-box using cutmix_minmax instead of alpha
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Default to True
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
Note:
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
to the range [``lam_min``, ``lam_max``].
.. math::
\lambda = \frac{Beta(\alpha, \alpha)}
{\lambda_{max} - \lambda_{min}} + \lambda_{min}
And the resize ratio of source images is calculated by :math:`\lambda`:
.. math::
\text{ratio} = \sqrt{1-lam}
"""
def __init__(self,
alpha,
num_classes,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
prob=1.0,
cutmix_minmax=None,
correct_lam=True,
**kwargs):
super(BatchResizeMixLayer, self).__init__(
alpha=alpha,
num_classes=num_classes,
prob=prob,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam,
**kwargs)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation
def cutmix(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
batch_size = img.size(0)
index = torch.randperm(batch_size)
(bby1, bby2, bbx1,
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
img[index],
size=(bby2 - bby1, bbx2 - bbx1),
mode=self.interpolation)
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label

View File

@ -8,6 +8,7 @@ augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
]
@ -29,6 +30,14 @@ def test_augments():
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
# Test resizemix
augments_cfg = dict(
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
# Test cutmixup
augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),