[Refactor] Refactor batch augmentations

pull/913/head
mzr1996 2022-06-01 15:29:30 +08:00
parent dd660ed99e
commit f3299b4ca2
44 changed files with 603 additions and 548 deletions

View File

@ -17,6 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -17,6 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -21,6 +21,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -17,6 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -18,6 +18,5 @@ model = dict(
num_classes=1000),
topk=(1, 5),
),
train_cfg=dict(
augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000,
prob=1.)))
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
)

View File

@ -13,5 +13,5 @@ model = dict(
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
train_cfg=dict(
augments=dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)))
train_cfg=dict(augments=dict(type='Mixup', alpha=1., num_classes=10)),
)

View File

@ -13,6 +13,5 @@ model = dict(
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
train_cfg=dict(
augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000,
prob=1.)))
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
)

View File

@ -17,6 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -18,6 +18,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -17,6 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -36,6 +36,7 @@ model = dict(
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
]),
)

View File

@ -36,6 +36,7 @@ model = dict(
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
]),
)

View File

@ -36,6 +36,7 @@ model = dict(
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
]),
)

View File

@ -25,6 +25,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -25,6 +25,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -16,6 +16,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -16,6 +16,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)

View File

@ -27,9 +27,10 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)
# data settings
train_dataloader = dict(batch_size=256)

View File

@ -18,9 +18,10 @@ model = dict(
mode='original',
)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.2, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.2, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)
# dataset settings
train_dataloader = dict(sampler=dict(type='RepeatAugSampler', shuffle=True))

View File

@ -13,8 +13,8 @@ model = dict(
),
head=dict(loss=dict(use_sigmoid=True)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
dict(type='Mixup', alpha=0.1, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]))
# dataset settings

View File

@ -10,9 +10,10 @@ model = dict(
backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)),
head=dict(loss=dict(use_sigmoid=True)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
dict(type='Mixup', alpha=0.1, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
]),
)
# schedule settings
optim_wrapper = dict(

View File

@ -8,13 +8,8 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(
augments=dict(
type='BatchMixup',
alpha=0.2,
num_classes=1000,
prob=1.,
)))
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
)
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -8,13 +8,8 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(
augments=dict(
type='BatchMixup',
alpha=0.2,
num_classes=1000,
prob=1.,
)))
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
)
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -8,13 +8,8 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(
augments=dict(
type='BatchMixup',
alpha=0.2,
num_classes=1000,
prob=1.,
)))
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
)
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -8,13 +8,8 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(
augments=dict(
type='BatchMixup',
alpha=0.2,
num_classes=1000,
prob=1.,
)))
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
)
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attention import MultiheadAttention, ShiftWindowMSA
from .augment.augments import Augments
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
from .channel_shuffle import channel_shuffle
from .data_preprocessor import ClsDataPreprocessor
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
@ -14,7 +14,8 @@ from .se_layer import SELayer
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed',
'resize_relative_position_bias_table', 'ClsDataPreprocessor'
'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA',
'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding',
'resize_pos_embed', 'resize_relative_position_bias_table',
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix'
]

View File

@ -1,9 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .augments import Augments
from .cutmix import BatchCutMixLayer
from .identity import Identity
from .mixup import BatchMixupLayer
from .resizemix import BatchResizeMixLayer
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
'BatchResizeMixLayer')

View File

