[Improve] Speed up data preprocessor. (#1064)

* [Improve] Speed up data preprocessor.

* Add ClsDataSample serialization override functions.

* Add unit tests

* Modify configs to fit new mixup args.

* Fix `num_classes` of the ImageNet-21k config.

* Update docs.
pull/1123/head
Ma Zerun 2022-10-17 17:08:18 +08:00 committed by GitHub
parent 06c919efc2
commit 29f066f7fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 505 additions and 378 deletions

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CIFAR100'
data_preprocessor = dict(
num_classes=100,
# RGB format normalization parameters
mean=[129.304, 124.070, 112.434],
std=[68.170, 65.392, 70.418],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CIFAR10'
data_preprocessor = dict(
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CUB'
data_preprocessor = dict(
num_classes=200,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'CUB'
data_preprocessor = dict(
num_classes=200,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet21k'
data_preprocessor = dict(
num_classes=21842,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -3,6 +3,7 @@ dataset_type = 'ImageNet'
# Google research usually use the below normalization setting.
data_preprocessor = dict(
num_classes=1000,
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
# convert image from BGR to RGB

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,6 +1,7 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],

View File

@ -1,11 +1,14 @@
# dataset settings
dataset_type = 'VOC'
data_preprocessor = dict(
num_classes=20,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
# generate onehot-format labels for multi-label classification.
to_onehot=True,
)
train_pipeline = [

View File

@ -17,7 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -17,7 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -21,7 +21,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -17,7 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -18,5 +18,5 @@ model = dict(
num_classes=1000),
topk=(1, 5),
),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)

View File

@ -20,5 +20,6 @@ model = dict(
reduction='mean',
loss_weight=1.0),
topk=(1, 5),
cal_acc=False))
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
cal_acc=False),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)

View File

@ -20,5 +20,6 @@ model = dict(
reduction='mean',
loss_weight=1.0),
topk=(1, 5),
cal_acc=False))
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
cal_acc=False),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)

View File

@ -20,5 +20,6 @@ model = dict(
reduction='mean',
loss_weight=1.0),
topk=(1, 5),
cal_acc=False))
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
cal_acc=False),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)

View File

@ -19,5 +19,6 @@ model = dict(
reduction='mean',
loss_weight=1.0),
topk=(1, 5),
cal_acc=False))
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
cal_acc=False),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)

View File

@ -13,5 +13,5 @@ model = dict(
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
train_cfg=dict(augments=dict(type='Mixup', alpha=1., num_classes=10)),
train_cfg=dict(augments=dict(type='Mixup', alpha=1.)),
)

View File

@ -13,5 +13,5 @@ model = dict(
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)

View File

@ -17,7 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -18,7 +18,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -17,7 +17,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -36,7 +36,7 @@ model = dict(
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)

View File

@ -36,7 +36,7 @@ model = dict(
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)

View File

@ -36,7 +36,7 @@ model = dict(
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)

View File

@ -25,7 +25,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -25,7 +25,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -16,7 +16,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -16,7 +16,7 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -27,8 +27,8 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -10,7 +10,7 @@ model = dict(
# dataset settings
dataset_type = 'MNIST'
data_preprocessor = dict(mean=[33.46], std=[78.87])
data_preprocessor = dict(mean=[33.46], std=[78.87], num_classes=10)
pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')]

View File

@ -298,44 +298,6 @@ Models:
Top 5 Accuracy: 93.80
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
- Name: wide-resnet50_3rdparty_8xb32_in1k
Metadata:
FLOPs: 11440000000 # 11.44G
Parameters: 68880000 # 68.88M
Training Techniques:
- SGD with Momentum
- Weight Decay
In Collection: ResNet
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.48
Top 5 Accuracy: 94.08
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth
Config: configs/resnet/wide-resnet50_8xb32_in1k.py
Converted From:
Weights: https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
- Name: wide-resnet101_3rdparty_8xb32_in1k
Metadata:
FLOPs: 22810000000 # 22.81G
Parameters: 126890000 # 126.89M
Training Techniques:
- SGD with Momentum
- Weight Decay
In Collection: ResNet
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.84
Top 5 Accuracy: 94.28
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth
Config: configs/resnet/wide-resnet101_8xb32_in1k.py
Converted From:
Weights: https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
- Name: resnetv1c50_8xb32_in1k
Metadata:
FLOPs: 4360000000

View File

@ -19,8 +19,8 @@ model = dict(
use_sigmoid=True,
)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.2, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.2),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -13,8 +13,8 @@ model = dict(
),
head=dict(loss=dict(use_sigmoid=True)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.1, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.1),
dict(type='CutMix', alpha=1.0)
]))
# dataset settings

