[Feature] Support resizemix. (#676)
* add resizemix * skip torch.__version__ < 1.7.0 * Update mmcls/models/utils/augment/resizemix.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Update mmcls/models/utils/augment/resizemix.py Co-authored-by: Ma Zerun <mzr1996@163.com> * resize -> F.interpolate * fix docs * fix test * add Copyright * add argument interpolation Co-authored-by: Ma Zerun <mzr1996@163.com>pull/703/head
parent
2037260ea6
commit
c1534f9126
|
@ -3,5 +3,7 @@ from .augments import Augments
|
||||||
from .cutmix import BatchCutMixLayer
|
from .cutmix import BatchCutMixLayer
|
||||||
from .identity import Identity
|
from .identity import Identity
|
||||||
from .mixup import BatchMixupLayer
|
from .mixup import BatchMixupLayer
|
||||||
|
from .resizemix import BatchResizeMixLayer
|
||||||
|
|
||||||
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer')
|
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
|
||||||
|
'BatchResizeMixLayer')
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mmcls.models.utils.augment.builder import AUGMENT
|
||||||
|
from .cutmix import BatchCutMixLayer
|
||||||
|
from .utils import one_hot_encoding
|
||||||
|
|
||||||
|
|
||||||
|
@AUGMENT.register_module(name='BatchResizeMix')
|
||||||
|
class BatchResizeMixLayer(BatchCutMixLayer):
|
||||||
|
r"""ResizeMix Random Paste layer for batch ResizeMix.
|
||||||
|
|
||||||
|
The ResizeMix will resize an image to a small patch and paste it on another
|
||||||
|
image. More details can be found in `ResizeMix: Mixing Data with Preserved
|
||||||
|
Object Information and True Labels <https://arxiv.org/abs/2012.11101>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha (float): Parameters for Beta distribution. Positive(>0)
|
||||||
|
num_classes (int): The number of classes.
|
||||||
|
lam_min(float): The minimum value of lam. Defaults to 0.1.
|
||||||
|
lam_max(float): The maximum value of lam. Defaults to 0.8.
|
||||||
|
interpolation (str): algorithm used for upsampling:
|
||||||
|
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'.
|
||||||
|
Default to 'bilinear'.
|
||||||
|
prob (float): mix probability. It should be in range [0, 1].
|
||||||
|
Default to 1.0.
|
||||||
|
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
|
||||||
|
(as percent of image size). When cutmix_minmax is not None, we
|
||||||
|
generate cutmix bounding-box using cutmix_minmax instead of alpha
|
||||||
|
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
||||||
|
clipped by image borders. Default to True
|
||||||
|
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
|
||||||
|
variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
|
||||||
|
to the range [``lam_min``, ``lam_max``].
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\lambda = \frac{Beta(\alpha, \alpha)}
|
||||||
|
{\lambda_{max} - \lambda_{min}} + \lambda_{min}
|
||||||
|
|
||||||
|
And the resize ratio of source images is calculated by :math:`\lambda`:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{ratio} = \sqrt{1-lam}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
alpha,
|
||||||
|
num_classes,
|
||||||
|
lam_min: float = 0.1,
|
||||||
|
lam_max: float = 0.8,
|
||||||
|
interpolation='bilinear',
|
||||||
|
prob=1.0,
|
||||||
|
cutmix_minmax=None,
|
||||||
|
correct_lam=True,
|
||||||
|
**kwargs):
|
||||||
|
super(BatchResizeMixLayer, self).__init__(
|
||||||
|
alpha=alpha,
|
||||||
|
num_classes=num_classes,
|
||||||
|
prob=prob,
|
||||||
|
cutmix_minmax=cutmix_minmax,
|
||||||
|
correct_lam=correct_lam,
|
||||||
|
**kwargs)
|
||||||
|
self.lam_min = lam_min
|
||||||
|
self.lam_max = lam_max
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def cutmix(self, img, gt_label):
|
||||||
|
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||||
|
|
||||||
|
lam = np.random.beta(self.alpha, self.alpha)
|
||||||
|
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
|
||||||
|
batch_size = img.size(0)
|
||||||
|
index = torch.randperm(batch_size)
|
||||||
|
|
||||||
|
(bby1, bby2, bbx1,
|
||||||
|
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
|
||||||
|
|
||||||
|
img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
|
||||||
|
img[index],
|
||||||
|
size=(bby2 - bby1, bbx2 - bbx1),
|
||||||
|
mode=self.interpolation)
|
||||||
|
mixed_gt_label = lam * one_hot_gt_label + (
|
||||||
|
1 - lam) * one_hot_gt_label[index, :]
|
||||||
|
return img, mixed_gt_label
|
|
@ -8,6 +8,7 @@ augment_cfgs = [
|
||||||
dict(type='BatchCutMix', alpha=1., prob=1.),
|
dict(type='BatchCutMix', alpha=1., prob=1.),
|
||||||
dict(type='BatchMixup', alpha=1., prob=1.),
|
dict(type='BatchMixup', alpha=1., prob=1.),
|
||||||
dict(type='Identity', prob=1.),
|
dict(type='Identity', prob=1.),
|
||||||
|
dict(type='BatchResizeMix', alpha=1., prob=1.)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +30,14 @@ def test_augments():
|
||||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||||
assert mixed_labels.shape == torch.Size((4, 10))
|
assert mixed_labels.shape == torch.Size((4, 10))
|
||||||
|
|
||||||
|
# Test resizemix
|
||||||
|
augments_cfg = dict(
|
||||||
|
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
|
||||||
|
augs = Augments(augments_cfg)
|
||||||
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||||
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||||
|
assert mixed_labels.shape == torch.Size((4, 10))
|
||||||
|
|
||||||
# Test cutmixup
|
# Test cutmixup
|
||||||
augments_cfg = [
|
augments_cfg = [
|
||||||
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
||||||
|
|
Loading…
Reference in New Issue