[Refactor] Refactor batch augmentations
parent
dd660ed99e
commit
f3299b4ca2
|
@ -17,6 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -21,6 +21,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -18,6 +18,5 @@ model = dict(
|
|||
num_classes=1000),
|
||||
topk=(1, 5),
|
||||
),
|
||||
train_cfg=dict(
|
||||
augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000,
|
||||
prob=1.)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
)
|
||||
|
|
|
@ -13,5 +13,5 @@ model = dict(
|
|||
num_classes=10,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
|
||||
train_cfg=dict(
|
||||
augments=dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=1., num_classes=10)),
|
||||
)
|
||||
|
|
|
@ -13,6 +13,5 @@ model = dict(
|
|||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
|
||||
train_cfg=dict(
|
||||
augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000,
|
||||
prob=1.)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ model = dict(
|
|||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ model = dict(
|
|||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -36,6 +36,7 @@ model = dict(
|
|||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -25,6 +25,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -25,6 +25,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -27,9 +27,10 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
||||
# data settings
|
||||
train_dataloader = dict(batch_size=256)
|
||||
|
|
|
@ -18,9 +18,10 @@ model = dict(
|
|||
mode='original',
|
||||
)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.2, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.2, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(sampler=dict(type='RepeatAugSampler', shuffle=True))
|
||||
|
|
|
@ -13,8 +13,8 @@ model = dict(
|
|||
),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
dict(type='Mixup', alpha=0.1, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]))
|
||||
|
||||
# dataset settings
|
||||
|
|
|
@ -10,9 +10,10 @@ model = dict(
|
|||
backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
dict(type='Mixup', alpha=0.1, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
]),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
|
|
|
@ -8,13 +8,8 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(
|
||||
augments=dict(
|
||||
type='BatchMixup',
|
||||
alpha=0.2,
|
||||
num_classes=1000,
|
||||
prob=1.,
|
||||
)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||
|
|
|
@ -8,13 +8,8 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(
|
||||
augments=dict(
|
||||
type='BatchMixup',
|
||||
alpha=0.2,
|
||||
num_classes=1000,
|
||||
prob=1.,
|
||||
)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||
|
|
|
@ -8,13 +8,8 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(
|
||||
augments=dict(
|
||||
type='BatchMixup',
|
||||
alpha=0.2,
|
||||
num_classes=1000,
|
||||
prob=1.,
|
||||
)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||
|
|
|
@ -8,13 +8,8 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(
|
||||
augments=dict(
|
||||
type='BatchMixup',
|
||||
alpha=0.2,
|
||||
num_classes=1000,
|
||||
prob=1.,
|
||||
)))
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .attention import MultiheadAttention, ShiftWindowMSA
|
||||
from .augment.augments import Augments
|
||||
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .data_preprocessor import ClsDataPreprocessor
|
||||
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
|
||||
|
@ -14,7 +14,8 @@ from .se_layer import SELayer
|
|||
__all__ = [
|
||||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
|
||||
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
|
||||
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
|
||||
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed',
|
||||
'resize_relative_position_bias_table', 'ClsDataPreprocessor'
|
||||
'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA',
|
||||
'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding',
|
||||
'resize_pos_embed', 'resize_relative_position_bias_table',
|
||||
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix'
|
||||
]
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .augments import Augments
|
||||
from .cutmix import BatchCutMixLayer
|
||||
from .identity import Identity
|
||||
from .mixup import BatchMixupLayer
|
||||
from .resizemix import BatchResizeMixLayer
|
||||
|
||||
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
|
||||
'BatchResizeMixLayer')
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .builder import build_augment
|
||||
|
||||
|
||||
class Augments(object):
|
||||
"""Data augments.
|
||||
|
||||
We implement some data augmentation methods, such as mixup, cutmix.
|
||||
|
||||
Args:
|
||||
augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`):
|
||||
Config dict of augments
|
||||
|
||||
Example:
|
||||
>>> augments_cfg = [
|
||||
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
||||
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
|
||||
]
|
||||
>>> augments = Augments(augments_cfg)
|
||||
>>> imgs = torch.randn(16, 3, 32, 32)
|
||||
>>> label = torch.randint(0, 10, (16, ))
|
||||
>>> imgs, label = augments(imgs, label)
|
||||
|
||||
To decide which augmentation within Augments block is used
|
||||
the following rule is applied.
|
||||
We pick augmentation based on the probabilities. In the example above,
|
||||
we decide if we should use BatchCutMix with probability 0.5,
|
||||
BatchMixup 0.3. As Identity is not in augments_cfg, we use Identity with
|
||||
probability 1 - 0.5 - 0.3 = 0.2.
|
||||
"""
|
||||
|
||||
def __init__(self, augments_cfg):
|
||||
super(Augments, self).__init__()
|
||||
|
||||
if isinstance(augments_cfg, dict):
|
||||
augments_cfg = [augments_cfg]
|
||||
|
||||
assert len(augments_cfg) > 0, \
|
||||
'The length of augments_cfg should be positive.'
|
||||
self.augments = [build_augment(cfg) for cfg in augments_cfg]
|
||||
self.augment_probs = [aug.prob for aug in self.augments]
|
||||
|
||||
has_identity = any([cfg['type'] == 'Identity' for cfg in augments_cfg])
|
||||
if has_identity:
|
||||
assert sum(self.augment_probs) == 1.0,\
|
||||
'The sum of augmentation probabilities should equal to 1,' \
|
||||
' but got {:.2f}'.format(sum(self.augment_probs))
|
||||
else:
|
||||
assert sum(self.augment_probs) <= 1.0,\
|
||||
'The sum of augmentation probabilities should less than or ' \
|
||||
'equal to 1, but got {:.2f}'.format(sum(self.augment_probs))
|
||||
identity_prob = 1 - sum(self.augment_probs)
|
||||
if identity_prob > 0:
|
||||
num_classes = self.augments[0].num_classes
|
||||
self.augments += [
|
||||
build_augment(
|
||||
dict(
|
||||
type='Identity',
|
||||
num_classes=num_classes,
|
||||
prob=identity_prob))
|
||||
]
|
||||
self.augment_probs += [identity_prob]
|
||||
|
||||
def __call__(self, img, gt_label):
|
||||
if self.augments:
|
||||
random_state = np.random.RandomState(random.randint(0, 2**32 - 1))
|
||||
aug = random_state.choice(self.augments, p=self.augment_probs)
|
||||
return aug(img, gt_label)
|
||||
return img, gt_label
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
AUGMENT = Registry('augment')
|
||||
|
||||
|
||||
def build_augment(cfg, default_args=None):
|
||||
return build_from_cfg(cfg, AUGMENT, default_args)
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import AUGMENT
|
||||
from .utils import one_hot_encoding
|
||||
|
||||
|
||||
@AUGMENT.register_module(name='Identity')
|
||||
class Identity(object):
|
||||
"""Change gt_label to one_hot encoding and keep img as the same.
|
||||
|
||||
Args:
|
||||
num_classes (int): The number of classes.
|
||||
prob (float): MixUp probability. It should be in range [0, 1].
|
||||
Default to 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, prob=1.0):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
assert isinstance(num_classes, int)
|
||||
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.prob = prob
|
||||
|
||||
def one_hot(self, gt_label):
|
||||
return one_hot_encoding(gt_label, self.num_classes)
|
||||
|
||||
def __call__(self, img, gt_label):
|
||||
return img, self.one_hot(gt_label)
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .builder import AUGMENT
|
||||
from .utils import one_hot_encoding
|
||||
|
||||
|
||||
class BaseMixupLayer(object, metaclass=ABCMeta):
|
||||
"""Base class for MixupLayer.
|
||||
|
||||
Args:
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number.
|
||||
num_classes (int): The number of classes.
|
||||
prob (float): MixUp probability. It should be in range [0, 1].
|
||||
Default to 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, num_classes, prob=1.0):
|
||||
super(BaseMixupLayer, self).__init__()
|
||||
|
||||
assert isinstance(alpha, float) and alpha > 0
|
||||
assert isinstance(num_classes, int)
|
||||
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
|
||||
|
||||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
self.prob = prob
|
||||
|
||||
@abstractmethod
|
||||
def mixup(self, imgs, gt_label):
|
||||
pass
|
||||
|
||||
|
||||
@AUGMENT.register_module(name='BatchMixup')
|
||||
class BatchMixupLayer(BaseMixupLayer):
|
||||
r"""Mixup layer for a batch of data.
|
||||
|
||||
Mixup is a method to reduces the memorization of corrupt labels and
|
||||
increases the robustness to adversarial examples. It's
|
||||
proposed in `mixup: Beyond Empirical Risk Minimization
|
||||
<https://arxiv.org/abs/1710.09412>`
|
||||
|
||||
This method simply linearly mix pairs of data and their labels.
|
||||
|
||||
Args:
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
are in the note.
|
||||
num_classes (int): The number of classes.
|
||||
prob (float): The probability to execute mixup. It should be in
|
||||
range [0, 1]. Default sto 1.0.
|
||||
|
||||
Note:
|
||||
The :math:`\alpha` (``alpha``) determines a random distribution
|
||||
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
|
||||
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
|
||||
distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatchMixupLayer, self).__init__(*args, **kwargs)
|
||||
|
||||
def mixup(self, img, gt_label):
|
||||
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = img.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
|
||||
mixed_img = lam * img + (1 - lam) * img[index, :]
|
||||
mixed_gt_label = lam * one_hot_gt_label + (
|
||||
1 - lam) * one_hot_gt_label[index, :]
|
||||
|
||||
return mixed_img, mixed_gt_label
|
||||
|
||||
def __call__(self, img, gt_label):
|
||||
return self.mixup(img, gt_label)
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def one_hot_encoding(gt, num_classes):
|
||||
"""Change gt_label to one_hot encoding.
|
||||
|
||||
If the shape has 2 or more
|
||||
dimensions, return it without encoding.
|
||||
Args:
|
||||
gt (Tensor): The gt label with shape (N,) or shape (N, */).
|
||||
num_classes (int): The number of classes.
|
||||
Return:
|
||||
Tensor: One hot gt label.
|
||||
"""
|
||||
if gt.ndim == 1:
|
||||
# multi-class classification
|
||||
return F.one_hot(gt, num_classes=num_classes)
|
||||
else:
|
||||
# binary classification
|
||||
# example. [[0], [1], [1]]
|
||||
# multi-label classification
|
||||
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
|
||||
return gt
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cutmix import CutMix
|
||||
from .mixup import Mixup
|
||||
from .resizemix import ResizeMix
|
||||
from .wrapper import RandomBatchAugment
|
||||
|
||||
__all__ = ('RandomBatchAugment', 'CutMix', 'Mixup', 'ResizeMix')
|
|
@ -1,43 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .builder import AUGMENT
|
||||
from .utils import one_hot_encoding
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
from .mixup import Mixup
|
||||
|
||||
|
||||
class BaseCutMixLayer(object, metaclass=ABCMeta):
|
||||
"""Base class for CutMixLayer.
|
||||
@BATCH_AUGMENTS.register_module()
|
||||
class CutMix(Mixup):
|
||||
r"""CutMix batch agumentation.
|
||||
|
||||
CutMix is a method to improve the network's generalization capability. It's
|
||||
proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
|
||||
with Localizable Features <https://arxiv.org/abs/1905.04899>`
|
||||
|
||||
With this method, patches are cut and pasted among training images where
|
||||
the ground truth labels are also mixed proportionally to the area of the
|
||||
patches.
|
||||
|
||||
Args:
|
||||
alpha (float): Parameters for Beta distribution. Positive(>0)
|
||||
num_classes (int): The number of classes
|
||||
prob (float): MixUp probability. It should be in range [0, 1].
|
||||
Default to 1.0
|
||||
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
|
||||
(as percent of image size). When cutmix_minmax is not None, we
|
||||
generate cutmix bounding-box using cutmix_minmax instead of alpha
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
can be found in :class:`BatchMixupLayer`.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
cutmix_minmax (List[float], optional): The min/max area ratio of the
|
||||
patches. If not None, the bounding-box of patches is uniform
|
||||
sampled within this ratio range, and the ``alpha`` will be ignored.
|
||||
Otherwise, the bounding-box is generated according to the
|
||||
``alpha``. Defaults to None.
|
||||
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
||||
clipped by image borders. Default to True
|
||||
clipped by image borders. Defaults to True.
|
||||
|
||||
.. note ::
|
||||
If the ``cutmix_minmax`` is None, how to generate the bounding-box of
|
||||
patches according to the ``alpha``?
|
||||
|
||||
First, generate a :math:`\lambda`, details can be found in
|
||||
:class:`Mixup`. And then, the area ratio of the bounding-box
|
||||
is calculated by:
|
||||
|
||||
.. math::
|
||||
\text{ratio} = \sqrt{1-\lambda}
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha,
|
||||
num_classes,
|
||||
prob=1.0,
|
||||
num_classes=None,
|
||||
cutmix_minmax=None,
|
||||
correct_lam=True):
|
||||
super(BaseCutMixLayer, self).__init__()
|
||||
super().__init__(alpha=alpha, num_classes=num_classes)
|
||||
|
||||
assert isinstance(alpha, float) and alpha > 0
|
||||
assert isinstance(num_classes, int)
|
||||
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
|
||||
|
||||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
self.prob = prob
|
||||
self.cutmix_minmax = cutmix_minmax
|
||||
self.correct_lam = correct_lam
|
||||
|
||||
|
@ -54,7 +68,7 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
|
|||
count (int, optional): Number of bbox to generate. Default to None
|
||||
"""
|
||||
assert len(self.cutmix_minmax) == 2
|
||||
img_h, img_w = img_shape[-2:]
|
||||
img_h, img_w = img_shape
|
||||
cut_h = np.random.randint(
|
||||
int(img_h * self.cutmix_minmax[0]),
|
||||
int(img_h * self.cutmix_minmax[1]),
|
||||
|
@ -82,7 +96,7 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
|
|||
count (int, optional): Number of bbox to generate. Default to None
|
||||
"""
|
||||
ratio = np.sqrt(1 - lam)
|
||||
img_h, img_w = img_shape[-2:]
|
||||
img_h, img_w = img_shape
|
||||
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
|
||||
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
|
||||
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
|
||||
|
@ -107,69 +121,18 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
|
|||
yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count)
|
||||
if self.correct_lam or self.cutmix_minmax is not None:
|
||||
bbox_area = (yu - yl) * (xu - xl)
|
||||
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
|
||||
lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
|
||||
return (yl, yu, xl, xu), lam
|
||||
|
||||
@abstractmethod
|
||||
def cutmix(self, imgs, gt_label):
|
||||
pass
|
||||
|
||||
|
||||
@AUGMENT.register_module(name='BatchCutMix')
|
||||
class BatchCutMixLayer(BaseCutMixLayer):
|
||||
r"""CutMix layer for a batch of data.
|
||||
|
||||
CutMix is a method to improve the network's generalization capability. It's
|
||||
proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
|
||||
with Localizable Features <https://arxiv.org/abs/1905.04899>`
|
||||
|
||||
With this method, patches are cut and pasted among training images where
|
||||
the ground truth labels are also mixed proportionally to the area of the
|
||||
patches.
|
||||
|
||||
Args:
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
can be found in :class:`BatchMixupLayer`.
|
||||
num_classes (int): The number of classes
|
||||
prob (float): The probability to execute cutmix. It should be in
|
||||
range [0, 1]. Defaults to 1.0.
|
||||
cutmix_minmax (List[float], optional): The min/max area ratio of the
|
||||
patches. If not None, the bounding-box of patches is uniform
|
||||
sampled within this ratio range, and the ``alpha`` will be ignored.
|
||||
Otherwise, the bounding-box is generated according to the
|
||||
``alpha``. Defaults to None.
|
||||
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
||||
clipped by image borders. Defaults to True.
|
||||
|
||||
Note:
|
||||
If the ``cutmix_minmax`` is None, how to generate the bounding-box of
|
||||
patches according to the ``alpha``?
|
||||
|
||||
First, generate a :math:`\lambda`, details can be found in
|
||||
:class:`BatchMixupLayer`. And then, the area ratio of the bounding-box
|
||||
is calculated by:
|
||||
|
||||
.. math::
|
||||
\text{ratio} = \sqrt{1-\lambda}
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
|
||||
|
||||
def cutmix(self, img, gt_label):
|
||||
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth."""
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = img.size(0)
|
||||
batch_size = batch_inputs.size(0)
|
||||
img_shape = batch_inputs.shape[-2:]
|
||||
index = torch.randperm(batch_size)
|
||||
|
||||
(bby1, bby2, bbx1,
|
||||
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
|
||||
img[:, :, bby1:bby2, bbx1:bbx2] = \
|
||||
img[index, :, bby1:bby2, bbx1:bbx2]
|
||||
mixed_gt_label = lam * one_hot_gt_label + (
|
||||
1 - lam) * one_hot_gt_label[index, :]
|
||||
return img, mixed_gt_label
|
||||
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
|
||||
batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2]
|
||||
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
|
||||
|
||||
def __call__(self, img, gt_label):
|
||||
return self.cutmix(img, gt_label)
|
||||
return batch_inputs, mixed_score
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
|
||||
|
||||
@BATCH_AUGMENTS.register_module()
|
||||
class Mixup:
|
||||
r"""Mixup batch augmentation.
|
||||
|
||||
Mixup is a method to reduces the memorization of corrupt labels and
|
||||
increases the robustness to adversarial examples. It's proposed in
|
||||
`mixup: Beyond Empirical Risk Minimization
|
||||
<https://arxiv.org/abs/1710.09412>`_
|
||||
|
||||
Args:
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
are in the note.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
|
||||
Note:
|
||||
The :math:`\alpha` (``alpha``) determines a random distribution
|
||||
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
|
||||
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
|
||||
distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, num_classes=None):
|
||||
assert isinstance(alpha, float) and alpha > 0
|
||||
assert isinstance(num_classes, int) or num_classes is None
|
||||
|
||||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth."""
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = batch_inputs.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
|
||||
mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :]
|
||||
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
|
||||
|
||||
return mixed_inputs, mixed_score
|
||||
|
||||
def __call__(self, batch_inputs: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]):
|
||||
"""Mix the batch inputs and batch data samples."""
|
||||
assert data_samples is not None, f'{self.__class__.__name__} ' \
|
||||
'requires data_samples. If you only want to inference, please ' \
|
||||
'disable it from preprocessing.'
|
||||
|
||||
if self.num_classes is None and 'num_classes' not in data_samples[0]:
|
||||
raise RuntimeError(
|
||||
'Not specify the `num_classes` and cannot get it from '
|
||||
'data samples. Please specify `num_classes` in the '
|
||||
f'{self.__class__.__name__}.')
|
||||
num_classes = self.num_classes or data_samples[0].get('num_classes')
|
||||
|
||||
batch_score = torch.stack([
|
||||
LabelData.label_to_onehot(sample.gt_label.label, num_classes)
|
||||
for sample in data_samples
|
||||
])
|
||||
|
||||
mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score)
|
||||
|
||||
for i, sample in enumerate(data_samples):
|
||||
sample.set_gt_score(mixed_score[i])
|
||||
|
||||
return mixed_inputs, data_samples
|
|
@ -3,13 +3,12 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.models.utils.augment.builder import AUGMENT
|
||||
from .cutmix import BatchCutMixLayer
|
||||
from .utils import one_hot_encoding
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
from .cutmix import CutMix
|
||||
|
||||
|
||||
@AUGMENT.register_module(name='BatchResizeMix')
|
||||
class BatchResizeMixLayer(BatchCutMixLayer):
|
||||
@BATCH_AUGMENTS.register_module()
|
||||
class ResizeMix(CutMix):
|
||||
r"""ResizeMix Random Paste layer for a batch of data.
|
||||
|
||||
The ResizeMix will resize an image to a small patch and paste it on another
|
||||
|
@ -19,8 +18,10 @@ class BatchResizeMixLayer(BatchCutMixLayer):
|
|||
Args:
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
can be found in :class:`BatchMixupLayer`.
|
||||
num_classes (int): The number of classes.
|
||||
can be found in :class:`Mixup`.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
lam_min(float): The minimum value of lam. Defaults to 0.1.
|
||||
lam_max(float): The maximum value of lam. Defaults to 0.8.
|
||||
interpolation (str): algorithm used for upsampling:
|
||||
|
@ -35,7 +36,7 @@ class BatchResizeMixLayer(BatchCutMixLayer):
|
|||
``alpha``. Defaults to None.
|
||||
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
||||
clipped by image borders. Defaults to True
|
||||
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
|
||||
**kwargs: Any other parameters accpeted by :class:`CutMix`.
|
||||
|
||||
Note:
|
||||
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
|
||||
|
@ -54,40 +55,35 @@ class BatchResizeMixLayer(BatchCutMixLayer):
|
|||
|
||||
def __init__(self,
|
||||
alpha,
|
||||
num_classes,
|
||||
num_classes=None,
|
||||
lam_min: float = 0.1,
|
||||
lam_max: float = 0.8,
|
||||
interpolation='bilinear',
|
||||
prob=1.0,
|
||||
cutmix_minmax=None,
|
||||
correct_lam=True,
|
||||
**kwargs):
|
||||
super(BatchResizeMixLayer, self).__init__(
|
||||
correct_lam=True):
|
||||
super().__init__(
|
||||
alpha=alpha,
|
||||
num_classes=num_classes,
|
||||
prob=prob,
|
||||
cutmix_minmax=cutmix_minmax,
|
||||
correct_lam=correct_lam,
|
||||
**kwargs)
|
||||
correct_lam=correct_lam)
|
||||
self.lam_min = lam_min
|
||||
self.lam_max = lam_max
|
||||
self.interpolation = interpolation
|
||||
|
||||
def cutmix(self, img, gt_label):
|
||||
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth."""
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
|
||||
batch_size = img.size(0)
|
||||
img_shape = batch_inputs.shape[-2:]
|
||||
batch_size = batch_inputs.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
|
||||
(bby1, bby2, bbx1,
|
||||
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
|
||||
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
|
||||
batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate(
|
||||
batch_inputs[index],
|
||||
size=(y2 - y1, x2 - x1),
|
||||
mode=self.interpolation,
|
||||
align_corners=False)
|
||||
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
|
||||
|
||||
img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
|
||||
img[index],
|
||||
size=(bby2 - bby1, bbx2 - bbx1),
|
||||
mode=self.interpolation)
|
||||
mixed_gt_label = lam * one_hot_gt_label + (
|
||||
1 - lam) * one_hot_gt_label[index, :]
|
||||
return img, mixed_gt_label
|
||||
return batch_inputs, mixed_score
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
|
||||
|
||||
class RandomBatchAugment:
|
||||
"""Randomly choose one batch augmentation to apply.
|
||||
|
||||
Args:
|
||||
augments (Callable | dict | list): configs of batch
|
||||
augmentations.
|
||||
probs (float | List[float] | None): The probabilities of each batch
|
||||
augmentations. If None, choose evenly. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> augments_cfg = [
|
||||
... dict(type='CutMix', alpha=1., num_classes=10),
|
||||
... dict(type='Mixup', alpha=1., num_classes=10)
|
||||
... ]
|
||||
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
|
||||
>>> imgs = torch.randn(16, 3, 32, 32)
|
||||
>>> label = torch.randint(0, 10, (16, ))
|
||||
>>> imgs, label = batch_augment(imgs, label)
|
||||
|
||||
.. note ::
|
||||
|
||||
To decide which batch augmentation will be used, it picks one of
|
||||
``augments`` based on the probabilities. In the example above, the
|
||||
probability to use CutMix is 0.5, to use Mixup is 0.3, and to do
|
||||
nothing is 0.2.
|
||||
"""
|
||||
|
||||
def __init__(self, augments: Union[Callable, dict, list], probs=None):
|
||||
if not isinstance(augments, (tuple, list)):
|
||||
augments = [augments]
|
||||
|
||||
self.augments = []
|
||||
for aug in augments:
|
||||
if isinstance(aug, dict):
|
||||
self.augments.append(BATCH_AUGMENTS.build(aug))
|
||||
else:
|
||||
self.augments.append(aug)
|
||||
|
||||
if isinstance(probs, float):
|
||||
probs = [probs]
|
||||
|
||||
if probs is not None:
|
||||
assert len(augments) == len(probs), \
|
||||
'``augments`` and ``probs`` must have same lengths. ' \
|
||||
f'Got {len(augments)} vs {len(probs)}.'
|
||||
assert sum(probs) <= 1, \
|
||||
'The total probability of batch augments exceeds 1.'
|
||||
self.augments.append(None)
|
||||
probs.append(1 - sum(probs))
|
||||
|
||||
self.probs = probs
|
||||
|
||||
def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]):
|
||||
"""Randomly apply batch augmentations to the batch inputs and batch
|
||||
data samples."""
|
||||
aug_index = np.random.choice(len(self.augments), p=self.probs)
|
||||
aug = self.augments[aug_index]
|
||||
|
||||
if aug is not None:
|
||||
return aug(inputs, data_samples)
|
||||
else:
|
||||
return inputs, data_samples
|
|
@ -6,13 +6,14 @@ import torch
|
|||
from mmengine.model import BaseDataPreprocessor, stack_batch
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .batch_augments import RandomBatchAugment
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ClsDataPreprocessor(BaseDataPreprocessor):
|
||||
"""Image pre-processor for classification tasks.
|
||||
|
||||
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
|
||||
Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
|
||||
|
||||
1. It won't do normalization if ``mean`` is not specified.
|
||||
2. It does normalization and color space conversion after stacking batch.
|
||||
|
@ -39,6 +40,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
pad_value (Number): The padded pixel value. Defaults to 0.
|
||||
to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
batch_augments (dict, optional): The batch augmentations settings,
|
||||
including "augments" and "probs". For more details, see
|
||||
:class:`mmcls.models.RandomBatchAugment`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -65,14 +69,16 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
else:
|
||||
self._enable_normalize = False
|
||||
|
||||
# TODO: support batch augmentations.
|
||||
self.batch_augments = batch_augments
|
||||
if batch_augments is not None:
|
||||
self.batch_augments = RandomBatchAugment(batch_augments)
|
||||
else:
|
||||
self.batch_augments = None
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, list]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
"""Perform normalization, padding, bgr2rgb conversion and batch
|
||||
augmentation based on ``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
|
@ -98,8 +104,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
else:
|
||||
batch_inputs = batch_inputs.to(torch.float32)
|
||||
|
||||
# ----- Batch Aug ----
|
||||
if training and self.batch_augments is not None:
|
||||
inputs, batch_data_samples = self.batch_augments(
|
||||
inputs, batch_data_samples)
|
||||
batch_inputs, batch_data_samples = self.batch_augments(
|
||||
batch_inputs, batch_data_samples)
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
|
||||
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
|
||||
VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||
from .registry import (BATCH_AUGMENTS, DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||
METRICS, MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
|
||||
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
|
||||
WEIGHT_INITIALIZERS)
|
||||
|
||||
__all__ = [
|
||||
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
|
||||
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
|
||||
'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
|
||||
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS',
|
||||
'VISUALIZERS'
|
||||
'VISUALIZERS', 'BATCH_AUGMENTS'
|
||||
]
|
||||
|
|
|
@ -53,6 +53,8 @@ MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
|||
# manage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
||||
# manage all kinds of batch augmentations like Mixup and CutMix.
|
||||
BATCH_AUGMENTS = Registry('batch augment')
|
||||
|
||||
# Registries For Optimizer and the related
|
||||
# manage all kinds of optimizers like `SGD` and `Adam`
|
||||
|
|
|
@ -47,9 +47,7 @@ class TestImageClassifier(TestCase):
|
|||
# test set batch augmentation from train_cfg
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'train_cfg':
|
||||
dict(augments=[
|
||||
dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
||||
])
|
||||
dict(augments=dict(type='Mixup', alpha=1., num_classes=10))
|
||||
}
|
||||
model: ImageClassifier = MODELS.build(cfg)
|
||||
self.assertIsNotNone(model.data_preprocessor.batch_augments)
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.utils import Augments
|
||||
|
||||
augment_cfgs = [
|
||||
dict(type='BatchCutMix', alpha=1., prob=1.),
|
||||
dict(type='BatchMixup', alpha=1., prob=1.),
|
||||
dict(type='Identity', prob=1.),
|
||||
dict(type='BatchResizeMix', alpha=1., prob=1.)
|
||||
]
|
||||
|
||||
|
||||
def test_augments():
|
||||
imgs = torch.randn(4, 3, 32, 32)
|
||||
labels = torch.randint(0, 10, (4, ))
|
||||
|
||||
# Test cutmix
|
||||
augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
|
||||
augs = Augments(augments_cfg)
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
# Test mixup
|
||||
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
||||
augs = Augments(augments_cfg)
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
# Test resizemix
|
||||
augments_cfg = dict(
|
||||
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
|
||||
augs = Augments(augments_cfg)
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
# Test cutmixup
|
||||
augments_cfg = [
|
||||
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
||||
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
|
||||
]
|
||||
augs = Augments(augments_cfg)
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
augments_cfg = [
|
||||
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
||||
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
|
||||
]
|
||||
augs = Augments(augments_cfg)
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
augments_cfg = [
|
||||
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
||||
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
|
||||
dict(type='Identity', num_classes=10, prob=0.2)
|
||||
]
|
||||
augs = Augments(augments_cfg)
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', augment_cfgs)
|
||||
def test_binary_augment(cfg):
|
||||
|
||||
cfg_ = dict(num_classes=1, **cfg)
|
||||
augs = Augments(cfg_)
|
||||
|
||||
imgs = torch.randn(4, 3, 32, 32)
|
||||
labels = torch.randint(0, 2, (4, 1)).float()
|
||||
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', augment_cfgs)
|
||||
def test_multilabel_augment(cfg):
|
||||
|
||||
cfg_ = dict(num_classes=10, **cfg)
|
||||
augs = Augments(cfg_)
|
||||
|
||||
imgs = torch.randn(4, 3, 32, 32)
|
||||
labels = torch.randint(0, 2, (4, 10)).float()
|
||||
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
|
@ -0,0 +1,246 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.models import Mixup, RandomBatchAugment
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
|
||||
augment_cfgs = [
|
||||
dict(type='BatchCutMix', alpha=1., prob=1.),
|
||||
dict(type='BatchMixup', alpha=1., prob=1.),
|
||||
dict(type='Identity', prob=1.),
|
||||
dict(type='BatchResizeMix', alpha=1., prob=1.)
|
||||
]
|
||||
|
||||
|
||||
class TestRandomBatchAugment(TestCase):
|
||||
|
||||
def test_initialize(self):
|
||||
# test single augmentation
|
||||
augments = dict(type='Mixup', alpha=1.)
|
||||
batch_augments = RandomBatchAugment(augments)
|
||||
self.assertIsInstance(batch_augments.augments, list)
|
||||
self.assertEqual(len(batch_augments.augments), 1)
|
||||
|
||||
# test specify augments with object
|
||||
augments = Mixup(alpha=1.)
|
||||
batch_augments = RandomBatchAugment(augments)
|
||||
self.assertIsInstance(batch_augments.augments, list)
|
||||
self.assertEqual(len(batch_augments.augments), 1)
|
||||
|
||||
# test multiple augmentation
|
||||
augments = [
|
||||
dict(type='Mixup', alpha=1.),
|
||||
dict(type='CutMix', alpha=0.8),
|
||||
]
|
||||
batch_augments = RandomBatchAugment(augments)
|
||||
# mixup, cutmix
|
||||
self.assertEqual(len(batch_augments.augments), 2)
|
||||
self.assertIsNone(batch_augments.probs)
|
||||
|
||||
# test specify probs
|
||||
augments = [
|
||||
dict(type='Mixup', alpha=1.),
|
||||
dict(type='CutMix', alpha=0.8),
|
||||
]
|
||||
batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3])
|
||||
# mixup, cutmix and None
|
||||
self.assertEqual(len(batch_augments.augments), 3)
|
||||
self.assertAlmostEqual(batch_augments.probs[-1], 0.2)
|
||||
|
||||
# test assertion
|
||||
with self.assertRaisesRegex(AssertionError, 'Got 2 vs 1'):
|
||||
RandomBatchAugment(augments, probs=0.5)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'exceeds 1.'):
|
||||
RandomBatchAugment(augments, probs=[0.5, 0.6])
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
|
||||
augments = [
|
||||
dict(type='Mixup', alpha=1.),
|
||||
dict(type='CutMix', alpha=0.8),
|
||||
]
|
||||
batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3])
|
||||
|
||||
with patch('numpy.random', np.random.RandomState(0)):
|
||||
batch_augments.augments[1] = MagicMock()
|
||||
batch_augments(inputs, data_samples)
|
||||
batch_augments.augments[1].assert_called_once_with(
|
||||
inputs, data_samples)
|
||||
|
||||
augments = [
|
||||
dict(type='Mixup', alpha=1.),
|
||||
dict(type='CutMix', alpha=0.8),
|
||||
]
|
||||
batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0])
|
||||
mixed_inputs, mixed_samples = batch_augments(inputs, data_samples)
|
||||
self.assertIs(mixed_inputs, inputs)
|
||||
self.assertIs(mixed_samples, data_samples)
|
||||
|
||||
|
||||
class TestMixup(TestCase):
|
||||
DEFAULT_ARGS = dict(type='Mixup', alpha=1.)
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample(metainfo={
|
||||
'num_classes': 10
|
||||
}).set_gt_label(1) for _ in range(2)
|
||||
]
|
||||
|
||||
# test get num_classes from data_samples
|
||||
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
mixup(inputs, data_samples)
|
||||
|
||||
# test binary classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
|
||||
|
||||
# test multi-label classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
|
||||
|
||||
|
||||
class TestCutMix(TestCase):
|
||||
DEFAULT_ARGS = dict(type='CutMix', alpha=1.)
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample(metainfo={
|
||||
'num_classes': 10
|
||||
}).set_gt_label(1) for _ in range(2)
|
||||
]
|
||||
|
||||
# test with cutmix_minmax
|
||||
cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
# test without correct_lam
|
||||
cfg = {**self.DEFAULT_ARGS, 'correct_lam': False}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
# test get num_classes from data_samples
|
||||
cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
cutmix(inputs, data_samples)
|
||||
|
||||
# test binary classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
|
||||
|
||||
# test multi-label classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
|
||||
|
||||
|
||||
class TestResizeMix(TestCase):
|
||||
DEFAULT_ARGS = dict(type='ResizeMix', alpha=1.)
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample(metainfo={
|
||||
'num_classes': 10
|
||||
}).set_gt_label(1) for _ in range(2)
|
||||
]
|
||||
|
||||
# test get num_classes from data_samples
|
||||
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
mixup(inputs, data_samples)
|
||||
|
||||
# test binary classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
|
||||
|
||||
# test multi-label classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
|
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
|||
import torch
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.models import ClsDataPreprocessor
|
||||
from mmcls.models import ClsDataPreprocessor, RandomBatchAugment
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
@ -62,16 +62,25 @@ class TestClsDataPreprocessor(TestCase):
|
|||
self.assertIsNone(data_samples)
|
||||
|
||||
def test_batch_augmentation(self):
|
||||
# TODO: complete this test after refactoring batch augmentation
|
||||
|
||||
cfg = dict(
|
||||
type='ClsDataPreprocessor',
|
||||
batch_augments=[
|
||||
dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
||||
dict(type='Mixup', alpha=0.8, num_classes=10),
|
||||
dict(type='CutMix', alpha=1., num_classes=10)
|
||||
])
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
self.assertIsNotNone(processor.batch_augments)
|
||||
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': ClsDataSample().set_gt_label(1)
|
||||
}]
|
||||
_, data_samples = processor(data, training=True)
|
||||
|
||||
cfg['batch_augments'] = None
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
self.assertIsNone(processor.batch_augments)
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
}]
|
||||
_, data_samples = processor(data, training=True)
|
||||
self.assertIsNone(data_samples)
|
||||
|
|
Loading…
Reference in New Issue