View File

@ -10,8 +10,8 @@ model = dict(
backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)),
head=dict(loss=dict(use_sigmoid=True)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.1, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.1),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -26,8 +26,8 @@ model = dict(
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=1000),
dict(type='CutMix', alpha=1.0, num_classes=1000)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
# dataset settings

View File

@ -8,7 +8,7 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)
# schedule setting

View File

@ -8,7 +8,7 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)
# schedule setting

View File

@ -8,7 +8,7 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)
# schedule setting

View File

@ -8,7 +8,7 @@ _base_ = [
# model setting
model = dict(
head=dict(hidden_dim=3072),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
)
# schedule setting

View File

@ -233,8 +233,8 @@ These augmentations are usually only used during training, therefore, we use the
neck=...,
head=...,
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
@ -247,8 +247,8 @@ You can also speicy the probabilities of every batch augmentation by the ``probs
neck=...,
head=...,
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
], probs=[0.3, 0.7])
)

View File

@ -289,7 +289,7 @@ _base_ = './resnet50_8xb32_in1k.py'
# using CutMix batch augment
model = dict(
train_cfg=dict(
augments=dict(type='CutMix', alpha=1.0, num_classes=1000, prob=1.0)
augments=dict(type='CutMix', alpha=1.0)
)
)

View File