@ -1,73 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
import numpy as np
from .builder import build_augment
class Augments(object):
"""Data augments.
We implement some data augmentation methods, such as mixup, cutmix.
Args:
augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`):
Config dict of augments
Example:
>>> augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
]
>>> augments = Augments(augments_cfg)
>>> imgs = torch.randn(16, 3, 32, 32)
>>> label = torch.randint(0, 10, (16, ))
>>> imgs, label = augments(imgs, label)
To decide which augmentation within Augments block is used
the following rule is applied.
We pick augmentation based on the probabilities. In the example above,
we decide if we should use BatchCutMix with probability 0.5,
BatchMixup 0.3. As Identity is not in augments_cfg, we use Identity with
probability 1 - 0.5 - 0.3 = 0.2.
"""
def __init__(self, augments_cfg):
super(Augments, self).__init__()
if isinstance(augments_cfg, dict):
augments_cfg = [augments_cfg]
assert len(augments_cfg) > 0, \
'The length of augments_cfg should be positive.'
self.augments = [build_augment(cfg) for cfg in augments_cfg]
self.augment_probs = [aug.prob for aug in self.augments]
has_identity = any([cfg['type'] == 'Identity' for cfg in augments_cfg])
if has_identity:
assert sum(self.augment_probs) == 1.0,\
'The sum of augmentation probabilities should equal to 1,' \
' but got {:.2f}'.format(sum(self.augment_probs))
else:
assert sum(self.augment_probs) <= 1.0,\
'The sum of augmentation probabilities should less than or ' \
'equal to 1, but got {:.2f}'.format(sum(self.augment_probs))
identity_prob = 1 - sum(self.augment_probs)
if identity_prob > 0:
num_classes = self.augments[0].num_classes
self.augments += [
build_augment(
dict(
type='Identity',
num_classes=num_classes,
prob=identity_prob))
]
self.augment_probs += [identity_prob]
def __call__(self, img, gt_label):
if self.augments:
random_state = np.random.RandomState(random.randint(0, 2**32 - 1))
aug = random_state.choice(self.augments, p=self.augment_probs)
return aug(img, gt_label)
return img, gt_label

View File

@ -1,8 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
AUGMENT = Registry('augment')
def build_augment(cfg, default_args=None):
return build_from_cfg(cfg, AUGMENT, default_args)

View File

@ -1,29 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AUGMENT
from .utils import one_hot_encoding
@AUGMENT.register_module(name='Identity')
class Identity(object):
"""Change gt_label to one_hot encoding and keep img as the same.
Args:
num_classes (int): The number of classes.
prob (float): MixUp probability. It should be in range [0, 1].
Default to 1.0
"""
def __init__(self, num_classes, prob=1.0):
super(Identity, self).__init__()
assert isinstance(num_classes, int)
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
self.num_classes = num_classes
self.prob = prob
def one_hot(self, gt_label):
return one_hot_encoding(gt_label, self.num_classes)
def __call__(self, img, gt_label):
return img, self.one_hot(gt_label)

View File

@ -1,80 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
from .builder import AUGMENT
from .utils import one_hot_encoding
class BaseMixupLayer(object, metaclass=ABCMeta):
"""Base class for MixupLayer.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number.
num_classes (int): The number of classes.
prob (float): MixUp probability. It should be in range [0, 1].
Default to 1.0
"""
def __init__(self, alpha, num_classes, prob=1.0):
super(BaseMixupLayer, self).__init__()
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int)
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
self.alpha = alpha
self.num_classes = num_classes
self.prob = prob
@abstractmethod
def mixup(self, imgs, gt_label):
pass
@AUGMENT.register_module(name='BatchMixup')
class BatchMixupLayer(BaseMixupLayer):
r"""Mixup layer for a batch of data.
Mixup is a method to reduces the memorization of corrupt labels and
increases the robustness to adversarial examples. It's
proposed in `mixup: Beyond Empirical Risk Minimization
<https://arxiv.org/abs/1710.09412>`
This method simply linearly mix pairs of data and their labels.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
are in the note.
num_classes (int): The number of classes.
prob (float): The probability to execute mixup. It should be in
range [0, 1]. Default sto 1.0.
Note:
The :math:`\alpha` (``alpha``) determines a random distribution
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
distribution.
"""
def __init__(self, *args, **kwargs):
super(BatchMixupLayer, self).__init__(*args, **kwargs)
def mixup(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)
mixed_img = lam * img + (1 - lam) * img[index, :]
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return mixed_img, mixed_gt_label
def __call__(self, img, gt_label):
return self.mixup(img, gt_label)

View File

@ -1,24 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
def one_hot_encoding(gt, num_classes):
"""Change gt_label to one_hot encoding.
If the shape has 2 or more
dimensions, return it without encoding.
Args:
gt (Tensor): The gt label with shape (N,) or shape (N, */).
num_classes (int): The number of classes.
Return:
Tensor: One hot gt label.
"""
if gt.ndim == 1:
# multi-class classification
return F.one_hot(gt, num_classes=num_classes)
else:
# binary classification
# example. [[0], [1], [1]]
# multi-label classification
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
return gt

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cutmix import CutMix
from .mixup import Mixup
from .resizemix import ResizeMix
from .wrapper import RandomBatchAugment
__all__ = ('RandomBatchAugment', 'CutMix', 'Mixup', 'ResizeMix')

