[Feature] Support resizemix. (#676)

* add resizemix

* skip torch.__version__ < 1.7.0

* Update mmcls/models/utils/augment/resizemix.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Update mmcls/models/utils/augment/resizemix.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* resize -> F.interpolate

* fix docs

* fix test

* add Copyright

* add argument interpolation

Co-authored-by: Ma Zerun <mzr1996@163.com>
pull/703/head
takuoko 2022-03-07 13:11:20 +09:00 committed by GitHub
parent 2037260ea6
commit c1534f9126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 1 deletions

View File

@ -3,5 +3,7 @@ from .augments import Augments
from .cutmix import BatchCutMixLayer from .cutmix import BatchCutMixLayer
from .identity import Identity from .identity import Identity
from .mixup import BatchMixupLayer from .mixup import BatchMixupLayer
from .resizemix import BatchResizeMixLayer
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer') __all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
'BatchResizeMixLayer')

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmcls.models.utils.augment.builder import AUGMENT
from .cutmix import BatchCutMixLayer
from .utils import one_hot_encoding
@AUGMENT.register_module(name='BatchResizeMix')
class BatchResizeMixLayer(BatchCutMixLayer):
r"""ResizeMix Random Paste layer for batch ResizeMix.
The ResizeMix will resize an image to a small patch and paste it on another
image. More details can be found in `ResizeMix: Mixing Data with Preserved
Object Information and True Labels <https://arxiv.org/abs/2012.11101>`_
Args:
alpha (float): Parameters for Beta distribution. Positive(>0)
num_classes (int): The number of classes.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'.
Default to 'bilinear'.
prob (float): mix probability. It should be in range [0, 1].
Default to 1.0.
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
(as percent of image size). When cutmix_minmax is not None, we
generate cutmix bounding-box using cutmix_minmax instead of alpha
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Default to True
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
Note:
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
to the range [``lam_min``, ``lam_max``].
.. math::
\lambda = \frac{Beta(\alpha, \alpha)}
{\lambda_{max} - \lambda_{min}} + \lambda_{min}
And the resize ratio of source images is calculated by :math:`\lambda`:
.. math::
\text{ratio} = \sqrt{1-lam}
"""
def __init__(self,
alpha,
num_classes,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
prob=1.0,
cutmix_minmax=None,
correct_lam=True,
**kwargs):
super(BatchResizeMixLayer, self).__init__(
alpha=alpha,
num_classes=num_classes,
prob=prob,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam,
**kwargs)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation
def cutmix(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
batch_size = img.size(0)
index = torch.randperm(batch_size)
(bby1, bby2, bbx1,
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
img[index],
size=(bby2 - bby1, bbx2 - bbx1),
mode=self.interpolation)
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label

View File

@ -8,6 +8,7 @@ augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.), dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.), dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.), dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
] ]
@ -29,6 +30,14 @@ def test_augments():
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10)) assert mixed_labels.shape == torch.Size((4, 10))
# Test resizemix
augments_cfg = dict(
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
# Test cutmixup # Test cutmixup
augments_cfg = [ augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),