@ -71,7 +71,7 @@ model = dict(
backbone=dict(
type='ResNet', # 主干网络类型
# 除了 `type` 之外的所有字段都来自 `ResNet` 类的 __init__ 方法
# 查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.ResNet.html
# 查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.ResNet.html
depth=50,
num_stages=4, # 主干网络状态(stages)的数目,这些状态产生的特征图作为后续的 head 的输入。
out_indices=(3, ), # 输出的特征图输出索引。
@ -81,7 +81,7 @@ model = dict(
head=dict(
type='LinearClsHead', # 分类颈网络类型
# 除了 `type` 之外的所有字段都来自 `LinearClsHead` 类的 __init__ 方法
# 查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.LinearClsHead.html
# 查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.LinearClsHead.html
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), # 损失函数配置信息
@ -278,7 +278,7 @@ _base_ = './resnet50_8xb32_in1k.py'
# 模型在之前的基础上使用 CutMix 训练增强
model = dict(
train_cfg=dict(
augments=dict(type='CutMix', alpha=1.0, num_classes=1000, prob=1.0)
augments=dict(type='CutMix', alpha=1.0)
)
)

View File

@ -296,6 +296,7 @@ class SingleLabelMetric(BaseMetric):
true positives, false negatives and false positives.
Defaults to "macro".
num_classes (Optional, int): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
@ -349,6 +350,7 @@ class SingleLabelMetric(BaseMetric):
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
items: Sequence[str] = ('precision', 'recall', 'f1-score'),
average: Optional[str] = 'macro',
num_classes: Optional[int] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
@ -365,6 +367,7 @@ class SingleLabelMetric(BaseMetric):
'"support".'
self.items = tuple(items)
self.average = average
self.num_classes = num_classes
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
@ -383,12 +386,14 @@ class SingleLabelMetric(BaseMetric):
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
elif ('num_classes' in pred_label):
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = pred_label['num_classes']
else:
raise ValueError('The `pred_label` in data_samples do not '
'have neither `score` nor `num_classes`.')
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` has '\
'only `label`.'
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)

View File

@ -24,9 +24,6 @@ class CutMix(Mixup):
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`Mixup`.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
cutmix_minmax (List[float], optional): The min/max area ratio of the
patches. If not None, the bounding-box of patches is uniform
sampled within this ratio range, and the ``alpha`` will be ignored.
@ -49,10 +46,9 @@ class CutMix(Mixup):
def __init__(self,
alpha: float,
num_classes: Optional[int] = None,
cutmix_minmax: Optional[List[float]] = None,
correct_lam: bool = True):
super().__init__(alpha=alpha, num_classes=num_classes)
super().__init__(alpha=alpha)
self.cutmix_minmax = cutmix_minmax
self.correct_lam = correct_lam

View File

@ -1,12 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
from typing import Tuple
import numpy as np
import torch
from mmengine.structures import LabelData
from mmcls.registry import BATCH_AUGMENTS
from mmcls.structures import ClsDataSample
@BATCH_AUGMENTS.register_module()
@ -22,9 +20,6 @@ class Mixup:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
are in the note.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
Note:
The :math:`\alpha` (``alpha``) determines a random distribution
@ -33,12 +28,10 @@ class Mixup:
distribution.
"""
def __init__(self, alpha: float, num_classes: Optional[int] = None):
def __init__(self, alpha: float):
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int) or num_classes is None
self.alpha = alpha
self.num_classes = num_classes
def mix(self, batch_inputs: torch.Tensor,
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -62,28 +55,11 @@ class Mixup:
return mixed_inputs, mixed_scores
def __call__(self, batch_inputs: torch.Tensor,
data_samples: List[ClsDataSample]):
def __call__(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch data samples."""
assert data_samples is not None, f'{self.__class__.__name__} ' \
'requires data_samples. If you only want to inference, please ' \
'disable it from preprocessing.'
assert batch_score.ndim == 2, \
'The input `batch_score` should be a one-hot format tensor, '\
'which shape should be ``(N, num_classes)``.'
if self.num_classes is None and 'num_classes' not in data_samples[0]:
raise RuntimeError(
'Not specify the `num_classes` and cannot get it from '
'data samples. Please specify `num_classes` in the '
f'{self.__class__.__name__}.')
num_classes = self.num_classes or data_samples[0].get('num_classes')
batch_score = torch.stack([
LabelData.label_to_onehot(sample.gt_label.label, num_classes)
for sample in data_samples
])
mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score)
for i, sample in enumerate(data_samples):
sample.set_gt_score(mixed_score[i])
return mixed_inputs, data_samples
mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score.float())
return mixed_inputs, mixed_score

View File

@ -21,9 +21,6 @@ class ResizeMix(CutMix):
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`Mixup`.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
@ -57,17 +54,13 @@ class ResizeMix(CutMix):
def __init__(self,
alpha: float,
num_classes: Optional[int] = None,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation: str = 'bilinear',
cutmix_minmax: Optional[List[float]] = None,
correct_lam: bool = True):
super().__init__(
alpha=alpha,
num_classes=num_classes,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam)
alpha=alpha, cutmix_minmax=cutmix_minmax, correct_lam=correct_lam)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation

View File

@ -17,13 +17,16 @@ class RandomBatchAugment:
augmentations. If None, choose evenly. Defaults to None.
Example:
>>> import torch
>>> import torch.nn.functional as F
>>> from mmcls.models import RandomBatchAugment
>>> augments_cfg = [
... dict(type='CutMix', alpha=1., num_classes=10),
... dict(type='Mixup', alpha=1., num_classes=10)
... dict(type='CutMix', alpha=1.),
... dict(type='Mixup', alpha=1.)
... ]
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
>>> imgs = torch.randn(16, 3, 32, 32)
>>> label = torch.randint(0, 10, (16, ))
>>> imgs = torch.rand(16, 3, 32, 32)
>>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10)
>>> imgs, label = batch_augment(imgs, label)
.. note ::
@ -59,13 +62,13 @@ class RandomBatchAugment:
self.probs = probs
def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]):
def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor):
"""Randomly apply batch augmentations to the batch inputs and batch
data samples."""
aug_index = np.random.choice(len(self.augments), p=self.probs)
aug = self.augments[aug_index]
if aug is not None:
return aug(inputs, data_samples)
return aug(batch_input, batch_score)
else:
return inputs, data_samples
return batch_input, batch_score.float()

View File

@ -8,6 +8,8 @@ import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmcls.registry import MODELS
from mmcls.structures import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
from .batch_augments import RandomBatchAugment
@ -42,6 +44,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
pad_value (Number): The padded pixel value. Defaults to 0.
to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
to_onehot (bool): Whether to generate one-hot format gt-labels and set
to data samples. Defaults to False.
num_classes (int, optional): The number of classes. Defaults to None.
batch_augments (dict, optional): The batch augmentations settings,
including "augments" and "probs". For more details, see
:class:`mmcls.models.RandomBatchAugment`.
@ -53,11 +58,15 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
pad_size_divisor: int = 1,
pad_value: Number = 0,
to_rgb: bool = False,
to_onehot: bool = False,
num_classes: Optional[int] = None,
batch_augments: Optional[dict] = None):
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.to_rgb = to_rgb
self.to_onehot = to_onehot
self.num_classes = num_classes
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
@ -73,6 +82,13 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
if batch_augments is not None:
self.batch_augments = RandomBatchAugment(**batch_augments)
if not self.to_onehot:
from mmengine.logging import MMLogger
MMLogger.get_current_instance().info(
'Because batch augmentations are enabled, the data '
'preprocessor automatically enables the `to_onehot` '
'option to generate one-hot format labels.')
self.to_onehot = True
else:
self.batch_augments = None
@ -87,8 +103,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
Returns:
dict: Data in the same format as the model input.
"""
data = self.cast_data(data)
inputs = data['inputs']
inputs = self.cast_data(data['inputs'])
if isinstance(inputs, torch.Tensor):
# The branch if use `default_collate` as the collate_fn in the
@ -135,12 +150,36 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
self.pad_value)
# ----- Batch Aug ----
if training and self.batch_augments is not None:
data_samples = data['data_samples']
inputs, data_samples = self.batch_augments(inputs, data_samples)
data['data_samples'] = data_samples
data_samples = data.get('data_samples', None)
if data_samples is not None:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(
gt_labels, device=self.device)
data['inputs'] = inputs
batch_score = stack_batch_scores(gt_labels, device=self.device)
if batch_score is None and self.to_onehot:
assert batch_label is not None, \
'Cannot generate onehot format labels because no labels.'
num_classes = self.num_classes or data_samples[0].get(
'num_classes')
assert num_classes is not None, \
'Cannot generate one-hot format labels because not set ' \
'`num_classes` in `data_preprocessor`.'
batch_score = batch_label_to_onehot(batch_label, label_indices,
num_classes)
return data
# ----- Batch Augmentations ----
if training and self.batch_augments is not None:
inputs, batch_score = self.batch_augments(inputs, batch_score)
# ----- scatter labels and scores to data samples ---
if batch_label is not None:
for sample, label in zip(
data_samples, tensor_split(batch_label,
label_indices)):
sample.set_gt_label(label)
if batch_score is not None:
for sample, score in zip(data_samples, batch_score):
sample.set_gt_score(score)
return {'inputs': inputs, 'data_samples': data_samples}

View File

@ -1,4 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_data_sample import ClsDataSample
from .utils import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
__all__ = ['ClsDataSample']
__all__ = [
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels',
'stack_batch_scores', 'tensor_split'
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing.reduction import ForkingPickler
from numbers import Number
from typing import Sequence, Union
@ -9,20 +10,18 @@ from mmengine.structures import BaseDataElement, LabelData
from mmengine.utils import is_str
def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
num_classes: int = None) -> LabelData:
"""Convert label of various python types to :obj:`mmengine.LabelData`.
def format_label(
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
num_classes (int, optional): The number of classes. If not None, set
it to the metainfo. Defaults to None.
Returns:
:obj:`mmengine.LabelData`: The foramtted label data.
:obj:`torch.Tensor`: The foramtted label tensor.
"""
# Handle single number
@ -37,15 +36,36 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
metainfo = {}
if num_classes is not None:
metainfo['num_classes'] = num_classes
if value.max() >= num_classes:
raise ValueError(f'The label data ({value}) should not '
f'exceed num_classes ({num_classes}).')
label = LabelData(label=value, metainfo=metainfo)
return label
return value
def format_score(
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
class ClsDataSample(BaseDataElement):
@ -115,64 +135,50 @@ class ClsDataSample(BaseDataElement):
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``gt_label``."""
label = format_label(value, self.get('num_classes'))
if 'gt_label' in self:
self.gt_label.label = label.label
else:
self.gt_label = label
label_data = getattr(self, '_gt_label', LabelData())
label_data.label = format_label(value)
self.gt_label = label_data
return self
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``gt_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
label_data = getattr(self, '_gt_label', LabelData())
label_data.score = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.score) == self.num_classes, \
f'The length of score {len(label_data.score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
metainfo = {'num_classes': value.size(0)}
if 'gt_label' in self:
self.gt_label.score = value
else:
self.gt_label = LabelData(score=value, metainfo=metainfo)
self.set_field(
name='num_classes',
value=len(label_data.score),
field_type='metainfo')
self.gt_label = label_data
return self
def set_pred_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``pred_label``."""
label = format_label(value, self.get('num_classes'))
if 'pred_label' in self:
self.pred_label.label = label.label
else:
self.pred_label = label
label_data = getattr(self, '_pred_label', LabelData())
label_data.label = format_label(value)
self.pred_label = label_data
return self
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``pred_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
label_data = getattr(self, '_pred_label', LabelData())
label_data.score = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.score) == self.num_classes, \
f'The length of score {len(label_data.score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
metainfo = {'num_classes': value.size(0)}
if 'pred_label' in self:
self.pred_label.score = value
else:
self.pred_label = LabelData(score=value, metainfo=metainfo)
self.set_field(
name='num_classes',
value=len(label_data.score),
field_type='metainfo')
self.pred_label = label_data
return self
@property
@ -198,3 +204,32 @@ class ClsDataSample(BaseDataElement):
@pred_label.deleter
def pred_label(self):
del self._pred_label
def _reduce_cls_datasample(data_sample):
"""reduce ClsDataSample."""
attr_dict = data_sample.__dict__
convert_keys = []
for k, v in attr_dict.items():
if isinstance(v, LabelData):
attr_dict[k] = v.numpy()
convert_keys.append(k)
return _rebuild_cls_datasample, (attr_dict, convert_keys)
def _rebuild_cls_datasample(attr_dict, convert_keys):
"""rebuild ClsDataSample."""
data_sample = ClsDataSample()
for k in convert_keys:
attr_dict[k] = attr_dict[k].to_tensor()
data_sample.__dict__ = attr_dict
return data_sample
# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume
# many file descriptors because it contains multiple LabelData with tensors.
# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and
# convert these tensors to np.ndarray during pickling. It may influence the
# performance of dataloader, but slightly because these tensors in LabelData
# are very small.
ForkingPickler.register(ClsDataSample, _reduce_cls_datasample)

View File

@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn.functional as F
from mmengine.structures import LabelData
if hasattr(torch, 'tensor_split'):
tensor_split = torch.tensor_split
else:
# A simple implementation of `tensor_split`.
def tensor_split(input: torch.Tensor, indices: list):
outs = []
for start, end in zip([0] + indices, indices + [input.size(0)]):
outs.append(input[start:end])
return outs
def cat_batch_labels(elements: List[LabelData], device=None):
"""Concat the ``label`` of a batch of :obj:`LabelData` to a tensor.
Args:
elements (List[LabelData]): A batch of :obj`LabelData`.
device (torch.device, optional): The output device of the batch label.
Defaults to None.
Returns:
Tuple[torch.Tensor, List[int]]: The first item is the concated label
tensor, and the second item is the split indices of every sample.
"""
item = elements[0]
if 'label' not in item._data_fields:
return None, None
labels = []
splits = [0]
for element in elements:
labels.append(element.label)
splits.append(splits[-1] + element.label.size(0))
batch_label = torch.cat(labels)
if device is not None:
batch_label = batch_label.to(device=device)
return batch_label, splits[1:-1]
def batch_label_to_onehot(batch_label, split_indices, num_classes):
"""Convert a concated label tensor to onehot format.
Args:
batch_label (torch.Tensor): A concated label tensor from multiple
samples.
split_indices (List[int]): The split indices of every sample.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The onehot format label tensor.
Examples:
>>> import torch
>>> from mmcls.structures import batch_label_to_onehot
>>> # Assume a concated label from 3 samples.
>>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1]
>>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1])
>>> split_indices = [2, 5]
>>> batch_label_to_onehot(batch_label, split_indices, num_classes=5)
tensor([[1, 1, 0, 0, 0],
[1, 0, 1, 0, 1],
[0, 1, 0, 1, 0]])
"""
sparse_onehot_list = F.one_hot(batch_label, num_classes)
onehot_list = [
sparse_onehot.sum(0)
for sparse_onehot in tensor_split(sparse_onehot_list, split_indices)
]
return torch.stack(onehot_list)
def stack_batch_scores(elements, device=None):
"""Stack the ``score`` of a batch of :obj:`LabelData` to a tensor.
Args:
elements (List[LabelData]): A batch of :obj`LabelData`.
device (torch.device, optional): The output device of the batch label.
Defaults to None.
Returns:
torch.Tensor: The stacked score tensor.
"""
item = elements[0]
if 'score' not in item._data_fields:
return None
batch_score = torch.stack([element.score for element in elements])
if device is not None:
batch_score = batch_score.to(device)
return batch_score

View File

@ -212,7 +212,9 @@ class TestSingleLabel(TestCase):
pred_no_score = copy.deepcopy(pred)
for sample in pred_no_score:
del sample['pred_label']['score']
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
del sample['num_classes']
metric = METRICS.build(
dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3))
metric.process(None, pred_no_score)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
@ -221,14 +223,13 @@ class TestSingleLabel(TestCase):
self.assertAlmostEqual(res['single-label/recall'], 72.222, places=2)
self.assertAlmostEqual(res['single-label/f1-score'], 65.555, places=2)
pred_no_num_classes = copy.deepcopy(pred_no_score)
for sample in pred_no_num_classes:
del sample['pred_label']['num_classes']
with self.assertRaisesRegex(ValueError, 'neither `score` nor'):
metric.process(None, pred_no_num_classes)
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
with self.assertRaisesRegex(AssertionError, 'must be specified'):
metric.process(None, pred_no_score)
# Test with empty items
metric = METRICS.build(dict(type='SingleLabelMetric', items=tuple()))
metric = METRICS.build(
dict(type='SingleLabelMetric', items=tuple(), num_classes=3))
metric.process(None, pred)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)

View File

@ -47,7 +47,7 @@ class TestImageClassifier(TestCase):
# test set batch augmentation from train_cfg
cfg = {
**self.DEFAULT_ARGS, 'train_cfg':
dict(augments=dict(type='Mixup', alpha=1., num_classes=10))
dict(augments=dict(type='Mixup', alpha=1.))
}
model: ImageClassifier = MODELS.build(cfg)
self.assertIsNotNone(model.data_preprocessor.batch_augments)

View File

@ -7,7 +7,6 @@ import torch
from mmcls.models import Mixup, RandomBatchAugment
from mmcls.registry import BATCH_AUGMENTS
from mmcls.structures import ClsDataSample
class TestRandomBatchAugment(TestCase):
@ -54,7 +53,7 @@ class TestRandomBatchAugment(TestCase):
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
scores = torch.rand(2, 10)
augments = [
dict(type='Mixup', alpha=1.),
@ -64,18 +63,17 @@ class TestRandomBatchAugment(TestCase):
with patch('numpy.random', np.random.RandomState(0)):
batch_augments.augments[1] = MagicMock()
batch_augments(inputs, data_samples)
batch_augments.augments[1].assert_called_once_with(
inputs, data_samples)
batch_augments(inputs, scores)
batch_augments.augments[1].assert_called_once_with(inputs, scores)
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0])
mixed_inputs, mixed_samples = batch_augments(inputs, data_samples)
mixed_inputs, mixed_samples = batch_augments(inputs, scores)
self.assertIs(mixed_inputs, inputs)
self.assertIs(mixed_samples, data_samples)
self.assertIs(mixed_samples, scores)
class TestMixup(TestCase):
@ -86,45 +84,21 @@ class TestMixup(TestCase):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
scores = torch.rand(2, 10)
# test get num_classes from data_samples
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
mixed_inputs, mixed_scores = mixup(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
mixup(inputs, data_samples)
self.assertEqual(mixed_scores.shape, (2, 10))
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
scores = torch.rand(2, 1)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
mixed_inputs, mixed_scores = mixup(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
self.assertEqual(mixed_scores.shape, (2, 1))
class TestCutMix(TestCase):
@ -135,59 +109,36 @@ class TestCutMix(TestCase):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
scores = torch.rand(2, 10)
# test with cutmix_minmax
cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)}
cutmix = BATCH_AUGMENTS.build(cfg)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
mixed_inputs, mixed_scores = cutmix(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
self.assertEqual(mixed_scores.shape, (2, 10))
# test without correct_lam
cfg = {**self.DEFAULT_ARGS, 'correct_lam': False}
cutmix = BATCH_AUGMENTS.build(cfg)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
mixed_inputs, mixed_scores = cutmix(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
self.assertEqual(mixed_scores.shape, (2, 10))
# test get num_classes from data_samples
# test default settings
cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
mixed_inputs, mixed_scores = cutmix(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
cutmix(inputs, data_samples)
self.assertEqual(mixed_scores.shape, (2, 10))
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
cutmix = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
scores = torch.rand(2, 1)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
mixed_inputs, mixed_scores = cutmix(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
cutmix = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
self.assertEqual(mixed_scores.shape, (2, 1))
class TestResizeMix(TestCase):
@ -198,42 +149,18 @@ class TestResizeMix(TestCase):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
scores = torch.rand(2, 10)
# test get num_classes from data_samples
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
mixed_inputs, mixed_scores = mixup(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
mixup(inputs, data_samples)
self.assertEqual(mixed_scores.shape, (2, 10))
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
scores = torch.rand(2, 1)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
mixed_inputs, mixed_scores = mixup(inputs, scores)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
self.assertEqual(mixed_scores.shape, (2, 1))

View File

@ -71,7 +71,7 @@ class TestClsDataPreprocessor(TestCase):
inputs = processed_data['inputs']
self.assertTrue((inputs >= -1).all())
self.assertTrue((inputs <= 1).all())
self.assertNotIn('data_samples', processed_data)
self.assertIsNone(processed_data['data_samples'])
data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))}
inputs = processor(data)['inputs']
@ -81,9 +81,10 @@ class TestClsDataPreprocessor(TestCase):
def test_batch_augmentation(self):
cfg = dict(
type='ClsDataPreprocessor',
num_classes=10,
batch_augments=dict(augments=[
dict(type='Mixup', alpha=0.8, num_classes=10),
dict(type='CutMix', alpha=1., num_classes=10)
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.)
]))
processor: ClsDataPreprocessor = MODELS.build(cfg)
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
@ -101,4 +102,4 @@ class TestClsDataPreprocessor(TestCase):
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
processed_data = processor(data, training=True)
self.assertIn('inputs', processed_data)
self.assertNotIn('data_samples', processed_data)
self.assertIsNone(processed_data['data_samples'])

View File

@ -60,20 +60,6 @@ class TestClsDataSample(TestCase):
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
method('hi')
# Test set num_classes
data_sample = ClsDataSample(metainfo={'num_classes': 10})
method = getattr(data_sample, 'set_' + key)
method(5)
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIn('num_classes', label)
self.assertEqual(label.num_classes, 10)
# Test unavailable label
with self.assertRaisesRegex(ValueError, r'data .*[15].* should '):
method(15)
def test_set_gt_label(self):
self._test_set_label('gt_label')
@ -97,62 +83,42 @@ class TestClsDataSample(TestCase):
self.assertNotIn('pred_label', data_sample)
def test_set_gt_score(self):
data_sample = ClsDataSample(metainfo={'num_classes': 5})
data_sample = ClsDataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label)
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
self.assertEqual(data_sample.gt_label.num_classes, 5)
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
data_sample.set_gt_score([1, 2, 3])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
data_sample.set_gt_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
# Test auto inter num_classes
data_sample = ClsDataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertEqual(data_sample.gt_label.num_classes, 5)
def test_set_pred_score(self):
data_sample = ClsDataSample(metainfo={'num_classes': 5})
data_sample = ClsDataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label)
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
self.assertEqual(data_sample.pred_label.num_classes, 5)
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
data_sample.set_pred_score([1, 2, 3])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
data_sample.set_pred_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
# Test auto inter num_classes
data_sample = ClsDataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertEqual(data_sample.pred_label.num_classes, 5)

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.structures import LabelData
from mmcls.structures import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
class TestStructureUtils(TestCase):
def test_tensor_split(self):
tensor = torch.tensor([0, 1, 2, 3, 4, 5, 6])
split_indices = [0, 2, 6, 6]
outs = tensor_split(tensor, split_indices)
self.assertEqual(len(outs), len(split_indices) + 1)
self.assertEqual(outs[0].size(0), 0)
self.assertEqual(outs[1].size(0), 2)
self.assertEqual(outs[2].size(0), 4)
self.assertEqual(outs[3].size(0), 0)
self.assertEqual(outs[4].size(0), 1)
tensor = torch.tensor([])
split_indices = [0, 0, 0, 0]
outs = tensor_split(tensor, split_indices)
self.assertEqual(len(outs), len(split_indices) + 1)
def test_cat_batch_labels(self):
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
]
batch_label, split_indices = cat_batch_labels(labels)
self.assertEqual(split_indices, [1, 3, 6, 6])
self.assertEqual(len(batch_label), 6)
labels = tensor_split(batch_label, split_indices)
self.assertEqual(labels[0].tolist(), [1])
self.assertEqual(labels[1].tolist(), [3, 2])
self.assertEqual(labels[2].tolist(), [0, 1, 4])
self.assertEqual(labels[3].tolist(), [])
self.assertEqual(labels[4].tolist(), [])
labels = [
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
]
batch_label, split_indices = cat_batch_labels(labels)
self.assertIsNone(batch_label)
self.assertIsNone(split_indices)
def test_stack_batch_scores(self):
labels = [
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
]
batch_score = stack_batch_scores(labels)
self.assertEqual(batch_score.shape, (3, 5))
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
]
batch_score = stack_batch_scores(labels)
self.assertIsNone(batch_score)
def test_batch_label_to_onehot(self):
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
]
batch_label, split_indices = cat_batch_labels(labels)
batch_score = batch_label_to_onehot(
batch_label, split_indices, num_classes=5)
self.assertEqual(batch_score[0].tolist(), [0, 1, 0, 0, 0])
self.assertEqual(batch_score[1].tolist(), [0, 0, 1, 1, 0])
self.assertEqual(batch_score[2].tolist(), [1, 1, 0, 0, 1])
self.assertEqual(batch_score[3].tolist(), [0, 0, 0, 0, 0])
self.assertEqual(batch_score[4].tolist(), [0, 0, 0, 0, 0])