diff --git a/configs/_base_/models/conformer/base-p16.py b/configs/_base_/models/conformer/base-p16.py index 157dcc98..f840914a 100644 --- a/configs/_base_/models/conformer/base-p16.py +++ b/configs/_base_/models/conformer/base-p16.py @@ -17,6 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/conformer/small-p16.py b/configs/_base_/models/conformer/small-p16.py index 17298089..a913abcb 100644 --- a/configs/_base_/models/conformer/small-p16.py +++ b/configs/_base_/models/conformer/small-p16.py @@ -17,6 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/conformer/small-p32.py b/configs/_base_/models/conformer/small-p32.py index 593aba12..bd9c00b4 100644 --- a/configs/_base_/models/conformer/small-p32.py +++ b/configs/_base_/models/conformer/small-p32.py @@ -21,6 +21,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/conformer/tiny-p16.py b/configs/_base_/models/conformer/tiny-p16.py index dad8ecae..1edb3388 100644 --- a/configs/_base_/models/conformer/tiny-p16.py +++ b/configs/_base_/models/conformer/tiny-p16.py @@ -17,6 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py b/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py index 5bb07db5..05897733 100644 --- a/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py +++ b/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py @@ -18,6 +18,5 @@ model = dict( num_classes=1000), topk=(1, 5), ), - train_cfg=dict( - augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000, - prob=1.))) + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), +) diff --git a/configs/_base_/models/resnet50_cifar_mixup.py b/configs/_base_/models/resnet50_cifar_mixup.py index 3de14f3f..7805cbd7 100644 --- a/configs/_base_/models/resnet50_cifar_mixup.py +++ b/configs/_base_/models/resnet50_cifar_mixup.py @@ -13,5 +13,5 @@ model = dict( num_classes=10, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), - train_cfg=dict( - augments=dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.))) + train_cfg=dict(augments=dict(type='Mixup', alpha=1., num_classes=10)), +) diff --git a/configs/_base_/models/resnet50_mixup.py b/configs/_base_/models/resnet50_mixup.py index 8ff95226..8a783a1e 100644 --- a/configs/_base_/models/resnet50_mixup.py +++ b/configs/_base_/models/resnet50_mixup.py @@ -13,6 +13,5 @@ model = dict( num_classes=1000, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), - train_cfg=dict( - augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000, - prob=1.))) + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), +) diff --git a/configs/_base_/models/swin_transformer/base_224.py b/configs/_base_/models/swin_transformer/base_224.py index e16b4e60..28d5b529 100644 --- a/configs/_base_/models/swin_transformer/base_224.py +++ b/configs/_base_/models/swin_transformer/base_224.py @@ -17,6 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/swin_transformer/small_224.py b/configs/_base_/models/swin_transformer/small_224.py index 78739866..ea4e070e 100644 --- a/configs/_base_/models/swin_transformer/small_224.py +++ b/configs/_base_/models/swin_transformer/small_224.py @@ -18,6 +18,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/swin_transformer/tiny_224.py b/configs/_base_/models/swin_transformer/tiny_224.py index 2d68d66b..feddc581 100644 --- a/configs/_base_/models/swin_transformer/tiny_224.py +++ b/configs/_base_/models/swin_transformer/tiny_224.py @@ -17,6 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/t2t-vit-t-14.py b/configs/_base_/models/t2t-vit-t-14.py index 91dbb676..3f4c8603 100644 --- a/configs/_base_/models/t2t-vit-t-14.py +++ b/configs/_base_/models/t2t-vit-t-14.py @@ -36,6 +36,7 @@ model = dict( topk=(1, 5), init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes), - dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes), - ])) + dict(type='Mixup', alpha=0.8, num_classes=num_classes), + dict(type='CutMix', alpha=1.0, num_classes=num_classes), + ]), +) diff --git a/configs/_base_/models/t2t-vit-t-19.py b/configs/_base_/models/t2t-vit-t-19.py index 8ab139d6..65e4dc99 100644 --- a/configs/_base_/models/t2t-vit-t-19.py +++ b/configs/_base_/models/t2t-vit-t-19.py @@ -36,6 +36,7 @@ model = dict( topk=(1, 5), init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes), - dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes), - ])) + dict(type='Mixup', alpha=0.8, num_classes=num_classes), + dict(type='CutMix', alpha=1.0, num_classes=num_classes), + ]), +) diff --git a/configs/_base_/models/t2t-vit-t-24.py b/configs/_base_/models/t2t-vit-t-24.py index 5990960a..f2e60185 100644 --- a/configs/_base_/models/t2t-vit-t-24.py +++ b/configs/_base_/models/t2t-vit-t-24.py @@ -36,6 +36,7 @@ model = dict( topk=(1, 5), init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes), - dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes), - ])) + dict(type='Mixup', alpha=0.8, num_classes=num_classes), + dict(type='CutMix', alpha=1.0, num_classes=num_classes), + ]), +) diff --git a/configs/_base_/models/twins_pcpvt_base.py b/configs/_base_/models/twins_pcpvt_base.py index 473d7ee8..32c8e4a3 100644 --- a/configs/_base_/models/twins_pcpvt_base.py +++ b/configs/_base_/models/twins_pcpvt_base.py @@ -25,6 +25,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/twins_svt_base.py b/configs/_base_/models/twins_svt_base.py index cabd3739..974e76b4 100644 --- a/configs/_base_/models/twins_svt_base.py +++ b/configs/_base_/models/twins_svt_base.py @@ -25,6 +25,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/van/van_small.py b/configs/_base_/models/van/van_small.py index 320e90af..68228c5c 100644 --- a/configs/_base_/models/van/van_small.py +++ b/configs/_base_/models/van/van_small.py @@ -16,6 +16,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/_base_/models/van/van_tiny.py b/configs/_base_/models/van/van_tiny.py index 42791ac3..c765dbc9 100644 --- a/configs/_base_/models/van/van_tiny.py +++ b/configs/_base_/models/van/van_tiny.py @@ -16,6 +16,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) diff --git a/configs/deit/deit-small_pt-4xb256_in1k.py b/configs/deit/deit-small_pt-4xb256_in1k.py index 11336dad..e28d12f3 100644 --- a/configs/deit/deit-small_pt-4xb256_in1k.py +++ b/configs/deit/deit-small_pt-4xb256_in1k.py @@ -27,9 +27,10 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.), ], train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.8, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) # data settings train_dataloader = dict(batch_size=256) diff --git a/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py index 86d7c2c9..74129565 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py @@ -18,9 +18,10 @@ model = dict( mode='original', )), train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.2, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.2, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) # dataset settings train_dataloader = dict(sampler=dict(type='RepeatAugSampler', shuffle=True)) diff --git a/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py index 6562a18d..2fb735f2 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py @@ -13,8 +13,8 @@ model = dict( ), head=dict(loss=dict(use_sigmoid=True)), train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) + dict(type='Mixup', alpha=0.1, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) ])) # dataset settings diff --git a/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py index 6399226b..e6872a3b 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py @@ -10,9 +10,10 @@ model = dict( backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)), head=dict(loss=dict(use_sigmoid=True)), train_cfg=dict(augments=[ - dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5), - dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) - ])) + dict(type='Mixup', alpha=0.1, num_classes=1000), + dict(type='CutMix', alpha=1.0, num_classes=1000) + ]), +) # schedule settings optim_wrapper = dict( diff --git a/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py index e1af07b8..0a9e5156 100644 --- a/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py @@ -8,13 +8,8 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict( - augments=dict( - type='BatchMixup', - alpha=0.2, - num_classes=1000, - prob=1., - ))) + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), +) # schedule setting optim_wrapper = dict(clip_grad=dict(max_norm=1.0)) diff --git a/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py index 815abe6b..83a92fca 100644 --- a/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py @@ -8,13 +8,8 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict( - augments=dict( - type='BatchMixup', - alpha=0.2, - num_classes=1000, - prob=1., - ))) + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), +) # schedule setting optim_wrapper = dict(clip_grad=dict(max_norm=1.0)) diff --git a/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py index 87e79cb3..0cf9d8e1 100644 --- a/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py @@ -8,13 +8,8 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict( - augments=dict( - type='BatchMixup', - alpha=0.2, - num_classes=1000, - prob=1., - ))) + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), +) # schedule setting optim_wrapper = dict(clip_grad=dict(max_norm=1.0)) diff --git a/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py index 6bf4e28e..c1b5a3d8 100644 --- a/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py @@ -8,13 +8,8 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict( - augments=dict( - type='BatchMixup', - alpha=0.2, - num_classes=1000, - prob=1., - ))) + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), +) # schedule setting optim_wrapper = dict(clip_grad=dict(max_norm=1.0)) diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py index 645467d2..98ec41a6 100644 --- a/mmcls/models/utils/__init__.py +++ b/mmcls/models/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .attention import MultiheadAttention, ShiftWindowMSA -from .augment.augments import Augments +from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix from .channel_shuffle import channel_shuffle from .data_preprocessor import ClsDataPreprocessor from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, @@ -14,7 +14,8 @@ from .se_layer import SELayer __all__ = [ 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed', - 'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing', - 'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed', - 'resize_relative_position_bias_table', 'ClsDataPreprocessor' + 'PatchMerging', 'HybridEmbed', 'RandomBatchAugment', 'ShiftWindowMSA', + 'is_tracing', 'MultiheadAttention', 'ConditionalPositionEncoding', + 'resize_pos_embed', 'resize_relative_position_bias_table', + 'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix' ] diff --git a/mmcls/models/utils/augment/__init__.py b/mmcls/models/utils/augment/__init__.py deleted file mode 100644 index 9f92cd54..00000000 --- a/mmcls/models/utils/augment/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .augments import Augments -from .cutmix import BatchCutMixLayer -from .identity import Identity -from .mixup import BatchMixupLayer -from .resizemix import BatchResizeMixLayer - -__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer', - 'BatchResizeMixLayer') diff --git a/mmcls/models/utils/augment/augments.py b/mmcls/models/utils/augment/augments.py deleted file mode 100644 index 8455e935..00000000 --- a/mmcls/models/utils/augment/augments.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import random - -import numpy as np - -from .builder import build_augment - - -class Augments(object): - """Data augments. - - We implement some data augmentation methods, such as mixup, cutmix. - - Args: - augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`): - Config dict of augments - - Example: - >>> augments_cfg = [ - dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), - dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3) - ] - >>> augments = Augments(augments_cfg) - >>> imgs = torch.randn(16, 3, 32, 32) - >>> label = torch.randint(0, 10, (16, )) - >>> imgs, label = augments(imgs, label) - - To decide which augmentation within Augments block is used - the following rule is applied. - We pick augmentation based on the probabilities. In the example above, - we decide if we should use BatchCutMix with probability 0.5, - BatchMixup 0.3. As Identity is not in augments_cfg, we use Identity with - probability 1 - 0.5 - 0.3 = 0.2. - """ - - def __init__(self, augments_cfg): - super(Augments, self).__init__() - - if isinstance(augments_cfg, dict): - augments_cfg = [augments_cfg] - - assert len(augments_cfg) > 0, \ - 'The length of augments_cfg should be positive.' - self.augments = [build_augment(cfg) for cfg in augments_cfg] - self.augment_probs = [aug.prob for aug in self.augments] - - has_identity = any([cfg['type'] == 'Identity' for cfg in augments_cfg]) - if has_identity: - assert sum(self.augment_probs) == 1.0,\ - 'The sum of augmentation probabilities should equal to 1,' \ - ' but got {:.2f}'.format(sum(self.augment_probs)) - else: - assert sum(self.augment_probs) <= 1.0,\ - 'The sum of augmentation probabilities should less than or ' \ - 'equal to 1, but got {:.2f}'.format(sum(self.augment_probs)) - identity_prob = 1 - sum(self.augment_probs) - if identity_prob > 0: - num_classes = self.augments[0].num_classes - self.augments += [ - build_augment( - dict( - type='Identity', - num_classes=num_classes, - prob=identity_prob)) - ] - self.augment_probs += [identity_prob] - - def __call__(self, img, gt_label): - if self.augments: - random_state = np.random.RandomState(random.randint(0, 2**32 - 1)) - aug = random_state.choice(self.augments, p=self.augment_probs) - return aug(img, gt_label) - return img, gt_label diff --git a/mmcls/models/utils/augment/builder.py b/mmcls/models/utils/augment/builder.py deleted file mode 100644 index 5d1205ee..00000000 --- a/mmcls/models/utils/augment/builder.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmcv.utils import Registry, build_from_cfg - -AUGMENT = Registry('augment') - - -def build_augment(cfg, default_args=None): - return build_from_cfg(cfg, AUGMENT, default_args) diff --git a/mmcls/models/utils/augment/identity.py b/mmcls/models/utils/augment/identity.py deleted file mode 100644 index ae3a3df5..00000000 --- a/mmcls/models/utils/augment/identity.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .builder import AUGMENT -from .utils import one_hot_encoding - - -@AUGMENT.register_module(name='Identity') -class Identity(object): - """Change gt_label to one_hot encoding and keep img as the same. - - Args: - num_classes (int): The number of classes. - prob (float): MixUp probability. It should be in range [0, 1]. - Default to 1.0 - """ - - def __init__(self, num_classes, prob=1.0): - super(Identity, self).__init__() - - assert isinstance(num_classes, int) - assert isinstance(prob, float) and 0.0 <= prob <= 1.0 - - self.num_classes = num_classes - self.prob = prob - - def one_hot(self, gt_label): - return one_hot_encoding(gt_label, self.num_classes) - - def __call__(self, img, gt_label): - return img, self.one_hot(gt_label) diff --git a/mmcls/models/utils/augment/mixup.py b/mmcls/models/utils/augment/mixup.py deleted file mode 100644 index e8899dd3..00000000 --- a/mmcls/models/utils/augment/mixup.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - -import numpy as np -import torch - -from .builder import AUGMENT -from .utils import one_hot_encoding - - -class BaseMixupLayer(object, metaclass=ABCMeta): - """Base class for MixupLayer. - - Args: - alpha (float): Parameters for Beta distribution to generate the - mixing ratio. It should be a positive number. - num_classes (int): The number of classes. - prob (float): MixUp probability. It should be in range [0, 1]. - Default to 1.0 - """ - - def __init__(self, alpha, num_classes, prob=1.0): - super(BaseMixupLayer, self).__init__() - - assert isinstance(alpha, float) and alpha > 0 - assert isinstance(num_classes, int) - assert isinstance(prob, float) and 0.0 <= prob <= 1.0 - - self.alpha = alpha - self.num_classes = num_classes - self.prob = prob - - @abstractmethod - def mixup(self, imgs, gt_label): - pass - - -@AUGMENT.register_module(name='BatchMixup') -class BatchMixupLayer(BaseMixupLayer): - r"""Mixup layer for a batch of data. - - Mixup is a method to reduces the memorization of corrupt labels and - increases the robustness to adversarial examples. It's - proposed in `mixup: Beyond Empirical Risk Minimization - ` - - This method simply linearly mix pairs of data and their labels. - - Args: - alpha (float): Parameters for Beta distribution to generate the - mixing ratio. It should be a positive number. More details - are in the note. - num_classes (int): The number of classes. - prob (float): The probability to execute mixup. It should be in - range [0, 1]. Default sto 1.0. - - Note: - The :math:`\alpha` (``alpha``) determines a random distribution - :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample - a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random - distribution. - """ - - def __init__(self, *args, **kwargs): - super(BatchMixupLayer, self).__init__(*args, **kwargs) - - def mixup(self, img, gt_label): - one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) - lam = np.random.beta(self.alpha, self.alpha) - batch_size = img.size(0) - index = torch.randperm(batch_size) - - mixed_img = lam * img + (1 - lam) * img[index, :] - mixed_gt_label = lam * one_hot_gt_label + ( - 1 - lam) * one_hot_gt_label[index, :] - - return mixed_img, mixed_gt_label - - def __call__(self, img, gt_label): - return self.mixup(img, gt_label) diff --git a/mmcls/models/utils/augment/utils.py b/mmcls/models/utils/augment/utils.py deleted file mode 100644 index e972d54b..00000000 --- a/mmcls/models/utils/augment/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn.functional as F - - -def one_hot_encoding(gt, num_classes): - """Change gt_label to one_hot encoding. - - If the shape has 2 or more - dimensions, return it without encoding. - Args: - gt (Tensor): The gt label with shape (N,) or shape (N, */). - num_classes (int): The number of classes. - Return: - Tensor: One hot gt label. - """ - if gt.ndim == 1: - # multi-class classification - return F.one_hot(gt, num_classes=num_classes) - else: - # binary classification - # example. [[0], [1], [1]] - # multi-label classification - # example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]] - return gt diff --git a/mmcls/models/utils/batch_augments/__init__.py b/mmcls/models/utils/batch_augments/__init__.py new file mode 100644 index 00000000..2fbc4e17 --- /dev/null +++ b/mmcls/models/utils/batch_augments/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cutmix import CutMix +from .mixup import Mixup +from .resizemix import ResizeMix +from .wrapper import RandomBatchAugment + +__all__ = ('RandomBatchAugment', 'CutMix', 'Mixup', 'ResizeMix') diff --git a/mmcls/models/utils/augment/cutmix.py b/mmcls/models/utils/batch_augments/cutmix.py similarity index 66% rename from mmcls/models/utils/augment/cutmix.py rename to mmcls/models/utils/batch_augments/cutmix.py index 0d8ba9dd..e4b04bdd 100644 --- a/mmcls/models/utils/augment/cutmix.py +++ b/mmcls/models/utils/batch_augments/cutmix.py @@ -1,43 +1,57 @@ # Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - import numpy as np import torch -from .builder import AUGMENT -from .utils import one_hot_encoding +from mmcls.registry import BATCH_AUGMENTS +from .mixup import Mixup -class BaseCutMixLayer(object, metaclass=ABCMeta): - """Base class for CutMixLayer. +@BATCH_AUGMENTS.register_module() +class CutMix(Mixup): + r"""CutMix batch agumentation. + + CutMix is a method to improve the network's generalization capability. It's + proposed in `CutMix: Regularization Strategy to Train Strong Classifiers + with Localizable Features ` + + With this method, patches are cut and pasted among training images where + the ground truth labels are also mixed proportionally to the area of the + patches. Args: - alpha (float): Parameters for Beta distribution. Positive(>0) - num_classes (int): The number of classes - prob (float): MixUp probability. It should be in range [0, 1]. - Default to 1.0 - cutmix_minmax (List[float], optional): cutmix min/max image ratio. - (as percent of image size). When cutmix_minmax is not None, we - generate cutmix bounding-box using cutmix_minmax instead of alpha + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + can be found in :class:`BatchMixupLayer`. + num_classes (int, optional): The number of classes. If not specified, + will try to get it from data samples during training. + Defaults to None. + cutmix_minmax (List[float], optional): The min/max area ratio of the + patches. If not None, the bounding-box of patches is uniform + sampled within this ratio range, and the ``alpha`` will be ignored. + Otherwise, the bounding-box is generated according to the + ``alpha``. Defaults to None. correct_lam (bool): Whether to apply lambda correction when cutmix bbox - clipped by image borders. Default to True + clipped by image borders. Defaults to True. + + .. note :: + If the ``cutmix_minmax`` is None, how to generate the bounding-box of + patches according to the ``alpha``? + + First, generate a :math:`\lambda`, details can be found in + :class:`Mixup`. And then, the area ratio of the bounding-box + is calculated by: + + .. math:: + \text{ratio} = \sqrt{1-\lambda} """ def __init__(self, alpha, - num_classes, - prob=1.0, + num_classes=None, cutmix_minmax=None, correct_lam=True): - super(BaseCutMixLayer, self).__init__() + super().__init__(alpha=alpha, num_classes=num_classes) - assert isinstance(alpha, float) and alpha > 0 - assert isinstance(num_classes, int) - assert isinstance(prob, float) and 0.0 <= prob <= 1.0 - - self.alpha = alpha - self.num_classes = num_classes - self.prob = prob self.cutmix_minmax = cutmix_minmax self.correct_lam = correct_lam @@ -54,7 +68,7 @@ class BaseCutMixLayer(object, metaclass=ABCMeta): count (int, optional): Number of bbox to generate. Default to None """ assert len(self.cutmix_minmax) == 2 - img_h, img_w = img_shape[-2:] + img_h, img_w = img_shape cut_h = np.random.randint( int(img_h * self.cutmix_minmax[0]), int(img_h * self.cutmix_minmax[1]), @@ -82,7 +96,7 @@ class BaseCutMixLayer(object, metaclass=ABCMeta): count (int, optional): Number of bbox to generate. Default to None """ ratio = np.sqrt(1 - lam) - img_h, img_w = img_shape[-2:] + img_h, img_w = img_shape cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) @@ -107,69 +121,18 @@ class BaseCutMixLayer(object, metaclass=ABCMeta): yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count) if self.correct_lam or self.cutmix_minmax is not None: bbox_area = (yu - yl) * (xu - xl) - lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) + lam = 1. - bbox_area / float(img_shape[0] * img_shape[1]) return (yl, yu, xl, xu), lam - @abstractmethod - def cutmix(self, imgs, gt_label): - pass - - -@AUGMENT.register_module(name='BatchCutMix') -class BatchCutMixLayer(BaseCutMixLayer): - r"""CutMix layer for a batch of data. - - CutMix is a method to improve the network's generalization capability. It's - proposed in `CutMix: Regularization Strategy to Train Strong Classifiers - with Localizable Features ` - - With this method, patches are cut and pasted among training images where - the ground truth labels are also mixed proportionally to the area of the - patches. - - Args: - alpha (float): Parameters for Beta distribution to generate the - mixing ratio. It should be a positive number. More details - can be found in :class:`BatchMixupLayer`. - num_classes (int): The number of classes - prob (float): The probability to execute cutmix. It should be in - range [0, 1]. Defaults to 1.0. - cutmix_minmax (List[float], optional): The min/max area ratio of the - patches. If not None, the bounding-box of patches is uniform - sampled within this ratio range, and the ``alpha`` will be ignored. - Otherwise, the bounding-box is generated according to the - ``alpha``. Defaults to None. - correct_lam (bool): Whether to apply lambda correction when cutmix bbox - clipped by image borders. Defaults to True. - - Note: - If the ``cutmix_minmax`` is None, how to generate the bounding-box of - patches according to the ``alpha``? - - First, generate a :math:`\lambda`, details can be found in - :class:`BatchMixupLayer`. And then, the area ratio of the bounding-box - is calculated by: - - .. math:: - \text{ratio} = \sqrt{1-\lambda} - """ - - def __init__(self, *args, **kwargs): - super(BatchCutMixLayer, self).__init__(*args, **kwargs) - - def cutmix(self, img, gt_label): - one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) + def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): + """Mix the batch inputs and batch one-hot format ground truth.""" lam = np.random.beta(self.alpha, self.alpha) - batch_size = img.size(0) + batch_size = batch_inputs.size(0) + img_shape = batch_inputs.shape[-2:] index = torch.randperm(batch_size) - (bby1, bby2, bbx1, - bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam) - img[:, :, bby1:bby2, bbx1:bbx2] = \ - img[index, :, bby1:bby2, bbx1:bbx2] - mixed_gt_label = lam * one_hot_gt_label + ( - 1 - lam) * one_hot_gt_label[index, :] - return img, mixed_gt_label + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2] + mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :] - def __call__(self, img, gt_label): - return self.cutmix(img, gt_label) + return batch_inputs, mixed_score diff --git a/mmcls/models/utils/batch_augments/mixup.py b/mmcls/models/utils/batch_augments/mixup.py new file mode 100644 index 00000000..4884849b --- /dev/null +++ b/mmcls/models/utils/batch_augments/mixup.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import numpy as np +import torch +from mmengine.data import LabelData + +from mmcls.core import ClsDataSample +from mmcls.registry import BATCH_AUGMENTS + + +@BATCH_AUGMENTS.register_module() +class Mixup: + r"""Mixup batch augmentation. + + Mixup is a method to reduces the memorization of corrupt labels and + increases the robustness to adversarial examples. It's proposed in + `mixup: Beyond Empirical Risk Minimization + `_ + + Args: + alpha (float): Parameters for Beta distribution to generate the + mixing ratio. It should be a positive number. More details + are in the note. + num_classes (int, optional): The number of classes. If not specified, + will try to get it from data samples during training. + Defaults to None. + + Note: + The :math:`\alpha` (``alpha``) determines a random distribution + :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample + a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random + distribution. + """ + + def __init__(self, alpha, num_classes=None): + assert isinstance(alpha, float) and alpha > 0 + assert isinstance(num_classes, int) or num_classes is None + + self.alpha = alpha + self.num_classes = num_classes + + def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): + """Mix the batch inputs and batch one-hot format ground truth.""" + lam = np.random.beta(self.alpha, self.alpha) + batch_size = batch_inputs.size(0) + index = torch.randperm(batch_size) + + mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :] + mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :] + + return mixed_inputs, mixed_score + + def __call__(self, batch_inputs: torch.Tensor, + data_samples: List[ClsDataSample]): + """Mix the batch inputs and batch data samples.""" + assert data_samples is not None, f'{self.__class__.__name__} ' \ + 'requires data_samples. If you only want to inference, please ' \ + 'disable it from preprocessing.' + + if self.num_classes is None and 'num_classes' not in data_samples[0]: + raise RuntimeError( + 'Not specify the `num_classes` and cannot get it from ' + 'data samples. Please specify `num_classes` in the ' + f'{self.__class__.__name__}.') + num_classes = self.num_classes or data_samples[0].get('num_classes') + + batch_score = torch.stack([ + LabelData.label_to_onehot(sample.gt_label.label, num_classes) + for sample in data_samples + ]) + + mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score) + + for i, sample in enumerate(data_samples): + sample.set_gt_score(mixed_score[i]) + + return mixed_inputs, data_samples diff --git a/mmcls/models/utils/augment/resizemix.py b/mmcls/models/utils/batch_augments/resizemix.py similarity index 67% rename from mmcls/models/utils/augment/resizemix.py rename to mmcls/models/utils/batch_augments/resizemix.py index 1506cc37..ae759687 100644 --- a/mmcls/models/utils/augment/resizemix.py +++ b/mmcls/models/utils/batch_augments/resizemix.py @@ -3,13 +3,12 @@ import numpy as np import torch import torch.nn.functional as F -from mmcls.models.utils.augment.builder import AUGMENT -from .cutmix import BatchCutMixLayer -from .utils import one_hot_encoding +from mmcls.registry import BATCH_AUGMENTS +from .cutmix import CutMix -@AUGMENT.register_module(name='BatchResizeMix') -class BatchResizeMixLayer(BatchCutMixLayer): +@BATCH_AUGMENTS.register_module() +class ResizeMix(CutMix): r"""ResizeMix Random Paste layer for a batch of data. The ResizeMix will resize an image to a small patch and paste it on another @@ -19,8 +18,10 @@ class BatchResizeMixLayer(BatchCutMixLayer): Args: alpha (float): Parameters for Beta distribution to generate the mixing ratio. It should be a positive number. More details - can be found in :class:`BatchMixupLayer`. - num_classes (int): The number of classes. + can be found in :class:`Mixup`. + num_classes (int, optional): The number of classes. If not specified, + will try to get it from data samples during training. + Defaults to None. lam_min(float): The minimum value of lam. Defaults to 0.1. lam_max(float): The maximum value of lam. Defaults to 0.8. interpolation (str): algorithm used for upsampling: @@ -35,7 +36,7 @@ class BatchResizeMixLayer(BatchCutMixLayer): ``alpha``. Defaults to None. correct_lam (bool): Whether to apply lambda correction when cutmix bbox clipped by image borders. Defaults to True - **kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`. + **kwargs: Any other parameters accpeted by :class:`CutMix`. Note: The :math:`\lambda` (``lam``) is the mixing ratio. It's a random @@ -54,40 +55,35 @@ class BatchResizeMixLayer(BatchCutMixLayer): def __init__(self, alpha, - num_classes, + num_classes=None, lam_min: float = 0.1, lam_max: float = 0.8, interpolation='bilinear', - prob=1.0, cutmix_minmax=None, - correct_lam=True, - **kwargs): - super(BatchResizeMixLayer, self).__init__( + correct_lam=True): + super().__init__( alpha=alpha, num_classes=num_classes, - prob=prob, cutmix_minmax=cutmix_minmax, - correct_lam=correct_lam, - **kwargs) + correct_lam=correct_lam) self.lam_min = lam_min self.lam_max = lam_max self.interpolation = interpolation - def cutmix(self, img, gt_label): - one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) - + def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): + """Mix the batch inputs and batch one-hot format ground truth.""" lam = np.random.beta(self.alpha, self.alpha) lam = lam * (self.lam_max - self.lam_min) + self.lam_min - batch_size = img.size(0) + img_shape = batch_inputs.shape[-2:] + batch_size = batch_inputs.size(0) index = torch.randperm(batch_size) - (bby1, bby2, bbx1, - bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam) + (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) + batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate( + batch_inputs[index], + size=(y2 - y1, x2 - x1), + mode=self.interpolation, + align_corners=False) + mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :] - img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate( - img[index], - size=(bby2 - bby1, bbx2 - bbx1), - mode=self.interpolation) - mixed_gt_label = lam * one_hot_gt_label + ( - 1 - lam) * one_hot_gt_label[index, :] - return img, mixed_gt_label + return batch_inputs, mixed_score diff --git a/mmcls/models/utils/batch_augments/wrapper.py b/mmcls/models/utils/batch_augments/wrapper.py new file mode 100644 index 00000000..2759aa69 --- /dev/null +++ b/mmcls/models/utils/batch_augments/wrapper.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Union + +import numpy as np +import torch + +from mmcls.registry import BATCH_AUGMENTS + + +class RandomBatchAugment: + """Randomly choose one batch augmentation to apply. + + Args: + augments (Callable | dict | list): configs of batch + augmentations. + probs (float | List[float] | None): The probabilities of each batch + augmentations. If None, choose evenly. Defaults to None. + + Example: + >>> augments_cfg = [ + ... dict(type='CutMix', alpha=1., num_classes=10), + ... dict(type='Mixup', alpha=1., num_classes=10) + ... ] + >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) + >>> imgs = torch.randn(16, 3, 32, 32) + >>> label = torch.randint(0, 10, (16, )) + >>> imgs, label = batch_augment(imgs, label) + + .. note :: + + To decide which batch augmentation will be used, it picks one of + ``augments`` based on the probabilities. In the example above, the + probability to use CutMix is 0.5, to use Mixup is 0.3, and to do + nothing is 0.2. + """ + + def __init__(self, augments: Union[Callable, dict, list], probs=None): + if not isinstance(augments, (tuple, list)): + augments = [augments] + + self.augments = [] + for aug in augments: + if isinstance(aug, dict): + self.augments.append(BATCH_AUGMENTS.build(aug)) + else: + self.augments.append(aug) + + if isinstance(probs, float): + probs = [probs] + + if probs is not None: + assert len(augments) == len(probs), \ + '``augments`` and ``probs`` must have same lengths. ' \ + f'Got {len(augments)} vs {len(probs)}.' + assert sum(probs) <= 1, \ + 'The total probability of batch augments exceeds 1.' + self.augments.append(None) + probs.append(1 - sum(probs)) + + self.probs = probs + + def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]): + """Randomly apply batch augmentations to the batch inputs and batch + data samples.""" + aug_index = np.random.choice(len(self.augments), p=self.probs) + aug = self.augments[aug_index] + + if aug is not None: + return aug(inputs, data_samples) + else: + return inputs, data_samples diff --git a/mmcls/models/utils/data_preprocessor.py b/mmcls/models/utils/data_preprocessor.py index 4fd26264..d869eebd 100644 --- a/mmcls/models/utils/data_preprocessor.py +++ b/mmcls/models/utils/data_preprocessor.py @@ -6,13 +6,14 @@ import torch from mmengine.model import BaseDataPreprocessor, stack_batch from mmcls.registry import MODELS +from .batch_augments import RandomBatchAugment @MODELS.register_module() class ClsDataPreprocessor(BaseDataPreprocessor): """Image pre-processor for classification tasks. - Comparing with the :class:`mmengine.ImgDataPreprocessor`, + Comparing with the :class:`mmengine.model.ImgDataPreprocessor`, 1. It won't do normalization if ``mean`` is not specified. 2. It does normalization and color space conversion after stacking batch. @@ -39,6 +40,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor): pad_value (Number): The padded pixel value. Defaults to 0. to_rgb (bool): whether to convert image from BGR to RGB. Defaults to False. + batch_augments (dict, optional): The batch augmentations settings, + including "augments" and "probs". For more details, see + :class:`mmcls.models.RandomBatchAugment`. """ def __init__(self, @@ -65,14 +69,16 @@ class ClsDataPreprocessor(BaseDataPreprocessor): else: self._enable_normalize = False - # TODO: support batch augmentations. - self.batch_augments = batch_augments + if batch_augments is not None: + self.batch_augments = RandomBatchAugment(batch_augments) + else: + self.batch_augments = None def forward(self, data: Sequence[dict], training: bool = False) -> Tuple[torch.Tensor, list]: - """Perform normalization、padding and bgr2rgb conversion based on - ``BaseDataPreprocessor``. + """Perform normalization, padding, bgr2rgb conversion and batch + augmentation based on ``BaseDataPreprocessor``. Args: data (Sequence[dict]): data sampled from dataloader. @@ -98,8 +104,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor): else: batch_inputs = batch_inputs.to(torch.float32) + # ----- Batch Aug ---- if training and self.batch_augments is not None: - inputs, batch_data_samples = self.batch_augments( - inputs, batch_data_samples) + batch_inputs, batch_data_samples = self.batch_augments( + batch_inputs, batch_data_samples) return batch_inputs, batch_data_samples diff --git a/mmcls/registry/__init__.py b/mmcls/registry/__init__.py index df038768..f074e69a 100644 --- a/mmcls/registry/__init__.py +++ b/mmcls/registry/__init__.py @@ -1,14 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, - MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, - OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS, - RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, - VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) +from .registry import (BATCH_AUGMENTS, DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, + METRICS, MODEL_WRAPPERS, MODELS, + OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, + PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, + TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, + WEIGHT_INITIALIZERS) __all__ = [ 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', - 'VISUALIZERS' + 'VISUALIZERS', 'BATCH_AUGMENTS' ] diff --git a/mmcls/registry/registry.py b/mmcls/registry/registry.py index b108cd0e..e524f9a6 100644 --- a/mmcls/registry/registry.py +++ b/mmcls/registry/registry.py @@ -53,6 +53,8 @@ MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS) # manage all kinds of weight initialization modules like `Uniform` WEIGHT_INITIALIZERS = Registry( 'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS) +# manage all kinds of batch augmentations like Mixup and CutMix. +BATCH_AUGMENTS = Registry('batch augment') # Registries For Optimizer and the related # manage all kinds of optimizers like `SGD` and `Adam` diff --git a/tests/test_models/test_classifiers.py b/tests/test_models/test_classifiers.py index be8b39af..dd554bff 100644 --- a/tests/test_models/test_classifiers.py +++ b/tests/test_models/test_classifiers.py @@ -47,9 +47,7 @@ class TestImageClassifier(TestCase): # test set batch augmentation from train_cfg cfg = { **self.DEFAULT_ARGS, 'train_cfg': - dict(augments=[ - dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.) - ]) + dict(augments=dict(type='Mixup', alpha=1., num_classes=10)) } model: ImageClassifier = MODELS.build(cfg) self.assertIsNotNone(model.data_preprocessor.batch_augments) diff --git a/tests/test_models/test_utils/test_augment.py b/tests/test_models/test_utils/test_augment.py deleted file mode 100644 index d1987fae..00000000 --- a/tests/test_models/test_utils/test_augment.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch - -from mmcls.models.utils import Augments - -augment_cfgs = [ - dict(type='BatchCutMix', alpha=1., prob=1.), - dict(type='BatchMixup', alpha=1., prob=1.), - dict(type='Identity', prob=1.), - dict(type='BatchResizeMix', alpha=1., prob=1.) -] - - -def test_augments(): - imgs = torch.randn(4, 3, 32, 32) - labels = torch.randint(0, 10, (4, )) - - # Test cutmix - augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.) - augs = Augments(augments_cfg) - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) - - # Test mixup - augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.) - augs = Augments(augments_cfg) - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) - - # Test resizemix - augments_cfg = dict( - type='BatchResizeMix', alpha=1., num_classes=10, prob=1.) - augs = Augments(augments_cfg) - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) - - # Test cutmixup - augments_cfg = [ - dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), - dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3) - ] - augs = Augments(augments_cfg) - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) - - augments_cfg = [ - dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), - dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5) - ] - augs = Augments(augments_cfg) - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) - - augments_cfg = [ - dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), - dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3), - dict(type='Identity', num_classes=10, prob=0.2) - ] - augs = Augments(augments_cfg) - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) - - -@pytest.mark.parametrize('cfg', augment_cfgs) -def test_binary_augment(cfg): - - cfg_ = dict(num_classes=1, **cfg) - augs = Augments(cfg_) - - imgs = torch.randn(4, 3, 32, 32) - labels = torch.randint(0, 2, (4, 1)).float() - - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 1)) - - -@pytest.mark.parametrize('cfg', augment_cfgs) -def test_multilabel_augment(cfg): - - cfg_ = dict(num_classes=10, **cfg) - augs = Augments(cfg_) - - imgs = torch.randn(4, 3, 32, 32) - labels = torch.randint(0, 2, (4, 10)).float() - - mixed_imgs, mixed_labels = augs(imgs, labels) - assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) - assert mixed_labels.shape == torch.Size((4, 10)) diff --git a/tests/test_models/test_utils/test_batch_augments.py b/tests/test_models/test_utils/test_batch_augments.py new file mode 100644 index 00000000..2a84b99c --- /dev/null +++ b/tests/test_models/test_utils/test_batch_augments.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import numpy as np +import torch + +from mmcls.core import ClsDataSample +from mmcls.models import Mixup, RandomBatchAugment +from mmcls.registry import BATCH_AUGMENTS + +augment_cfgs = [ + dict(type='BatchCutMix', alpha=1., prob=1.), + dict(type='BatchMixup', alpha=1., prob=1.), + dict(type='Identity', prob=1.), + dict(type='BatchResizeMix', alpha=1., prob=1.) +] + + +class TestRandomBatchAugment(TestCase): + + def test_initialize(self): + # test single augmentation + augments = dict(type='Mixup', alpha=1.) + batch_augments = RandomBatchAugment(augments) + self.assertIsInstance(batch_augments.augments, list) + self.assertEqual(len(batch_augments.augments), 1) + + # test specify augments with object + augments = Mixup(alpha=1.) + batch_augments = RandomBatchAugment(augments) + self.assertIsInstance(batch_augments.augments, list) + self.assertEqual(len(batch_augments.augments), 1) + + # test multiple augmentation + augments = [ + dict(type='Mixup', alpha=1.), + dict(type='CutMix', alpha=0.8), + ] + batch_augments = RandomBatchAugment(augments) + # mixup, cutmix + self.assertEqual(len(batch_augments.augments), 2) + self.assertIsNone(batch_augments.probs) + + # test specify probs + augments = [ + dict(type='Mixup', alpha=1.), + dict(type='CutMix', alpha=0.8), + ] + batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3]) + # mixup, cutmix and None + self.assertEqual(len(batch_augments.augments), 3) + self.assertAlmostEqual(batch_augments.probs[-1], 0.2) + + # test assertion + with self.assertRaisesRegex(AssertionError, 'Got 2 vs 1'): + RandomBatchAugment(augments, probs=0.5) + + with self.assertRaisesRegex(AssertionError, 'exceeds 1.'): + RandomBatchAugment(augments, probs=[0.5, 0.6]) + + def test_call(self): + inputs = torch.rand(2, 3, 224, 224) + data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] + + augments = [ + dict(type='Mixup', alpha=1.), + dict(type='CutMix', alpha=0.8), + ] + batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3]) + + with patch('numpy.random', np.random.RandomState(0)): + batch_augments.augments[1] = MagicMock() + batch_augments(inputs, data_samples) + batch_augments.augments[1].assert_called_once_with( + inputs, data_samples) + + augments = [ + dict(type='Mixup', alpha=1.), + dict(type='CutMix', alpha=0.8), + ] + batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0]) + mixed_inputs, mixed_samples = batch_augments(inputs, data_samples) + self.assertIs(mixed_inputs, inputs) + self.assertIs(mixed_samples, data_samples) + + +class TestMixup(TestCase): + DEFAULT_ARGS = dict(type='Mixup', alpha=1.) + + def test_initialize(self): + with self.assertRaises(AssertionError): + cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'} + BATCH_AUGMENTS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'} + BATCH_AUGMENTS.build(cfg) + + def test_call(self): + inputs = torch.rand(2, 3, 224, 224) + data_samples = [ + ClsDataSample(metainfo={ + 'num_classes': 10 + }).set_gt_label(1) for _ in range(2) + ] + + # test get num_classes from data_samples + mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS) + mixed_inputs, mixed_samples = mixup(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + + with self.assertRaisesRegex(RuntimeError, 'Not specify'): + data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] + mixup(inputs, data_samples) + + # test binary classification + cfg = {**self.DEFAULT_ARGS, 'num_classes': 1} + mixup = BATCH_AUGMENTS.build(cfg) + data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)] + + mixed_inputs, mixed_samples = mixup(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, )) + + # test multi-label classification + cfg = {**self.DEFAULT_ARGS, 'num_classes': 5} + mixup = BATCH_AUGMENTS.build(cfg) + data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)] + + mixed_inputs, mixed_samples = mixup(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, )) + + +class TestCutMix(TestCase): + DEFAULT_ARGS = dict(type='CutMix', alpha=1.) + + def test_initialize(self): + with self.assertRaises(AssertionError): + cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'} + BATCH_AUGMENTS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'} + BATCH_AUGMENTS.build(cfg) + + def test_call(self): + inputs = torch.rand(2, 3, 224, 224) + data_samples = [ + ClsDataSample(metainfo={ + 'num_classes': 10 + }).set_gt_label(1) for _ in range(2) + ] + + # test with cutmix_minmax + cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)} + cutmix = BATCH_AUGMENTS.build(cfg) + mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + + # test without correct_lam + cfg = {**self.DEFAULT_ARGS, 'correct_lam': False} + cutmix = BATCH_AUGMENTS.build(cfg) + mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + + # test get num_classes from data_samples + cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS) + mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + + with self.assertRaisesRegex(RuntimeError, 'Not specify'): + data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] + cutmix(inputs, data_samples) + + # test binary classification + cfg = {**self.DEFAULT_ARGS, 'num_classes': 1} + cutmix = BATCH_AUGMENTS.build(cfg) + data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)] + + mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, )) + + # test multi-label classification + cfg = {**self.DEFAULT_ARGS, 'num_classes': 5} + cutmix = BATCH_AUGMENTS.build(cfg) + data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)] + + mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, )) + + +class TestResizeMix(TestCase): + DEFAULT_ARGS = dict(type='ResizeMix', alpha=1.) + + def test_initialize(self): + with self.assertRaises(AssertionError): + cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'} + BATCH_AUGMENTS.build(cfg) + + with self.assertRaises(AssertionError): + cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'} + BATCH_AUGMENTS.build(cfg) + + def test_call(self): + inputs = torch.rand(2, 3, 224, 224) + data_samples = [ + ClsDataSample(metainfo={ + 'num_classes': 10 + }).set_gt_label(1) for _ in range(2) + ] + + # test get num_classes from data_samples + mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS) + mixed_inputs, mixed_samples = mixup(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + + with self.assertRaisesRegex(RuntimeError, 'Not specify'): + data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] + mixup(inputs, data_samples) + + # test binary classification + cfg = {**self.DEFAULT_ARGS, 'num_classes': 1} + mixup = BATCH_AUGMENTS.build(cfg) + data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)] + + mixed_inputs, mixed_samples = mixup(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, )) + + # test multi-label classification + cfg = {**self.DEFAULT_ARGS, 'num_classes': 5} + mixup = BATCH_AUGMENTS.build(cfg) + data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)] + + mixed_inputs, mixed_samples = mixup(inputs, data_samples) + self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) + self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, )) diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py index bb465e8f..854848ba 100644 --- a/tests/test_models/test_utils/test_data_preprocessor.py +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -4,7 +4,7 @@ from unittest import TestCase import torch from mmcls.core import ClsDataSample -from mmcls.models import ClsDataPreprocessor +from mmcls.models import ClsDataPreprocessor, RandomBatchAugment from mmcls.registry import MODELS from mmcls.utils import register_all_modules @@ -62,16 +62,25 @@ class TestClsDataPreprocessor(TestCase): self.assertIsNone(data_samples) def test_batch_augmentation(self): - # TODO: complete this test after refactoring batch augmentation - cfg = dict( type='ClsDataPreprocessor', batch_augments=[ - dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.) + dict(type='Mixup', alpha=0.8, num_classes=10), + dict(type='CutMix', alpha=1., num_classes=10) ]) processor: ClsDataPreprocessor = MODELS.build(cfg) - self.assertIsNotNone(processor.batch_augments) + self.assertIsInstance(processor.batch_augments, RandomBatchAugment) + data = [{ + 'inputs': torch.randint(0, 256, (3, 224, 224)), + 'data_sample': ClsDataSample().set_gt_label(1) + }] + _, data_samples = processor(data, training=True) cfg['batch_augments'] = None processor: ClsDataPreprocessor = MODELS.build(cfg) self.assertIsNone(processor.batch_augments) + data = [{ + 'inputs': torch.randint(0, 256, (3, 224, 224)), + }] + _, data_samples = processor(data, training=True) + self.assertIsNone(data_samples)