From 29f066f7fba8a87ba982bbed57b0af253bb3d9e1 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Mon, 17 Oct 2022 17:08:18 +0800 Subject: [PATCH] [Improve] Speed up data preprocessor. (#1064) * [Improve] Speed up data preprocessor. * Add ClsDataSample serialization override functions. * Add unit tests * Modify configs to fit new mixup args. * Fix `num_classes` of the ImageNet-21k config. * Update docs. --- configs/_base_/datasets/cifar100_bs16.py | 1 + configs/_base_/datasets/cifar10_bs16.py | 1 + configs/_base_/datasets/cub_bs8_384.py | 1 + configs/_base_/datasets/cub_bs8_448.py | 1 + configs/_base_/datasets/imagenet21k_bs128.py | 1 + .../_base_/datasets/imagenet_bs128_mbv3.py | 1 + .../imagenet_bs128_poolformer_medium_224.py | 1 + .../imagenet_bs128_poolformer_small_224.py | 1 + .../_base_/datasets/imagenet_bs256_rsb_a12.py | 1 + .../_base_/datasets/imagenet_bs256_rsb_a3.py | 1 + configs/_base_/datasets/imagenet_bs32.py | 1 + .../datasets/imagenet_bs32_pil_bicubic.py | 1 + .../datasets/imagenet_bs32_pil_resize.py | 1 + configs/_base_/datasets/imagenet_bs64.py | 1 + .../_base_/datasets/imagenet_bs64_autoaug.py | 1 + .../datasets/imagenet_bs64_convmixer_224.py | 1 + .../datasets/imagenet_bs64_mixer_224.py | 1 + .../datasets/imagenet_bs64_pil_resize.py | 1 + .../imagenet_bs64_pil_resize_autoaug.py | 1 + .../_base_/datasets/imagenet_bs64_swin_224.py | 1 + .../_base_/datasets/imagenet_bs64_swin_384.py | 1 + .../_base_/datasets/imagenet_bs64_t2t_224.py | 1 + configs/_base_/datasets/voc_bs16.py | 3 + configs/_base_/models/conformer/base-p16.py | 4 +- configs/_base_/models/conformer/small-p16.py | 4 +- configs/_base_/models/conformer/small-p32.py | 4 +- configs/_base_/models/conformer/tiny-p16.py | 4 +- .../_base_/models/repvgg-B3_lbs-mixup_in1k.py | 2 +- configs/_base_/models/resnest101.py | 5 +- configs/_base_/models/resnest200.py | 5 +- configs/_base_/models/resnest269.py | 5 +- configs/_base_/models/resnest50.py | 5 +- configs/_base_/models/resnet50_cifar_mixup.py | 2 +- configs/_base_/models/resnet50_mixup.py | 2 +- .../models/swin_transformer/base_224.py | 4 +- .../models/swin_transformer/small_224.py | 4 +- .../models/swin_transformer/tiny_224.py | 4 +- configs/_base_/models/t2t-vit-t-14.py | 4 +- configs/_base_/models/t2t-vit-t-19.py | 4 +- configs/_base_/models/t2t-vit-t-24.py | 4 +- configs/_base_/models/twins_pcpvt_base.py | 4 +- configs/_base_/models/twins_svt_base.py | 4 +- configs/_base_/models/van/van_small.py | 4 +- configs/_base_/models/van/van_tiny.py | 4 +- configs/deit/deit-small_pt-4xb256_in1k.py | 4 +- configs/lenet/lenet5_mnist.py | 2 +- configs/resnet/metafile.yml | 38 ----- .../resnet50_8xb256-rsb-a1-600e_in1k.py | 4 +- .../resnet50_8xb256-rsb-a2-300e_in1k.py | 4 +- .../resnet50_8xb256-rsb-a3-100e_in1k.py | 4 +- .../vit-base-p16_pt-32xb128-mae_in1k-224.py | 4 +- .../vit-base-p16_pt-64xb64_in1k-224.py | 2 +- .../vit-base-p32_pt-64xb64_in1k-224.py | 2 +- .../vit-large-p16_pt-64xb64_in1k-224.py | 2 +- .../vit-large-p32_pt-64xb64_in1k-224.py | 2 +- docs/en/api/data_process.rst | 8 +- docs/en/user_guides/config.md | 2 +- docs/zh_CN/user_guides/config.md | 6 +- mmcls/evaluation/metrics/single_label.py | 15 +- mmcls/models/utils/batch_augments/cutmix.py | 6 +- mmcls/models/utils/batch_augments/mixup.py | 40 +---- .../models/utils/batch_augments/resizemix.py | 9 +- mmcls/models/utils/batch_augments/wrapper.py | 17 +- mmcls/models/utils/data_preprocessor.py | 57 +++++-- mmcls/structures/__init__.py | 7 +- mmcls/structures/cls_data_sample.py | 147 +++++++++++------- mmcls/structures/utils.py | 96 ++++++++++++ .../test_metrics/test_single_label.py | 15 +- tests/test_models/test_classifiers.py | 2 +- .../test_utils/test_batch_augments.py | 129 ++++----------- .../test_utils/test_data_preprocessor.py | 9 +- tests/test_structures/test_datasample.py | 50 +----- tests/test_structures/test_utils.py | 93 +++++++++++ 73 files changed, 505 insertions(+), 378 deletions(-) create mode 100644 mmcls/structures/utils.py create mode 100644 tests/test_structures/test_utils.py diff --git a/configs/_base_/datasets/cifar100_bs16.py b/configs/_base_/datasets/cifar100_bs16.py index c6b3dfeb..78a74fa6 100644 --- a/configs/_base_/datasets/cifar100_bs16.py +++ b/configs/_base_/datasets/cifar100_bs16.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'CIFAR100' data_preprocessor = dict( + num_classes=100, # RGB format normalization parameters mean=[129.304, 124.070, 112.434], std=[68.170, 65.392, 70.418], diff --git a/configs/_base_/datasets/cifar10_bs16.py b/configs/_base_/datasets/cifar10_bs16.py index 198f0a9e..f29cfcd9 100644 --- a/configs/_base_/datasets/cifar10_bs16.py +++ b/configs/_base_/datasets/cifar10_bs16.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'CIFAR10' data_preprocessor = dict( + num_classes=10, # RGB format normalization parameters mean=[125.307, 122.961, 113.8575], std=[51.5865, 50.847, 51.255], diff --git a/configs/_base_/datasets/cub_bs8_384.py b/configs/_base_/datasets/cub_bs8_384.py index 1feb9437..17139dcb 100644 --- a/configs/_base_/datasets/cub_bs8_384.py +++ b/configs/_base_/datasets/cub_bs8_384.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'CUB' data_preprocessor = dict( + num_classes=200, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/cub_bs8_448.py b/configs/_base_/datasets/cub_bs8_448.py index 094cf618..0b07a1a0 100644 --- a/configs/_base_/datasets/cub_bs8_448.py +++ b/configs/_base_/datasets/cub_bs8_448.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'CUB' data_preprocessor = dict( + num_classes=200, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], # convert image from BGR to RGB diff --git a/configs/_base_/datasets/imagenet21k_bs128.py b/configs/_base_/datasets/imagenet21k_bs128.py index 704832c8..0f24b8a0 100644 --- a/configs/_base_/datasets/imagenet21k_bs128.py +++ b/configs/_base_/datasets/imagenet21k_bs128.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet21k' data_preprocessor = dict( + num_classes=21842, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs128_mbv3.py b/configs/_base_/datasets/imagenet_bs128_mbv3.py index 42d723fe..d64f258b 100644 --- a/configs/_base_/datasets/imagenet_bs128_mbv3.py +++ b/configs/_base_/datasets/imagenet_bs128_mbv3.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs128_poolformer_medium_224.py b/configs/_base_/datasets/imagenet_bs128_poolformer_medium_224.py index ae9a36f8..1f03d96d 100644 --- a/configs/_base_/datasets/imagenet_bs128_poolformer_medium_224.py +++ b/configs/_base_/datasets/imagenet_bs128_poolformer_medium_224.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs128_poolformer_small_224.py b/configs/_base_/datasets/imagenet_bs128_poolformer_small_224.py index 4add920f..d8785707 100644 --- a/configs/_base_/datasets/imagenet_bs128_poolformer_small_224.py +++ b/configs/_base_/datasets/imagenet_bs128_poolformer_small_224.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs256_rsb_a12.py b/configs/_base_/datasets/imagenet_bs256_rsb_a12.py index 18a15957..3038d46a 100644 --- a/configs/_base_/datasets/imagenet_bs256_rsb_a12.py +++ b/configs/_base_/datasets/imagenet_bs256_rsb_a12.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs256_rsb_a3.py b/configs/_base_/datasets/imagenet_bs256_rsb_a3.py index b24bd3ea..53a17c20 100644 --- a/configs/_base_/datasets/imagenet_bs256_rsb_a3.py +++ b/configs/_base_/datasets/imagenet_bs256_rsb_a3.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs32.py b/configs/_base_/datasets/imagenet_bs32.py index 3345e195..5bfa94aa 100644 --- a/configs/_base_/datasets/imagenet_bs32.py +++ b/configs/_base_/datasets/imagenet_bs32.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs32_pil_bicubic.py b/configs/_base_/datasets/imagenet_bs32_pil_bicubic.py index d0190463..aa34c574 100644 --- a/configs/_base_/datasets/imagenet_bs32_pil_bicubic.py +++ b/configs/_base_/datasets/imagenet_bs32_pil_bicubic.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs32_pil_resize.py b/configs/_base_/datasets/imagenet_bs32_pil_resize.py index fb7c8837..48234eb1 100644 --- a/configs/_base_/datasets/imagenet_bs32_pil_resize.py +++ b/configs/_base_/datasets/imagenet_bs32_pil_resize.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64.py b/configs/_base_/datasets/imagenet_bs64.py index 44c237dc..ea2db282 100644 --- a/configs/_base_/datasets/imagenet_bs64.py +++ b/configs/_base_/datasets/imagenet_bs64.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_autoaug.py b/configs/_base_/datasets/imagenet_bs64_autoaug.py index ea0c08a9..2d4c4469 100644 --- a/configs/_base_/datasets/imagenet_bs64_autoaug.py +++ b/configs/_base_/datasets/imagenet_bs64_autoaug.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_convmixer_224.py b/configs/_base_/datasets/imagenet_bs64_convmixer_224.py index a2c924f8..14932cfb 100644 --- a/configs/_base_/datasets/imagenet_bs64_convmixer_224.py +++ b/configs/_base_/datasets/imagenet_bs64_convmixer_224.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_mixer_224.py b/configs/_base_/datasets/imagenet_bs64_mixer_224.py index fe3ed25c..9a4a6d44 100644 --- a/configs/_base_/datasets/imagenet_bs64_mixer_224.py +++ b/configs/_base_/datasets/imagenet_bs64_mixer_224.py @@ -3,6 +3,7 @@ dataset_type = 'ImageNet' # Google research usually use the below normalization setting. data_preprocessor = dict( + num_classes=1000, mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], # convert image from BGR to RGB diff --git a/configs/_base_/datasets/imagenet_bs64_pil_resize.py b/configs/_base_/datasets/imagenet_bs64_pil_resize.py index d4a08f7b..022dda52 100644 --- a/configs/_base_/datasets/imagenet_bs64_pil_resize.py +++ b/configs/_base_/datasets/imagenet_bs64_pil_resize.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py b/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py index bb4d6eec..fd2709f4 100644 --- a/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py +++ b/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_swin_224.py b/configs/_base_/datasets/imagenet_bs64_swin_224.py index 30135214..1a54932a 100644 --- a/configs/_base_/datasets/imagenet_bs64_swin_224.py +++ b/configs/_base_/datasets/imagenet_bs64_swin_224.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_swin_384.py b/configs/_base_/datasets/imagenet_bs64_swin_384.py index ea351b4e..1e64f6aa 100644 --- a/configs/_base_/datasets/imagenet_bs64_swin_384.py +++ b/configs/_base_/datasets/imagenet_bs64_swin_384.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/imagenet_bs64_t2t_224.py b/configs/_base_/datasets/imagenet_bs64_t2t_224.py index 7cf4aca5..249806ab 100644 --- a/configs/_base_/datasets/imagenet_bs64_t2t_224.py +++ b/configs/_base_/datasets/imagenet_bs64_t2t_224.py @@ -1,6 +1,7 @@ # dataset settings dataset_type = 'ImageNet' data_preprocessor = dict( + num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], diff --git a/configs/_base_/datasets/voc_bs16.py b/configs/_base_/datasets/voc_bs16.py index 1e68b121..c5540237 100644 --- a/configs/_base_/datasets/voc_bs16.py +++ b/configs/_base_/datasets/voc_bs16.py @@ -1,11 +1,14 @@ # dataset settings dataset_type = 'VOC' data_preprocessor = dict( + num_classes=20, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], # convert image from BGR to RGB to_rgb=True, + # generate onehot-format labels for multi-label classification. + to_onehot=True, ) train_pipeline = [ diff --git a/configs/_base_/models/conformer/base-p16.py b/configs/_base_/models/conformer/base-p16.py index f840914a..959da505 100644 --- a/configs/_base_/models/conformer/base-p16.py +++ b/configs/_base_/models/conformer/base-p16.py @@ -17,7 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/conformer/small-p16.py b/configs/_base_/models/conformer/small-p16.py index a913abcb..2e4f9f80 100644 --- a/configs/_base_/models/conformer/small-p16.py +++ b/configs/_base_/models/conformer/small-p16.py @@ -17,7 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/conformer/small-p32.py b/configs/_base_/models/conformer/small-p32.py index bd9c00b4..f73811ff 100644 --- a/configs/_base_/models/conformer/small-p32.py +++ b/configs/_base_/models/conformer/small-p32.py @@ -21,7 +21,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/conformer/tiny-p16.py b/configs/_base_/models/conformer/tiny-p16.py index 1edb3388..fa9753b6 100644 --- a/configs/_base_/models/conformer/tiny-p16.py +++ b/configs/_base_/models/conformer/tiny-p16.py @@ -17,7 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py b/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py index 05897733..d88e687b 100644 --- a/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py +++ b/configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py @@ -18,5 +18,5 @@ model = dict( num_classes=1000), topk=(1, 5), ), - train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), ) diff --git a/configs/_base_/models/resnest101.py b/configs/_base_/models/resnest101.py index 97f7749c..3780c154 100644 --- a/configs/_base_/models/resnest101.py +++ b/configs/_base_/models/resnest101.py @@ -20,5 +20,6 @@ model = dict( reduction='mean', loss_weight=1.0), topk=(1, 5), - cal_acc=False)) -train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000)) + cal_acc=False), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), +) diff --git a/configs/_base_/models/resnest200.py b/configs/_base_/models/resnest200.py index 46100178..40d8f03e 100644 --- a/configs/_base_/models/resnest200.py +++ b/configs/_base_/models/resnest200.py @@ -20,5 +20,6 @@ model = dict( reduction='mean', loss_weight=1.0), topk=(1, 5), - cal_acc=False)) -train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000)) + cal_acc=False), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), +) diff --git a/configs/_base_/models/resnest269.py b/configs/_base_/models/resnest269.py index ad365d03..c37626f5 100644 --- a/configs/_base_/models/resnest269.py +++ b/configs/_base_/models/resnest269.py @@ -20,5 +20,6 @@ model = dict( reduction='mean', loss_weight=1.0), topk=(1, 5), - cal_acc=False)) -train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000)) + cal_acc=False), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), +) diff --git a/configs/_base_/models/resnest50.py b/configs/_base_/models/resnest50.py index 15269d4a..51c90e86 100644 --- a/configs/_base_/models/resnest50.py +++ b/configs/_base_/models/resnest50.py @@ -19,5 +19,6 @@ model = dict( reduction='mean', loss_weight=1.0), topk=(1, 5), - cal_acc=False)) -train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000)) + cal_acc=False), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), +) diff --git a/configs/_base_/models/resnet50_cifar_mixup.py b/configs/_base_/models/resnet50_cifar_mixup.py index 7805cbd7..f165c246 100644 --- a/configs/_base_/models/resnet50_cifar_mixup.py +++ b/configs/_base_/models/resnet50_cifar_mixup.py @@ -13,5 +13,5 @@ model = dict( num_classes=10, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), - train_cfg=dict(augments=dict(type='Mixup', alpha=1., num_classes=10)), + train_cfg=dict(augments=dict(type='Mixup', alpha=1.)), ) diff --git a/configs/_base_/models/resnet50_mixup.py b/configs/_base_/models/resnet50_mixup.py index 8a783a1e..23130a69 100644 --- a/configs/_base_/models/resnet50_mixup.py +++ b/configs/_base_/models/resnet50_mixup.py @@ -13,5 +13,5 @@ model = dict( num_classes=1000, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), - train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), ) diff --git a/configs/_base_/models/swin_transformer/base_224.py b/configs/_base_/models/swin_transformer/base_224.py index 28d5b529..b7c277f2 100644 --- a/configs/_base_/models/swin_transformer/base_224.py +++ b/configs/_base_/models/swin_transformer/base_224.py @@ -17,7 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/swin_transformer/small_224.py b/configs/_base_/models/swin_transformer/small_224.py index ea4e070e..d87d9d9a 100644 --- a/configs/_base_/models/swin_transformer/small_224.py +++ b/configs/_base_/models/swin_transformer/small_224.py @@ -18,7 +18,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/swin_transformer/tiny_224.py b/configs/_base_/models/swin_transformer/tiny_224.py index feddc581..f1781cf5 100644 --- a/configs/_base_/models/swin_transformer/tiny_224.py +++ b/configs/_base_/models/swin_transformer/tiny_224.py @@ -17,7 +17,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/t2t-vit-t-14.py b/configs/_base_/models/t2t-vit-t-14.py index 3f4c8603..58ea660e 100644 --- a/configs/_base_/models/t2t-vit-t-14.py +++ b/configs/_base_/models/t2t-vit-t-14.py @@ -36,7 +36,7 @@ model = dict( topk=(1, 5), init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=num_classes), - dict(type='CutMix', alpha=1.0, num_classes=num_classes), + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), ]), ) diff --git a/configs/_base_/models/t2t-vit-t-19.py b/configs/_base_/models/t2t-vit-t-19.py index 65e4dc99..51741c7a 100644 --- a/configs/_base_/models/t2t-vit-t-19.py +++ b/configs/_base_/models/t2t-vit-t-19.py @@ -36,7 +36,7 @@ model = dict( topk=(1, 5), init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=num_classes), - dict(type='CutMix', alpha=1.0, num_classes=num_classes), + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), ]), ) diff --git a/configs/_base_/models/t2t-vit-t-24.py b/configs/_base_/models/t2t-vit-t-24.py index f2e60185..ad772cf6 100644 --- a/configs/_base_/models/t2t-vit-t-24.py +++ b/configs/_base_/models/t2t-vit-t-24.py @@ -36,7 +36,7 @@ model = dict( topk=(1, 5), init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)), train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=num_classes), - dict(type='CutMix', alpha=1.0, num_classes=num_classes), + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), ]), ) diff --git a/configs/_base_/models/twins_pcpvt_base.py b/configs/_base_/models/twins_pcpvt_base.py index 32c8e4a3..14e46bae 100644 --- a/configs/_base_/models/twins_pcpvt_base.py +++ b/configs/_base_/models/twins_pcpvt_base.py @@ -25,7 +25,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/twins_svt_base.py b/configs/_base_/models/twins_svt_base.py index 974e76b4..a37385b0 100644 --- a/configs/_base_/models/twins_svt_base.py +++ b/configs/_base_/models/twins_svt_base.py @@ -25,7 +25,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/van/van_small.py b/configs/_base_/models/van/van_small.py index 68228c5c..29393c63 100644 --- a/configs/_base_/models/van/van_small.py +++ b/configs/_base_/models/van/van_small.py @@ -16,7 +16,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/_base_/models/van/van_tiny.py b/configs/_base_/models/van/van_tiny.py index c765dbc9..9cf5b288 100644 --- a/configs/_base_/models/van/van_tiny.py +++ b/configs/_base_/models/van/van_tiny.py @@ -16,7 +16,7 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/deit/deit-small_pt-4xb256_in1k.py b/configs/deit/deit-small_pt-4xb256_in1k.py index e28d12f3..b96d84ec 100644 --- a/configs/deit/deit-small_pt-4xb256_in1k.py +++ b/configs/deit/deit-small_pt-4xb256_in1k.py @@ -27,8 +27,8 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.), ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/lenet/lenet5_mnist.py b/configs/lenet/lenet5_mnist.py index e5cfd80d..b74fc0ec 100644 --- a/configs/lenet/lenet5_mnist.py +++ b/configs/lenet/lenet5_mnist.py @@ -10,7 +10,7 @@ model = dict( # dataset settings dataset_type = 'MNIST' -data_preprocessor = dict(mean=[33.46], std=[78.87]) +data_preprocessor = dict(mean=[33.46], std=[78.87], num_classes=10) pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')] diff --git a/configs/resnet/metafile.yml b/configs/resnet/metafile.yml index b29c2205..29aa84df 100644 --- a/configs/resnet/metafile.yml +++ b/configs/resnet/metafile.yml @@ -298,44 +298,6 @@ Models: Top 5 Accuracy: 93.80 Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py - - Name: wide-resnet50_3rdparty_8xb32_in1k - Metadata: - FLOPs: 11440000000 # 11.44G - Parameters: 68880000 # 68.88M - Training Techniques: - - SGD with Momentum - - Weight Decay - In Collection: ResNet - Results: - - Task: Image Classification - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 78.48 - Top 5 Accuracy: 94.08 - Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth - Config: configs/resnet/wide-resnet50_8xb32_in1k.py - Converted From: - Weights: https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth - Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py - - Name: wide-resnet101_3rdparty_8xb32_in1k - Metadata: - FLOPs: 22810000000 # 22.81G - Parameters: 126890000 # 126.89M - Training Techniques: - - SGD with Momentum - - Weight Decay - In Collection: ResNet - Results: - - Task: Image Classification - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 78.84 - Top 5 Accuracy: 94.28 - Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth - Config: configs/resnet/wide-resnet101_8xb32_in1k.py - Converted From: - Weights: https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth - Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py - Name: resnetv1c50_8xb32_in1k Metadata: FLOPs: 4360000000 diff --git a/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py index 1c213127..a4ea1598 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py @@ -19,8 +19,8 @@ model = dict( use_sigmoid=True, )), train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.2, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.2), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py index a8e93003..df8edc03 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py @@ -13,8 +13,8 @@ model = dict( ), head=dict(loss=dict(use_sigmoid=True)), train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.1, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.1), + dict(type='CutMix', alpha=1.0) ])) # dataset settings diff --git a/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py b/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py index e6872a3b..3a36c584 100644 --- a/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py +++ b/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py @@ -10,8 +10,8 @@ model = dict( backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)), head=dict(loss=dict(use_sigmoid=True)), train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.1, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.1), + dict(type='CutMix', alpha=1.0) ]), ) diff --git a/configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py b/configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py index 44319f8c..a46bbb21 100644 --- a/configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py +++ b/configs/vision_transformer/vit-base-p16_pt-32xb128-mae_in1k-224.py @@ -26,8 +26,8 @@ model = dict( dict(type='Constant', layer='LayerNorm', val=1., bias=0.), ], train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=1000), - dict(type='CutMix', alpha=1.0, num_classes=1000) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) ])) # dataset settings diff --git a/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py index 0a9e5156..07be0e9a 100644 --- a/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py @@ -8,7 +8,7 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), ) # schedule setting diff --git a/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py index 83a92fca..9cfc7c47 100644 --- a/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-base-p32_pt-64xb64_in1k-224.py @@ -8,7 +8,7 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), ) # schedule setting diff --git a/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py index 0cf9d8e1..0d9bd283 100644 --- a/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-large-p16_pt-64xb64_in1k-224.py @@ -8,7 +8,7 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), ) # schedule setting diff --git a/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py b/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py index c1b5a3d8..61e17916 100644 --- a/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py +++ b/configs/vision_transformer/vit-large-p32_pt-64xb64_in1k-224.py @@ -8,7 +8,7 @@ _base_ = [ # model setting model = dict( head=dict(hidden_dim=3072), - train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)), + train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)), ) # schedule setting diff --git a/docs/en/api/data_process.rst b/docs/en/api/data_process.rst index 0b1bb083..8a4cc193 100644 --- a/docs/en/api/data_process.rst +++ b/docs/en/api/data_process.rst @@ -233,8 +233,8 @@ These augmentations are usually only used during training, therefore, we use the neck=..., head=..., train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=num_classes), - dict(type='CutMix', alpha=1.0, num_classes=num_classes), + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), ]), ) @@ -247,8 +247,8 @@ You can also speicy the probabilities of every batch augmentation by the ``probs neck=..., head=..., train_cfg=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=num_classes), - dict(type='CutMix', alpha=1.0, num_classes=num_classes), + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), ], probs=[0.3, 0.7]) ) diff --git a/docs/en/user_guides/config.md b/docs/en/user_guides/config.md index c4fe19fb..39e95c7c 100644 --- a/docs/en/user_guides/config.md +++ b/docs/en/user_guides/config.md @@ -289,7 +289,7 @@ _base_ = './resnet50_8xb32_in1k.py' # using CutMix batch augment model = dict( train_cfg=dict( - augments=dict(type='CutMix', alpha=1.0, num_classes=1000, prob=1.0) + augments=dict(type='CutMix', alpha=1.0) ) ) diff --git a/docs/zh_CN/user_guides/config.md b/docs/zh_CN/user_guides/config.md index 04917d67..1943529b 100644 --- a/docs/zh_CN/user_guides/config.md +++ b/docs/zh_CN/user_guides/config.md @@ -71,7 +71,7 @@ model = dict( backbone=dict( type='ResNet', # 主干网络类型 # 除了 `type` 之外的所有字段都来自 `ResNet` 类的 __init__ 方法 - # 您查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.ResNet.html + # 可查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.ResNet.html depth=50, num_stages=4, # 主干网络状态(stages)的数目,这些状态产生的特征图作为后续的 head 的输入。 out_indices=(3, ), # 输出的特征图输出索引。 @@ -81,7 +81,7 @@ model = dict( head=dict( type='LinearClsHead', # 分类颈网络类型 # 除了 `type` 之外的所有字段都来自 `LinearClsHead` 类的 __init__ 方法 - # 您查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.LinearClsHead.html + # 可查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.LinearClsHead.html num_classes=1000, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), # 损失函数配置信息 @@ -278,7 +278,7 @@ _base_ = './resnet50_8xb32_in1k.py' # 模型在之前的基础上使用 CutMix 训练增强 model = dict( train_cfg=dict( - augments=dict(type='CutMix', alpha=1.0, num_classes=1000, prob=1.0) + augments=dict(type='CutMix', alpha=1.0) ) ) diff --git a/mmcls/evaluation/metrics/single_label.py b/mmcls/evaluation/metrics/single_label.py index 78120ad7..b18b649f 100644 --- a/mmcls/evaluation/metrics/single_label.py +++ b/mmcls/evaluation/metrics/single_label.py @@ -296,6 +296,7 @@ class SingleLabelMetric(BaseMetric): true positives, false negatives and false positives. Defaults to "macro". + num_classes (Optional, int): The number of classes. Defaults to None. collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. @@ -349,6 +350,7 @@ class SingleLabelMetric(BaseMetric): thrs: Union[float, Sequence[Union[float, None]], None] = 0., items: Sequence[str] = ('precision', 'recall', 'f1-score'), average: Optional[str] = 'macro', + num_classes: Optional[int] = None, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) @@ -365,6 +367,7 @@ class SingleLabelMetric(BaseMetric): '"support".' self.items = tuple(items) self.average = average + self.num_classes = num_classes def process(self, data_batch, data_samples: Sequence[dict]): """Process one batch of data samples. @@ -383,12 +386,14 @@ class SingleLabelMetric(BaseMetric): gt_label = data_sample['gt_label'] if 'score' in pred_label: result['pred_score'] = pred_label['score'].cpu() - elif ('num_classes' in pred_label): - result['pred_label'] = pred_label['label'].cpu() - result['num_classes'] = pred_label['num_classes'] else: - raise ValueError('The `pred_label` in data_samples do not ' - 'have neither `score` nor `num_classes`.') + num_classes = self.num_classes or data_sample.get( + 'num_classes') + assert num_classes is not None, \ + 'The `num_classes` must be specified if `pred_label` has '\ + 'only `label`.' + result['pred_label'] = pred_label['label'].cpu() + result['num_classes'] = num_classes result['gt_label'] = gt_label['label'].cpu() # Save the result to `self.results`. self.results.append(result) diff --git a/mmcls/models/utils/batch_augments/cutmix.py b/mmcls/models/utils/batch_augments/cutmix.py index 500ac738..5d0920e7 100644 --- a/mmcls/models/utils/batch_augments/cutmix.py +++ b/mmcls/models/utils/batch_augments/cutmix.py @@ -24,9 +24,6 @@ class CutMix(Mixup): alpha (float): Parameters for Beta distribution to generate the mixing ratio. It should be a positive number. More details can be found in :class:`Mixup`. - num_classes (int, optional): The number of classes. If not specified, - will try to get it from data samples during training. - Defaults to None. cutmix_minmax (List[float], optional): The min/max area ratio of the patches. If not None, the bounding-box of patches is uniform sampled within this ratio range, and the ``alpha`` will be ignored. @@ -49,10 +46,9 @@ class CutMix(Mixup): def __init__(self, alpha: float, - num_classes: Optional[int] = None, cutmix_minmax: Optional[List[float]] = None, correct_lam: bool = True): - super().__init__(alpha=alpha, num_classes=num_classes) + super().__init__(alpha=alpha) self.cutmix_minmax = cutmix_minmax self.correct_lam = correct_lam diff --git a/mmcls/models/utils/batch_augments/mixup.py b/mmcls/models/utils/batch_augments/mixup.py index 891fc8af..bbf249e4 100644 --- a/mmcls/models/utils/batch_augments/mixup.py +++ b/mmcls/models/utils/batch_augments/mixup.py @@ -1,12 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple +from typing import Tuple import numpy as np import torch -from mmengine.structures import LabelData from mmcls.registry import BATCH_AUGMENTS -from mmcls.structures import ClsDataSample @BATCH_AUGMENTS.register_module() @@ -22,9 +20,6 @@ class Mixup: alpha (float): Parameters for Beta distribution to generate the mixing ratio. It should be a positive number. More details are in the note. - num_classes (int, optional): The number of classes. If not specified, - will try to get it from data samples during training. - Defaults to None. Note: The :math:`\alpha` (``alpha``) determines a random distribution @@ -33,12 +28,10 @@ class Mixup: distribution. """ - def __init__(self, alpha: float, num_classes: Optional[int] = None): + def __init__(self, alpha: float): assert isinstance(alpha, float) and alpha > 0 - assert isinstance(num_classes, int) or num_classes is None self.alpha = alpha - self.num_classes = num_classes def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -62,28 +55,11 @@ class Mixup: return mixed_inputs, mixed_scores - def __call__(self, batch_inputs: torch.Tensor, - data_samples: List[ClsDataSample]): + def __call__(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): """Mix the batch inputs and batch data samples.""" - assert data_samples is not None, f'{self.__class__.__name__} ' \ - 'requires data_samples. If you only want to inference, please ' \ - 'disable it from preprocessing.' + assert batch_score.ndim == 2, \ + 'The input `batch_score` should be a one-hot format tensor, '\ + 'which shape should be ``(N, num_classes)``.' - if self.num_classes is None and 'num_classes' not in data_samples[0]: - raise RuntimeError( - 'Not specify the `num_classes` and cannot get it from ' - 'data samples. Please specify `num_classes` in the ' - f'{self.__class__.__name__}.') - num_classes = self.num_classes or data_samples[0].get('num_classes') - - batch_score = torch.stack([ - LabelData.label_to_onehot(sample.gt_label.label, num_classes) - for sample in data_samples - ]) - - mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score) - - for i, sample in enumerate(data_samples): - sample.set_gt_score(mixed_score[i]) - - return mixed_inputs, data_samples + mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score.float()) + return mixed_inputs, mixed_score diff --git a/mmcls/models/utils/batch_augments/resizemix.py b/mmcls/models/utils/batch_augments/resizemix.py index 4864ce06..fe01532e 100644 --- a/mmcls/models/utils/batch_augments/resizemix.py +++ b/mmcls/models/utils/batch_augments/resizemix.py @@ -21,9 +21,6 @@ class ResizeMix(CutMix): alpha (float): Parameters for Beta distribution to generate the mixing ratio. It should be a positive number. More details can be found in :class:`Mixup`. - num_classes (int, optional): The number of classes. If not specified, - will try to get it from data samples during training. - Defaults to None. lam_min(float): The minimum value of lam. Defaults to 0.1. lam_max(float): The maximum value of lam. Defaults to 0.8. interpolation (str): algorithm used for upsampling: @@ -57,17 +54,13 @@ class ResizeMix(CutMix): def __init__(self, alpha: float, - num_classes: Optional[int] = None, lam_min: float = 0.1, lam_max: float = 0.8, interpolation: str = 'bilinear', cutmix_minmax: Optional[List[float]] = None, correct_lam: bool = True): super().__init__( - alpha=alpha, - num_classes=num_classes, - cutmix_minmax=cutmix_minmax, - correct_lam=correct_lam) + alpha=alpha, cutmix_minmax=cutmix_minmax, correct_lam=correct_lam) self.lam_min = lam_min self.lam_max = lam_max self.interpolation = interpolation diff --git a/mmcls/models/utils/batch_augments/wrapper.py b/mmcls/models/utils/batch_augments/wrapper.py index 2759aa69..2b84dde5 100644 --- a/mmcls/models/utils/batch_augments/wrapper.py +++ b/mmcls/models/utils/batch_augments/wrapper.py @@ -17,13 +17,16 @@ class RandomBatchAugment: augmentations. If None, choose evenly. Defaults to None. Example: + >>> import torch + >>> import torch.nn.functional as F + >>> from mmcls.models import RandomBatchAugment >>> augments_cfg = [ - ... dict(type='CutMix', alpha=1., num_classes=10), - ... dict(type='Mixup', alpha=1., num_classes=10) + ... dict(type='CutMix', alpha=1.), + ... dict(type='Mixup', alpha=1.) ... ] >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) - >>> imgs = torch.randn(16, 3, 32, 32) - >>> label = torch.randint(0, 10, (16, )) + >>> imgs = torch.rand(16, 3, 32, 32) + >>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10) >>> imgs, label = batch_augment(imgs, label) .. note :: @@ -59,13 +62,13 @@ class RandomBatchAugment: self.probs = probs - def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]): + def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor): """Randomly apply batch augmentations to the batch inputs and batch data samples.""" aug_index = np.random.choice(len(self.augments), p=self.probs) aug = self.augments[aug_index] if aug is not None: - return aug(inputs, data_samples) + return aug(batch_input, batch_score) else: - return inputs, data_samples + return batch_input, batch_score.float() diff --git a/mmcls/models/utils/data_preprocessor.py b/mmcls/models/utils/data_preprocessor.py index 5fb6b551..7e338510 100644 --- a/mmcls/models/utils/data_preprocessor.py +++ b/mmcls/models/utils/data_preprocessor.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from mmengine.model import BaseDataPreprocessor, stack_batch from mmcls.registry import MODELS +from mmcls.structures import (batch_label_to_onehot, cat_batch_labels, + stack_batch_scores, tensor_split) from .batch_augments import RandomBatchAugment @@ -42,6 +44,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor): pad_value (Number): The padded pixel value. Defaults to 0. to_rgb (bool): whether to convert image from BGR to RGB. Defaults to False. + to_onehot (bool): Whether to generate one-hot format gt-labels and set + to data samples. Defaults to False. + num_classes (int, optional): The number of classes. Defaults to None. batch_augments (dict, optional): The batch augmentations settings, including "augments" and "probs". For more details, see :class:`mmcls.models.RandomBatchAugment`. @@ -53,11 +58,15 @@ class ClsDataPreprocessor(BaseDataPreprocessor): pad_size_divisor: int = 1, pad_value: Number = 0, to_rgb: bool = False, + to_onehot: bool = False, + num_classes: Optional[int] = None, batch_augments: Optional[dict] = None): super().__init__() self.pad_size_divisor = pad_size_divisor self.pad_value = pad_value self.to_rgb = to_rgb + self.to_onehot = to_onehot + self.num_classes = num_classes if mean is not None: assert std is not None, 'To enable the normalization in ' \ @@ -73,6 +82,13 @@ class ClsDataPreprocessor(BaseDataPreprocessor): if batch_augments is not None: self.batch_augments = RandomBatchAugment(**batch_augments) + if not self.to_onehot: + from mmengine.logging import MMLogger + MMLogger.get_current_instance().info( + 'Because batch augmentations are enabled, the data ' + 'preprocessor automatically enables the `to_onehot` ' + 'option to generate one-hot format labels.') + self.to_onehot = True else: self.batch_augments = None @@ -87,8 +103,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor): Returns: dict: Data in the same format as the model input. """ - data = self.cast_data(data) - inputs = data['inputs'] + inputs = self.cast_data(data['inputs']) if isinstance(inputs, torch.Tensor): # The branch if use `default_collate` as the collate_fn in the @@ -135,12 +150,36 @@ class ClsDataPreprocessor(BaseDataPreprocessor): inputs = stack_batch(processed_inputs, self.pad_size_divisor, self.pad_value) - # ----- Batch Aug ---- - if training and self.batch_augments is not None: - data_samples = data['data_samples'] - inputs, data_samples = self.batch_augments(inputs, data_samples) - data['data_samples'] = data_samples + data_samples = data.get('data_samples', None) + if data_samples is not None: + gt_labels = [sample.gt_label for sample in data_samples] + batch_label, label_indices = cat_batch_labels( + gt_labels, device=self.device) - data['inputs'] = inputs + batch_score = stack_batch_scores(gt_labels, device=self.device) + if batch_score is None and self.to_onehot: + assert batch_label is not None, \ + 'Cannot generate onehot format labels because no labels.' + num_classes = self.num_classes or data_samples[0].get( + 'num_classes') + assert num_classes is not None, \ + 'Cannot generate one-hot format labels because not set ' \ + '`num_classes` in `data_preprocessor`.' + batch_score = batch_label_to_onehot(batch_label, label_indices, + num_classes) - return data + # ----- Batch Augmentations ---- + if training and self.batch_augments is not None: + inputs, batch_score = self.batch_augments(inputs, batch_score) + + # ----- scatter labels and scores to data samples --- + if batch_label is not None: + for sample, label in zip( + data_samples, tensor_split(batch_label, + label_indices)): + sample.set_gt_label(label) + if batch_score is not None: + for sample, score in zip(data_samples, batch_score): + sample.set_gt_score(score) + + return {'inputs': inputs, 'data_samples': data_samples} diff --git a/mmcls/structures/__init__.py b/mmcls/structures/__init__.py index 61fa8ee3..0dc08443 100644 --- a/mmcls/structures/__init__.py +++ b/mmcls/structures/__init__.py @@ -1,4 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cls_data_sample import ClsDataSample +from .utils import (batch_label_to_onehot, cat_batch_labels, + stack_batch_scores, tensor_split) -__all__ = ['ClsDataSample'] +__all__ = [ + 'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels', + 'stack_batch_scores', 'tensor_split' +] diff --git a/mmcls/structures/cls_data_sample.py b/mmcls/structures/cls_data_sample.py index a3568eff..9e319a7b 100644 --- a/mmcls/structures/cls_data_sample.py +++ b/mmcls/structures/cls_data_sample.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing.reduction import ForkingPickler from numbers import Number from typing import Sequence, Union @@ -9,20 +10,18 @@ from mmengine.structures import BaseDataElement, LabelData from mmengine.utils import is_str -def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int], - num_classes: int = None) -> LabelData: - """Convert label of various python types to :obj:`mmengine.LabelData`. +def format_label( + value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor: + """Convert various python types to label-format tensor. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int`. Args: value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. - num_classes (int, optional): The number of classes. If not None, set - it to the metainfo. Defaults to None. Returns: - :obj:`mmengine.LabelData`: The foramtted label data. + :obj:`torch.Tensor`: The foramtted label tensor. """ # Handle single number @@ -37,15 +36,36 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int], value = torch.LongTensor([value]) elif not isinstance(value, torch.Tensor): raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' - metainfo = {} - if num_classes is not None: - metainfo['num_classes'] = num_classes - if value.max() >= num_classes: - raise ValueError(f'The label data ({value}) should not ' - f'exceed num_classes ({num_classes}).') - label = LabelData(label=value, metainfo=metainfo) - return label + return value + + +def format_score( + value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence): Score values. + + Returns: + :obj:`torch.Tensor`: The foramtted score tensor. + """ + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).float() + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value class ClsDataSample(BaseDataElement): @@ -115,64 +135,50 @@ class ClsDataSample(BaseDataElement): self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] ) -> 'ClsDataSample': """Set label of ``gt_label``.""" - label = format_label(value, self.get('num_classes')) - if 'gt_label' in self: - self.gt_label.label = label.label - else: - self.gt_label = label + label_data = getattr(self, '_gt_label', LabelData()) + label_data.label = format_label(value) + self.gt_label = label_data return self def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample': """Set score of ``gt_label``.""" - assert isinstance(value, torch.Tensor), \ - f'The value should be a torch.Tensor but got {type(value)}.' - assert value.ndim == 1, \ - f'The dims of value should be 1, but got {value.ndim}.' - - if 'num_classes' in self: - assert value.size(0) == self.num_classes, \ - f"The length of value ({value.size(0)}) doesn't "\ - f'match the num_classes ({self.num_classes}).' - metainfo = {'num_classes': self.num_classes} + label_data = getattr(self, '_gt_label', LabelData()) + label_data.score = format_score(value) + if hasattr(self, 'num_classes'): + assert len(label_data.score) == self.num_classes, \ + f'The length of score {len(label_data.score)} should be '\ + f'equal to the num_classes {self.num_classes}.' else: - metainfo = {'num_classes': value.size(0)} - - if 'gt_label' in self: - self.gt_label.score = value - else: - self.gt_label = LabelData(score=value, metainfo=metainfo) + self.set_field( + name='num_classes', + value=len(label_data.score), + field_type='metainfo') + self.gt_label = label_data return self def set_pred_label( self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] ) -> 'ClsDataSample': """Set label of ``pred_label``.""" - label = format_label(value, self.get('num_classes')) - if 'pred_label' in self: - self.pred_label.label = label.label - else: - self.pred_label = label + label_data = getattr(self, '_pred_label', LabelData()) + label_data.label = format_label(value) + self.pred_label = label_data return self def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample': """Set score of ``pred_label``.""" - assert isinstance(value, torch.Tensor), \ - f'The value should be a torch.Tensor but got {type(value)}.' - assert value.ndim == 1, \ - f'The dims of value should be 1, but got {value.ndim}.' - - if 'num_classes' in self: - assert value.size(0) == self.num_classes, \ - f"The length of value ({value.size(0)}) doesn't "\ - f'match the num_classes ({self.num_classes}).' - metainfo = {'num_classes': self.num_classes} + label_data = getattr(self, '_pred_label', LabelData()) + label_data.score = format_score(value) + if hasattr(self, 'num_classes'): + assert len(label_data.score) == self.num_classes, \ + f'The length of score {len(label_data.score)} should be '\ + f'equal to the num_classes {self.num_classes}.' else: - metainfo = {'num_classes': value.size(0)} - - if 'pred_label' in self: - self.pred_label.score = value - else: - self.pred_label = LabelData(score=value, metainfo=metainfo) + self.set_field( + name='num_classes', + value=len(label_data.score), + field_type='metainfo') + self.pred_label = label_data return self @property @@ -198,3 +204,32 @@ class ClsDataSample(BaseDataElement): @pred_label.deleter def pred_label(self): del self._pred_label + + +def _reduce_cls_datasample(data_sample): + """reduce ClsDataSample.""" + attr_dict = data_sample.__dict__ + convert_keys = [] + for k, v in attr_dict.items(): + if isinstance(v, LabelData): + attr_dict[k] = v.numpy() + convert_keys.append(k) + return _rebuild_cls_datasample, (attr_dict, convert_keys) + + +def _rebuild_cls_datasample(attr_dict, convert_keys): + """rebuild ClsDataSample.""" + data_sample = ClsDataSample() + for k in convert_keys: + attr_dict[k] = attr_dict[k].to_tensor() + data_sample.__dict__ = attr_dict + return data_sample + + +# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume +# many file descriptors because it contains multiple LabelData with tensors. +# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and +# convert these tensors to np.ndarray during pickling. It may influence the +# performance of dataloader, but slightly because these tensors in LabelData +# are very small. +ForkingPickler.register(ClsDataSample, _reduce_cls_datasample) diff --git a/mmcls/structures/utils.py b/mmcls/structures/utils.py new file mode 100644 index 00000000..8c8f0f3d --- /dev/null +++ b/mmcls/structures/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine.structures import LabelData + +if hasattr(torch, 'tensor_split'): + tensor_split = torch.tensor_split +else: + # A simple implementation of `tensor_split`. + def tensor_split(input: torch.Tensor, indices: list): + outs = [] + for start, end in zip([0] + indices, indices + [input.size(0)]): + outs.append(input[start:end]) + return outs + + +def cat_batch_labels(elements: List[LabelData], device=None): + """Concat the ``label`` of a batch of :obj:`LabelData` to a tensor. + + Args: + elements (List[LabelData]): A batch of :obj`LabelData`. + device (torch.device, optional): The output device of the batch label. + Defaults to None. + + Returns: + Tuple[torch.Tensor, List[int]]: The first item is the concated label + tensor, and the second item is the split indices of every sample. + """ + item = elements[0] + if 'label' not in item._data_fields: + return None, None + + labels = [] + splits = [0] + for element in elements: + labels.append(element.label) + splits.append(splits[-1] + element.label.size(0)) + batch_label = torch.cat(labels) + if device is not None: + batch_label = batch_label.to(device=device) + return batch_label, splits[1:-1] + + +def batch_label_to_onehot(batch_label, split_indices, num_classes): + """Convert a concated label tensor to onehot format. + + Args: + batch_label (torch.Tensor): A concated label tensor from multiple + samples. + split_indices (List[int]): The split indices of every sample. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmcls.structures import batch_label_to_onehot + >>> # Assume a concated label from 3 samples. + >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] + >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) + >>> split_indices = [2, 5] + >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) + tensor([[1, 1, 0, 0, 0], + [1, 0, 1, 0, 1], + [0, 1, 0, 1, 0]]) + """ + sparse_onehot_list = F.one_hot(batch_label, num_classes) + onehot_list = [ + sparse_onehot.sum(0) + for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) + ] + return torch.stack(onehot_list) + + +def stack_batch_scores(elements, device=None): + """Stack the ``score`` of a batch of :obj:`LabelData` to a tensor. + + Args: + elements (List[LabelData]): A batch of :obj`LabelData`. + device (torch.device, optional): The output device of the batch label. + Defaults to None. + + Returns: + torch.Tensor: The stacked score tensor. + """ + item = elements[0] + if 'score' not in item._data_fields: + return None + + batch_score = torch.stack([element.score for element in elements]) + if device is not None: + batch_score = batch_score.to(device) + return batch_score diff --git a/tests/test_evaluation/test_metrics/test_single_label.py b/tests/test_evaluation/test_metrics/test_single_label.py index 077e3ce2..108f80b7 100644 --- a/tests/test_evaluation/test_metrics/test_single_label.py +++ b/tests/test_evaluation/test_metrics/test_single_label.py @@ -212,7 +212,9 @@ class TestSingleLabel(TestCase): pred_no_score = copy.deepcopy(pred) for sample in pred_no_score: del sample['pred_label']['score'] - metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6))) + del sample['num_classes'] + metric = METRICS.build( + dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3)) metric.process(None, pred_no_score) res = metric.evaluate(6) self.assertIsInstance(res, dict) @@ -221,14 +223,13 @@ class TestSingleLabel(TestCase): self.assertAlmostEqual(res['single-label/recall'], 72.222, places=2) self.assertAlmostEqual(res['single-label/f1-score'], 65.555, places=2) - pred_no_num_classes = copy.deepcopy(pred_no_score) - for sample in pred_no_num_classes: - del sample['pred_label']['num_classes'] - with self.assertRaisesRegex(ValueError, 'neither `score` nor'): - metric.process(None, pred_no_num_classes) + metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6))) + with self.assertRaisesRegex(AssertionError, 'must be specified'): + metric.process(None, pred_no_score) # Test with empty items - metric = METRICS.build(dict(type='SingleLabelMetric', items=tuple())) + metric = METRICS.build( + dict(type='SingleLabelMetric', items=tuple(), num_classes=3)) metric.process(None, pred) res = metric.evaluate(6) self.assertIsInstance(res, dict) diff --git a/tests/test_models/test_classifiers.py b/tests/test_models/test_classifiers.py index 59a6adcd..6ed4e46a 100644 --- a/tests/test_models/test_classifiers.py +++ b/tests/test_models/test_classifiers.py @@ -47,7 +47,7 @@ class TestImageClassifier(TestCase): # test set batch augmentation from train_cfg cfg = { **self.DEFAULT_ARGS, 'train_cfg': - dict(augments=dict(type='Mixup', alpha=1., num_classes=10)) + dict(augments=dict(type='Mixup', alpha=1.)) } model: ImageClassifier = MODELS.build(cfg) self.assertIsNotNone(model.data_preprocessor.batch_augments) diff --git a/tests/test_models/test_utils/test_batch_augments.py b/tests/test_models/test_utils/test_batch_augments.py index 3890df49..65eef588 100644 --- a/tests/test_models/test_utils/test_batch_augments.py +++ b/tests/test_models/test_utils/test_batch_augments.py @@ -7,7 +7,6 @@ import torch from mmcls.models import Mixup, RandomBatchAugment from mmcls.registry import BATCH_AUGMENTS -from mmcls.structures import ClsDataSample class TestRandomBatchAugment(TestCase): @@ -54,7 +53,7 @@ class TestRandomBatchAugment(TestCase): def test_call(self): inputs = torch.rand(2, 3, 224, 224) - data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] + scores = torch.rand(2, 10) augments = [ dict(type='Mixup', alpha=1.), @@ -64,18 +63,17 @@ class TestRandomBatchAugment(TestCase): with patch('numpy.random', np.random.RandomState(0)): batch_augments.augments[1] = MagicMock() - batch_augments(inputs, data_samples) - batch_augments.augments[1].assert_called_once_with( - inputs, data_samples) + batch_augments(inputs, scores) + batch_augments.augments[1].assert_called_once_with(inputs, scores) augments = [ dict(type='Mixup', alpha=1.), dict(type='CutMix', alpha=0.8), ] batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0]) - mixed_inputs, mixed_samples = batch_augments(inputs, data_samples) + mixed_inputs, mixed_samples = batch_augments(inputs, scores) self.assertIs(mixed_inputs, inputs) - self.assertIs(mixed_samples, data_samples) + self.assertIs(mixed_samples, scores) class TestMixup(TestCase): @@ -86,45 +84,21 @@ class TestMixup(TestCase): cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'} BATCH_AUGMENTS.build(cfg) - with self.assertRaises(AssertionError): - cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'} - BATCH_AUGMENTS.build(cfg) - def test_call(self): inputs = torch.rand(2, 3, 224, 224) - data_samples = [ - ClsDataSample(metainfo={ - 'num_classes': 10 - }).set_gt_label(1) for _ in range(2) - ] + scores = torch.rand(2, 10) - # test get num_classes from data_samples mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS) - mixed_inputs, mixed_samples = mixup(inputs, data_samples) + mixed_inputs, mixed_scores = mixup(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) - - with self.assertRaisesRegex(RuntimeError, 'Not specify'): - data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] - mixup(inputs, data_samples) + self.assertEqual(mixed_scores.shape, (2, 10)) # test binary classification - cfg = {**self.DEFAULT_ARGS, 'num_classes': 1} - mixup = BATCH_AUGMENTS.build(cfg) - data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)] + scores = torch.rand(2, 1) - mixed_inputs, mixed_samples = mixup(inputs, data_samples) + mixed_inputs, mixed_scores = mixup(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, )) - - # test multi-label classification - cfg = {**self.DEFAULT_ARGS, 'num_classes': 5} - mixup = BATCH_AUGMENTS.build(cfg) - data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)] - - mixed_inputs, mixed_samples = mixup(inputs, data_samples) - self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, )) + self.assertEqual(mixed_scores.shape, (2, 1)) class TestCutMix(TestCase): @@ -135,59 +109,36 @@ class TestCutMix(TestCase): cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'} BATCH_AUGMENTS.build(cfg) - with self.assertRaises(AssertionError): - cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'} - BATCH_AUGMENTS.build(cfg) - def test_call(self): inputs = torch.rand(2, 3, 224, 224) - data_samples = [ - ClsDataSample(metainfo={ - 'num_classes': 10 - }).set_gt_label(1) for _ in range(2) - ] + scores = torch.rand(2, 10) # test with cutmix_minmax cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)} cutmix = BATCH_AUGMENTS.build(cfg) - mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + mixed_inputs, mixed_scores = cutmix(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + self.assertEqual(mixed_scores.shape, (2, 10)) # test without correct_lam cfg = {**self.DEFAULT_ARGS, 'correct_lam': False} cutmix = BATCH_AUGMENTS.build(cfg) - mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + mixed_inputs, mixed_scores = cutmix(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) + self.assertEqual(mixed_scores.shape, (2, 10)) - # test get num_classes from data_samples + # test default settings cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS) - mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + mixed_inputs, mixed_scores = cutmix(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) - - with self.assertRaisesRegex(RuntimeError, 'Not specify'): - data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] - cutmix(inputs, data_samples) + self.assertEqual(mixed_scores.shape, (2, 10)) # test binary classification - cfg = {**self.DEFAULT_ARGS, 'num_classes': 1} - cutmix = BATCH_AUGMENTS.build(cfg) - data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)] + scores = torch.rand(2, 1) - mixed_inputs, mixed_samples = cutmix(inputs, data_samples) + mixed_inputs, mixed_scores = cutmix(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, )) - - # test multi-label classification - cfg = {**self.DEFAULT_ARGS, 'num_classes': 5} - cutmix = BATCH_AUGMENTS.build(cfg) - data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)] - - mixed_inputs, mixed_samples = cutmix(inputs, data_samples) - self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, )) + self.assertEqual(mixed_scores.shape, (2, 1)) class TestResizeMix(TestCase): @@ -198,42 +149,18 @@ class TestResizeMix(TestCase): cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'} BATCH_AUGMENTS.build(cfg) - with self.assertRaises(AssertionError): - cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'} - BATCH_AUGMENTS.build(cfg) - def test_call(self): inputs = torch.rand(2, 3, 224, 224) - data_samples = [ - ClsDataSample(metainfo={ - 'num_classes': 10 - }).set_gt_label(1) for _ in range(2) - ] + scores = torch.rand(2, 10) - # test get num_classes from data_samples mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS) - mixed_inputs, mixed_samples = mixup(inputs, data_samples) + mixed_inputs, mixed_scores = mixup(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, )) - - with self.assertRaisesRegex(RuntimeError, 'Not specify'): - data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)] - mixup(inputs, data_samples) + self.assertEqual(mixed_scores.shape, (2, 10)) # test binary classification - cfg = {**self.DEFAULT_ARGS, 'num_classes': 1} - mixup = BATCH_AUGMENTS.build(cfg) - data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)] + scores = torch.rand(2, 1) - mixed_inputs, mixed_samples = mixup(inputs, data_samples) + mixed_inputs, mixed_scores = mixup(inputs, scores) self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, )) - - # test multi-label classification - cfg = {**self.DEFAULT_ARGS, 'num_classes': 5} - mixup = BATCH_AUGMENTS.build(cfg) - data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)] - - mixed_inputs, mixed_samples = mixup(inputs, data_samples) - self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224)) - self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, )) + self.assertEqual(mixed_scores.shape, (2, 1)) diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py index 84b2f563..704270e1 100644 --- a/tests/test_models/test_utils/test_data_preprocessor.py +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -71,7 +71,7 @@ class TestClsDataPreprocessor(TestCase): inputs = processed_data['inputs'] self.assertTrue((inputs >= -1).all()) self.assertTrue((inputs <= 1).all()) - self.assertNotIn('data_samples', processed_data) + self.assertIsNone(processed_data['data_samples']) data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))} inputs = processor(data)['inputs'] @@ -81,9 +81,10 @@ class TestClsDataPreprocessor(TestCase): def test_batch_augmentation(self): cfg = dict( type='ClsDataPreprocessor', + num_classes=10, batch_augments=dict(augments=[ - dict(type='Mixup', alpha=0.8, num_classes=10), - dict(type='CutMix', alpha=1., num_classes=10) + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.) ])) processor: ClsDataPreprocessor = MODELS.build(cfg) self.assertIsInstance(processor.batch_augments, RandomBatchAugment) @@ -101,4 +102,4 @@ class TestClsDataPreprocessor(TestCase): data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]} processed_data = processor(data, training=True) self.assertIn('inputs', processed_data) - self.assertNotIn('data_samples', processed_data) + self.assertIsNone(processed_data['data_samples']) diff --git a/tests/test_structures/test_datasample.py b/tests/test_structures/test_datasample.py index 202a74d2..ee45c3f2 100644 --- a/tests/test_structures/test_datasample.py +++ b/tests/test_structures/test_datasample.py @@ -60,20 +60,6 @@ class TestClsDataSample(TestCase): with self.assertRaisesRegex(TypeError, " is not"): method('hi') - # Test set num_classes - data_sample = ClsDataSample(metainfo={'num_classes': 10}) - method = getattr(data_sample, 'set_' + key) - method(5) - self.assertIn(key, data_sample) - label = getattr(data_sample, key) - self.assertIsInstance(label, LabelData) - self.assertIn('num_classes', label) - self.assertEqual(label.num_classes, 10) - - # Test unavailable label - with self.assertRaisesRegex(ValueError, r'data .*[15].* should '): - method(15) - def test_set_gt_label(self): self._test_set_label('gt_label') @@ -97,62 +83,42 @@ class TestClsDataSample(TestCase): self.assertNotIn('pred_label', data_sample) def test_set_gt_score(self): - data_sample = ClsDataSample(metainfo={'num_classes': 5}) + data_sample = ClsDataSample() data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])) self.assertIn('score', data_sample.gt_label) torch.testing.assert_allclose(data_sample.gt_label.score, [0.1, 0.1, 0.6, 0.1, 0.1]) - self.assertEqual(data_sample.gt_label.num_classes, 5) # Test set again data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1])) torch.testing.assert_allclose(data_sample.gt_label.score, [0.2, 0.1, 0.5, 0.1, 0.1]) - # Test invalid type - with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'): - data_sample.set_gt_score([1, 2, 3]) + # Test invalid length + with self.assertRaisesRegex(AssertionError, 'should be equal to'): + data_sample.set_gt_score([1, 2]) # Test invalid dims with self.assertRaisesRegex(AssertionError, 'but got 2'): data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]])) - # Test invalid num_classes - with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'): - data_sample.set_gt_score(torch.tensor([0.1, 0.2, 0.3, 0.4])) - - # Test auto inter num_classes - data_sample = ClsDataSample() - data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])) - self.assertEqual(data_sample.gt_label.num_classes, 5) - def test_set_pred_score(self): - data_sample = ClsDataSample(metainfo={'num_classes': 5}) + data_sample = ClsDataSample() data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])) self.assertIn('score', data_sample.pred_label) torch.testing.assert_allclose(data_sample.pred_label.score, [0.1, 0.1, 0.6, 0.1, 0.1]) - self.assertEqual(data_sample.pred_label.num_classes, 5) # Test set again data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1])) torch.testing.assert_allclose(data_sample.pred_label.score, [0.2, 0.1, 0.5, 0.1, 0.1]) - # Test invalid type - with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'): - data_sample.set_pred_score([1, 2, 3]) + # Test invalid length + with self.assertRaisesRegex(AssertionError, 'should be equal to'): + data_sample.set_gt_score([1, 2]) # Test invalid dims with self.assertRaisesRegex(AssertionError, 'but got 2'): data_sample.set_pred_score( torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]])) - - # Test invalid num_classes - with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'): - data_sample.set_pred_score(torch.tensor([0.1, 0.2, 0.3, 0.4])) - - # Test auto inter num_classes - data_sample = ClsDataSample() - data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])) - self.assertEqual(data_sample.pred_label.num_classes, 5) diff --git a/tests/test_structures/test_utils.py b/tests/test_structures/test_utils.py new file mode 100644 index 00000000..998e8b7c --- /dev/null +++ b/tests/test_structures/test_utils.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine.structures import LabelData + +from mmcls.structures import (batch_label_to_onehot, cat_batch_labels, + stack_batch_scores, tensor_split) + + +class TestStructureUtils(TestCase): + + def test_tensor_split(self): + tensor = torch.tensor([0, 1, 2, 3, 4, 5, 6]) + split_indices = [0, 2, 6, 6] + outs = tensor_split(tensor, split_indices) + self.assertEqual(len(outs), len(split_indices) + 1) + self.assertEqual(outs[0].size(0), 0) + self.assertEqual(outs[1].size(0), 2) + self.assertEqual(outs[2].size(0), 4) + self.assertEqual(outs[3].size(0), 0) + self.assertEqual(outs[4].size(0), 1) + + tensor = torch.tensor([]) + split_indices = [0, 0, 0, 0] + outs = tensor_split(tensor, split_indices) + self.assertEqual(len(outs), len(split_indices) + 1) + + def test_cat_batch_labels(self): + labels = [ + LabelData(label=torch.tensor([1])), + LabelData(label=torch.tensor([3, 2])), + LabelData(label=torch.tensor([0, 1, 4])), + LabelData(label=torch.tensor([], dtype=torch.int64)), + LabelData(label=torch.tensor([], dtype=torch.int64)), + ] + + batch_label, split_indices = cat_batch_labels(labels) + self.assertEqual(split_indices, [1, 3, 6, 6]) + self.assertEqual(len(batch_label), 6) + labels = tensor_split(batch_label, split_indices) + self.assertEqual(labels[0].tolist(), [1]) + self.assertEqual(labels[1].tolist(), [3, 2]) + self.assertEqual(labels[2].tolist(), [0, 1, 4]) + self.assertEqual(labels[3].tolist(), []) + self.assertEqual(labels[4].tolist(), []) + + labels = [ + LabelData(score=torch.tensor([0, 1, 0, 0, 1])), + LabelData(score=torch.tensor([0, 0, 1, 0, 0])), + LabelData(score=torch.tensor([1, 0, 0, 1, 0])), + ] + batch_label, split_indices = cat_batch_labels(labels) + self.assertIsNone(batch_label) + self.assertIsNone(split_indices) + + def test_stack_batch_scores(self): + labels = [ + LabelData(score=torch.tensor([0, 1, 0, 0, 1])), + LabelData(score=torch.tensor([0, 0, 1, 0, 0])), + LabelData(score=torch.tensor([1, 0, 0, 1, 0])), + ] + + batch_score = stack_batch_scores(labels) + self.assertEqual(batch_score.shape, (3, 5)) + + labels = [ + LabelData(label=torch.tensor([1])), + LabelData(label=torch.tensor([3, 2])), + LabelData(label=torch.tensor([0, 1, 4])), + LabelData(label=torch.tensor([], dtype=torch.int64)), + LabelData(label=torch.tensor([], dtype=torch.int64)), + ] + batch_score = stack_batch_scores(labels) + self.assertIsNone(batch_score) + + def test_batch_label_to_onehot(self): + labels = [ + LabelData(label=torch.tensor([1])), + LabelData(label=torch.tensor([3, 2])), + LabelData(label=torch.tensor([0, 1, 4])), + LabelData(label=torch.tensor([], dtype=torch.int64)), + LabelData(label=torch.tensor([], dtype=torch.int64)), + ] + + batch_label, split_indices = cat_batch_labels(labels) + batch_score = batch_label_to_onehot( + batch_label, split_indices, num_classes=5) + self.assertEqual(batch_score[0].tolist(), [0, 1, 0, 0, 0]) + self.assertEqual(batch_score[1].tolist(), [0, 0, 1, 1, 0]) + self.assertEqual(batch_score[2].tolist(), [1, 1, 0, 0, 1]) + self.assertEqual(batch_score[3].tolist(), [0, 0, 0, 0, 0]) + self.assertEqual(batch_score[4].tolist(), [0, 0, 0, 0, 0])