View File

@ -1,43 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
from .builder import AUGMENT
from .utils import one_hot_encoding
from mmcls.registry import BATCH_AUGMENTS
from .mixup import Mixup
class BaseCutMixLayer(object, metaclass=ABCMeta):
"""Base class for CutMixLayer.
@BATCH_AUGMENTS.register_module()
class CutMix(Mixup):
r"""CutMix batch agumentation.
CutMix is a method to improve the network's generalization capability. It's
proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
with Localizable Features <https://arxiv.org/abs/1905.04899>`
With this method, patches are cut and pasted among training images where
the ground truth labels are also mixed proportionally to the area of the
patches.
Args:
alpha (float): Parameters for Beta distribution. Positive(>0)
num_classes (int): The number of classes
prob (float): MixUp probability. It should be in range [0, 1].
Default to 1.0
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
(as percent of image size). When cutmix_minmax is not None, we
generate cutmix bounding-box using cutmix_minmax instead of alpha
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`BatchMixupLayer`.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
cutmix_minmax (List[float], optional): The min/max area ratio of the
patches. If not None, the bounding-box of patches is uniform
sampled within this ratio range, and the ``alpha`` will be ignored.
Otherwise, the bounding-box is generated according to the
``alpha``. Defaults to None.
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Default to True
clipped by image borders. Defaults to True.
.. note ::
If the ``cutmix_minmax`` is None, how to generate the bounding-box of
patches according to the ``alpha``?
First, generate a :math:`\lambda`, details can be found in
:class:`Mixup`. And then, the area ratio of the bounding-box
is calculated by:
.. math::
\text{ratio} = \sqrt{1-\lambda}
"""
def __init__(self,
alpha,
num_classes,
prob=1.0,
num_classes=None,
cutmix_minmax=None,
correct_lam=True):
super(BaseCutMixLayer, self).__init__()
super().__init__(alpha=alpha, num_classes=num_classes)
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int)
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
self.alpha = alpha
self.num_classes = num_classes
self.prob = prob
self.cutmix_minmax = cutmix_minmax
self.correct_lam = correct_lam
@ -54,7 +68,7 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
count (int, optional): Number of bbox to generate. Default to None
"""
assert len(self.cutmix_minmax) == 2
img_h, img_w = img_shape[-2:]
img_h, img_w = img_shape
cut_h = np.random.randint(
int(img_h * self.cutmix_minmax[0]),
int(img_h * self.cutmix_minmax[1]),
@ -82,7 +96,7 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
count (int, optional): Number of bbox to generate. Default to None
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
img_h, img_w = img_shape
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
@ -107,69 +121,18 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count)
if self.correct_lam or self.cutmix_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
return (yl, yu, xl, xu), lam
@abstractmethod
def cutmix(self, imgs, gt_label):
pass
@AUGMENT.register_module(name='BatchCutMix')
class BatchCutMixLayer(BaseCutMixLayer):
r"""CutMix layer for a batch of data.
CutMix is a method to improve the network's generalization capability. It's
proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
with Localizable Features <https://arxiv.org/abs/1905.04899>`
With this method, patches are cut and pasted among training images where
the ground truth labels are also mixed proportionally to the area of the
patches.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`BatchMixupLayer`.
num_classes (int): The number of classes
prob (float): The probability to execute cutmix. It should be in
range [0, 1]. Defaults to 1.0.
cutmix_minmax (List[float], optional): The min/max area ratio of the
patches. If not None, the bounding-box of patches is uniform
sampled within this ratio range, and the ``alpha`` will be ignored.
Otherwise, the bounding-box is generated according to the
``alpha``. Defaults to None.
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Defaults to True.
Note:
If the ``cutmix_minmax`` is None, how to generate the bounding-box of
patches according to the ``alpha``?
First, generate a :math:`\lambda`, details can be found in
:class:`BatchMixupLayer`. And then, the area ratio of the bounding-box
is calculated by:
.. math::
\text{ratio} = \sqrt{1-\lambda}
"""
def __init__(self, *args, **kwargs):
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
def cutmix(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
batch_size = batch_inputs.size(0)
img_shape = batch_inputs.shape[-2:]
index = torch.randperm(batch_size)
(bby1, bby2, bbx1,
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
img[:, :, bby1:bby2, bbx1:bbx2] = \
img[index, :, bby1:bby2, bbx1:bbx2]
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2]
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
def __call__(self, img, gt_label):
return self.cutmix(img, gt_label)
return batch_inputs, mixed_score

View File

@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np
import torch
from mmengine.data import LabelData
from mmcls.core import ClsDataSample
from mmcls.registry import BATCH_AUGMENTS
@BATCH_AUGMENTS.register_module()
class Mixup:
r"""Mixup batch augmentation.
Mixup is a method to reduces the memorization of corrupt labels and
increases the robustness to adversarial examples. It's proposed in
`mixup: Beyond Empirical Risk Minimization
<https://arxiv.org/abs/1710.09412>`_
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
are in the note.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
Note:
The :math:`\alpha` (``alpha``) determines a random distribution
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
distribution.
"""
def __init__(self, alpha, num_classes=None):
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int) or num_classes is None
self.alpha = alpha
self.num_classes = num_classes
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
lam = np.random.beta(self.alpha, self.alpha)
batch_size = batch_inputs.size(0)
index = torch.randperm(batch_size)
mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :]
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
return mixed_inputs, mixed_score
def __call__(self, batch_inputs: torch.Tensor,
data_samples: List[ClsDataSample]):
"""Mix the batch inputs and batch data samples."""
assert data_samples is not None, f'{self.__class__.__name__} ' \
'requires data_samples. If you only want to inference, please ' \
'disable it from preprocessing.'
if self.num_classes is None and 'num_classes' not in data_samples[0]:
raise RuntimeError(
'Not specify the `num_classes` and cannot get it from '
'data samples. Please specify `num_classes` in the '
f'{self.__class__.__name__}.')
num_classes = self.num_classes or data_samples[0].get('num_classes')
batch_score = torch.stack([
LabelData.label_to_onehot(sample.gt_label.label, num_classes)
for sample in data_samples
])
mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score)
for i, sample in enumerate(data_samples):
sample.set_gt_score(mixed_score[i])
return mixed_inputs, data_samples

View File

@ -3,13 +3,12 @@ import numpy as np
import torch
import torch.nn.functional as F
from mmcls.models.utils.augment.builder import AUGMENT
from .cutmix import BatchCutMixLayer
from .utils import one_hot_encoding
from mmcls.registry import BATCH_AUGMENTS
from .cutmix import CutMix
@AUGMENT.register_module(name='BatchResizeMix')
class BatchResizeMixLayer(BatchCutMixLayer):
@BATCH_AUGMENTS.register_module()
class ResizeMix(CutMix):
r"""ResizeMix Random Paste layer for a batch of data.
The ResizeMix will resize an image to a small patch and paste it on another
@ -19,8 +18,10 @@ class BatchResizeMixLayer(BatchCutMixLayer):
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`BatchMixupLayer`.
num_classes (int): The number of classes.
can be found in :class:`Mixup`.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
@ -35,7 +36,7 @@ class BatchResizeMixLayer(BatchCutMixLayer):
``alpha``. Defaults to None.
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Defaults to True
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
**kwargs: Any other parameters accpeted by :class:`CutMix`.
Note:
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
@ -54,40 +55,35 @@ class BatchResizeMixLayer(BatchCutMixLayer):
def __init__(self,
alpha,
num_classes,
num_classes=None,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
prob=1.0,
cutmix_minmax=None,
correct_lam=True,
**kwargs):
super(BatchResizeMixLayer, self).__init__(
correct_lam=True):
super().__init__(
alpha=alpha,
num_classes=num_classes,
prob=prob,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam,
**kwargs)
correct_lam=correct_lam)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation
def cutmix(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
batch_size = img.size(0)
img_shape = batch_inputs.shape[-2:]
batch_size = batch_inputs.size(0)
index = torch.randperm(batch_size)
(bby1, bby2, bbx1,
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate(
batch_inputs[index],
size=(y2 - y1, x2 - x1),
mode=self.interpolation,
align_corners=False)
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
img[index],
size=(bby2 - bby1, bbx2 - bbx1),
mode=self.interpolation)
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label
return batch_inputs, mixed_score

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Union
import numpy as np
import torch
from mmcls.registry import BATCH_AUGMENTS
class RandomBatchAugment:
"""Randomly choose one batch augmentation to apply.
Args:
augments (Callable | dict | list): configs of batch
augmentations.
probs (float | List[float] | None): The probabilities of each batch
augmentations. If None, choose evenly. Defaults to None.
Example:
>>> augments_cfg = [
... dict(type='CutMix', alpha=1., num_classes=10),
... dict(type='Mixup', alpha=1., num_classes=10)
... ]
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
>>> imgs = torch.randn(16, 3, 32, 32)
>>> label = torch.randint(0, 10, (16, ))
>>> imgs, label = batch_augment(imgs, label)
.. note ::
To decide which batch augmentation will be used, it picks one of
``augments`` based on the probabilities. In the example above, the
probability to use CutMix is 0.5, to use Mixup is 0.3, and to do
nothing is 0.2.
"""
def __init__(self, augments: Union[Callable, dict, list], probs=None):
if not isinstance(augments, (tuple, list)):
augments = [augments]
self.augments = []
for aug in augments:
if isinstance(aug, dict):
self.augments.append(BATCH_AUGMENTS.build(aug))
else:
self.augments.append(aug)
if isinstance(probs, float):
probs = [probs]
if probs is not None:
assert len(augments) == len(probs), \
'``augments`` and ``probs`` must have same lengths. ' \
f'Got {len(augments)} vs {len(probs)}.'
assert sum(probs) <= 1, \
'The total probability of batch augments exceeds 1.'
self.augments.append(None)
probs.append(1 - sum(probs))
self.probs = probs
def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]):
"""Randomly apply batch augmentations to the batch inputs and batch
data samples."""
aug_index = np.random.choice(len(self.augments), p=self.probs)
aug = self.augments[aug_index]
if aug is not None:
return aug(inputs, data_samples)
else:
return inputs, data_samples

View File

@ -6,13 +6,14 @@ import torch
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmcls.registry import MODELS
from .batch_augments import RandomBatchAugment
@MODELS.register_module()
class ClsDataPreprocessor(BaseDataPreprocessor):
"""Image pre-processor for classification tasks.
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
@ -39,6 +40,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
pad_value (Number): The padded pixel value. Defaults to 0.
to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
batch_augments (dict, optional): The batch augmentations settings,
including "augments" and "probs". For more details, see
:class:`mmcls.models.RandomBatchAugment`.
"""
def __init__(self,
@ -65,14 +69,16 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
else:
self._enable_normalize = False
# TODO: support batch augmentations.
self.batch_augments = batch_augments
if batch_augments is not None:
self.batch_augments = RandomBatchAugment(batch_augments)
else:
self.batch_augments = None
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, list]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
@ -98,8 +104,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
else:
batch_inputs = batch_inputs.to(torch.float32)
# ----- Batch Aug ----
if training and self.batch_augments is not None:
inputs, batch_data_samples = self.batch_augments(
inputs, batch_data_samples)
batch_inputs, batch_data_samples = self.batch_augments(
batch_inputs, batch_data_samples)
return batch_inputs, batch_data_samples

View File

@ -1,14 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
from .registry import (BATCH_AUGMENTS, DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
METRICS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
WEIGHT_INITIALIZERS)
__all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS',
'VISUALIZERS'
'VISUALIZERS', 'BATCH_AUGMENTS'
]

View File

@ -53,6 +53,8 @@ MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# manage all kinds of batch augmentations like Mixup and CutMix.
BATCH_AUGMENTS = Registry('batch augment')
# Registries For Optimizer and the related
# manage all kinds of optimizers like `SGD` and `Adam`

View File

@ -47,9 +47,7 @@ class TestImageClassifier(TestCase):
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=[
dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
])
dict(augments=dict(type='Mixup', alpha=1., num_classes=10))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)

View File

@ -1,96 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcls.models.utils import Augments
augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
]
def test_augments():
imgs = torch.randn(4, 3, 32, 32)
labels = torch.randint(0, 10, (4, ))
# Test cutmix
augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
# Test mixup
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
# Test resizemix
augments_cfg = dict(
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
# Test cutmixup
augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
]
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
]
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
augments_cfg = [
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
dict(type='Identity', num_classes=10, prob=0.2)
]
augs = Augments(augments_cfg)
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
@pytest.mark.parametrize('cfg', augment_cfgs)
def test_binary_augment(cfg):
cfg_ = dict(num_classes=1, **cfg)
augs = Augments(cfg_)
imgs = torch.randn(4, 3, 32, 32)
labels = torch.randint(0, 2, (4, 1)).float()
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 1))
@pytest.mark.parametrize('cfg', augment_cfgs)
def test_multilabel_augment(cfg):
cfg_ = dict(num_classes=10, **cfg)
augs = Augments(cfg_)
imgs = torch.randn(4, 3, 32, 32)
labels = torch.randint(0, 2, (4, 10)).float()
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))

View File

@ -0,0 +1,246 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from mmcls.core import ClsDataSample
from mmcls.models import Mixup, RandomBatchAugment
from mmcls.registry import BATCH_AUGMENTS
augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
]
class TestRandomBatchAugment(TestCase):
def test_initialize(self):
# test single augmentation
augments = dict(type='Mixup', alpha=1.)
batch_augments = RandomBatchAugment(augments)
self.assertIsInstance(batch_augments.augments, list)
self.assertEqual(len(batch_augments.augments), 1)
# test specify augments with object
augments = Mixup(alpha=1.)
batch_augments = RandomBatchAugment(augments)
self.assertIsInstance(batch_augments.augments, list)
self.assertEqual(len(batch_augments.augments), 1)
# test multiple augmentation
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments)
# mixup, cutmix
self.assertEqual(len(batch_augments.augments), 2)
self.assertIsNone(batch_augments.probs)
# test specify probs
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3])
# mixup, cutmix and None
self.assertEqual(len(batch_augments.augments), 3)
self.assertAlmostEqual(batch_augments.probs[-1], 0.2)
# test assertion
with self.assertRaisesRegex(AssertionError, 'Got 2 vs 1'):
RandomBatchAugment(augments, probs=0.5)
with self.assertRaisesRegex(AssertionError, 'exceeds 1.'):
RandomBatchAugment(augments, probs=[0.5, 0.6])
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3])
with patch('numpy.random', np.random.RandomState(0)):
batch_augments.augments[1] = MagicMock()
batch_augments(inputs, data_samples)
batch_augments.augments[1].assert_called_once_with(
inputs, data_samples)
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0])
mixed_inputs, mixed_samples = batch_augments(inputs, data_samples)
self.assertIs(mixed_inputs, inputs)
self.assertIs(mixed_samples, data_samples)
class TestMixup(TestCase):
DEFAULT_ARGS = dict(type='Mixup', alpha=1.)
def test_initialize(self):
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
# test get num_classes from data_samples
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
mixup(inputs, data_samples)
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
class TestCutMix(TestCase):
DEFAULT_ARGS = dict(type='CutMix', alpha=1.)
def test_initialize(self):
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
# test with cutmix_minmax
cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)}
cutmix = BATCH_AUGMENTS.build(cfg)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
# test without correct_lam
cfg = {**self.DEFAULT_ARGS, 'correct_lam': False}
cutmix = BATCH_AUGMENTS.build(cfg)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
# test get num_classes from data_samples
cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
cutmix(inputs, data_samples)
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
cutmix = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
cutmix = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
class TestResizeMix(TestCase):
DEFAULT_ARGS = dict(type='ResizeMix', alpha=1.)
def test_initialize(self):
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
# test get num_classes from data_samples
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
mixup(inputs, data_samples)
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import torch
from mmcls.core import ClsDataSample
from mmcls.models import ClsDataPreprocessor
from mmcls.models import ClsDataPreprocessor, RandomBatchAugment
from mmcls.registry import MODELS
from mmcls.utils import register_all_modules
@ -62,16 +62,25 @@ class TestClsDataPreprocessor(TestCase):
self.assertIsNone(data_samples)
def test_batch_augmentation(self):
# TODO: complete this test after refactoring batch augmentation
cfg = dict(
type='ClsDataPreprocessor',
batch_augments=[
dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
dict(type='Mixup', alpha=0.8, num_classes=10),
dict(type='CutMix', alpha=1., num_classes=10)
])
processor: ClsDataPreprocessor = MODELS.build(cfg)
self.assertIsNotNone(processor.batch_augments)
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': ClsDataSample().set_gt_label(1)
}]
_, data_samples = processor(data, training=True)
cfg['batch_augments'] = None
processor: ClsDataPreprocessor = MODELS.build(cfg)
self.assertIsNone(processor.batch_augments)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
}]
_, data_samples = processor(data, training=True)
self.assertIsNone(data_samples)