[Improve] Speed up data preprocessor. (#1064)
* [Improve] Speed up data preprocessor. * Add ClsDataSample serialization override functions. * Add unit tests * Modify configs to fit new mixup args. * Fix `num_classes` of the ImageNet-21k config. * Update docs.pull/1123/head
parent
06c919efc2
commit
29f066f7fb
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CIFAR100'
|
||||
data_preprocessor = dict(
|
||||
num_classes=100,
|
||||
# RGB format normalization parameters
|
||||
mean=[129.304, 124.070, 112.434],
|
||||
std=[68.170, 65.392, 70.418],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CIFAR10'
|
||||
data_preprocessor = dict(
|
||||
num_classes=10,
|
||||
# RGB format normalization parameters
|
||||
mean=[125.307, 122.961, 113.8575],
|
||||
std=[51.5865, 50.847, 51.255],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CUB'
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CUB'
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet21k'
|
||||
data_preprocessor = dict(
|
||||
num_classes=21842,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -3,6 +3,7 @@ dataset_type = 'ImageNet'
|
|||
|
||||
# Google research usually use the below normalization setting.
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
# convert image from BGR to RGB
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# dataset settings
|
||||
dataset_type = 'VOC'
|
||||
data_preprocessor = dict(
|
||||
num_classes=20,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
# generate onehot-format labels for multi-label classification.
|
||||
to_onehot=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -21,7 +21,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -18,5 +18,5 @@ model = dict(
|
|||
num_classes=1000),
|
||||
topk=(1, 5),
|
||||
),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
|
|
@ -20,5 +20,6 @@ model = dict(
|
|||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
cal_acc=False))
|
||||
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
|
||||
cal_acc=False),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
|
|
@ -20,5 +20,6 @@ model = dict(
|
|||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
cal_acc=False))
|
||||
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
|
||||
cal_acc=False),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
|
|
@ -20,5 +20,6 @@ model = dict(
|
|||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
cal_acc=False))
|
||||
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
|
||||
cal_acc=False),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
|
|
@ -19,5 +19,6 @@ model = dict(
|
|||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
cal_acc=False))
|
||||
train_cfg = dict(mixup=dict(alpha=0.2, num_classes=1000))
|
||||
cal_acc=False),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
|
|
@ -13,5 +13,5 @@ model = dict(
|
|||
num_classes=10,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=1., num_classes=10)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=1.)),
|
||||
)
|
||||
|
|
|
@ -13,5 +13,5 @@ model = dict(
|
|||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -36,7 +36,7 @@ model = dict(
|
|||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -36,7 +36,7 @@ model = dict(
|
|||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -36,7 +36,7 @@ model = dict(
|
|||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -25,7 +25,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -25,7 +25,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
|
|
@ -27,8 +27,8 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ model = dict(
|
|||
|
||||
# dataset settings
|
||||
dataset_type = 'MNIST'
|
||||
data_preprocessor = dict(mean=[33.46], std=[78.87])
|
||||
data_preprocessor = dict(mean=[33.46], std=[78.87], num_classes=10)
|
||||
|
||||
pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')]
|
||||
|
||||
|
|
|
@ -298,44 +298,6 @@ Models:
|
|||
Top 5 Accuracy: 93.80
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
|
||||
- Name: wide-resnet50_3rdparty_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 11440000000 # 11.44G
|
||||
Parameters: 68880000 # 68.88M
|
||||
Training Techniques:
|
||||
- SGD with Momentum
|
||||
- Weight Decay
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.48
|
||||
Top 5 Accuracy: 94.08
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth
|
||||
Config: configs/resnet/wide-resnet50_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth
|
||||
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
- Name: wide-resnet101_3rdparty_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 22810000000 # 22.81G
|
||||
Parameters: 126890000 # 126.89M
|
||||
Training Techniques:
|
||||
- SGD with Momentum
|
||||
- Weight Decay
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.84
|
||||
Top 5 Accuracy: 94.28
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth
|
||||
Config: configs/resnet/wide-resnet101_8xb32_in1k.py
|
||||
Converted From:
|
||||
Weights: https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth
|
||||
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
- Name: resnetv1c50_8xb32_in1k
|
||||
Metadata:
|
||||
FLOPs: 4360000000
|
||||
|
|
|
@ -19,8 +19,8 @@ model = dict(
|
|||
use_sigmoid=True,
|
||||
)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.2, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.2),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ model = dict(
|
|||
),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.1, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.1),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
||||
|
||||
# dataset settings
|
||||
|
|
|
@ -10,8 +10,8 @@ model = dict(
|
|||
backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.1, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.1),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ model = dict(
|
|||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=1000),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=1000)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
||||
|
||||
# dataset settings
|
||||
|
|
|
@ -8,7 +8,7 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
|
|
|
@ -8,7 +8,7 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
|
|
|
@ -8,7 +8,7 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
|
|
|
@ -8,7 +8,7 @@ _base_ = [
|
|||
# model setting
|
||||
model = dict(
|
||||
head=dict(hidden_dim=3072),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2, num_classes=1000)),
|
||||
train_cfg=dict(augments=dict(type='Mixup', alpha=0.2)),
|
||||
)
|
||||
|
||||
# schedule setting
|
||||
|
|
|
@ -233,8 +233,8 @@ These augmentations are usually only used during training, therefore, we use the
|
|||
neck=...,
|
||||
head=...,
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -247,8 +247,8 @@ You can also speicy the probabilities of every batch augmentation by the ``probs
|
|||
neck=...,
|
||||
head=...,
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=num_classes),
|
||||
dict(type='CutMix', alpha=1.0, num_classes=num_classes),
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0),
|
||||
], probs=[0.3, 0.7])
|
||||
)
|
||||
|
||||
|
|
|
@ -289,7 +289,7 @@ _base_ = './resnet50_8xb32_in1k.py'
|
|||
# using CutMix batch augment
|
||||
model = dict(
|
||||
train_cfg=dict(
|
||||
augments=dict(type='CutMix', alpha=1.0, num_classes=1000, prob=1.0)
|
||||
augments=dict(type='CutMix', alpha=1.0)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ model = dict(
|
|||
backbone=dict(
|
||||
type='ResNet', # 主干网络类型
|
||||
# 除了 `type` 之外的所有字段都来自 `ResNet` 类的 __init__ 方法
|
||||
# 您查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.ResNet.html
|
||||
# 可查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.ResNet.html
|
||||
depth=50,
|
||||
num_stages=4, # 主干网络状态(stages)的数目,这些状态产生的特征图作为后续的 head 的输入。
|
||||
out_indices=(3, ), # 输出的特征图输出索引。
|
||||
|
@ -81,7 +81,7 @@ model = dict(
|
|||
head=dict(
|
||||
type='LinearClsHead', # 分类颈网络类型
|
||||
# 除了 `type` 之外的所有字段都来自 `LinearClsHead` 类的 __init__ 方法
|
||||
# 您查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.LinearClsHead.html
|
||||
# 可查阅 https://mmclassification.readthedocs.io/zh_CN/1.x/api/generated/mmcls.models.LinearClsHead.html
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), # 损失函数配置信息
|
||||
|
@ -278,7 +278,7 @@ _base_ = './resnet50_8xb32_in1k.py'
|
|||
# 模型在之前的基础上使用 CutMix 训练增强
|
||||
model = dict(
|
||||
train_cfg=dict(
|
||||
augments=dict(type='CutMix', alpha=1.0, num_classes=1000, prob=1.0)
|
||||
augments=dict(type='CutMix', alpha=1.0)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -296,6 +296,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
true positives, false negatives and false positives.
|
||||
|
||||
Defaults to "macro".
|
||||
num_classes (Optional, int): The number of classes. Defaults to None.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
|
@ -349,6 +350,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
|
||||
items: Sequence[str] = ('precision', 'recall', 'f1-score'),
|
||||
average: Optional[str] = 'macro',
|
||||
num_classes: Optional[int] = None,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
@ -365,6 +367,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
'"support".'
|
||||
self.items = tuple(items)
|
||||
self.average = average
|
||||
self.num_classes = num_classes
|
||||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
@ -383,12 +386,14 @@ class SingleLabelMetric(BaseMetric):
|
|||
gt_label = data_sample['gt_label']
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
elif ('num_classes' in pred_label):
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
result['num_classes'] = pred_label['num_classes']
|
||||
else:
|
||||
raise ValueError('The `pred_label` in data_samples do not '
|
||||
'have neither `score` nor `num_classes`.')
|
||||
num_classes = self.num_classes or data_sample.get(
|
||||
'num_classes')
|
||||
assert num_classes is not None, \
|
||||
'The `num_classes` must be specified if `pred_label` has '\
|
||||
'only `label`.'
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
result['num_classes'] = num_classes
|
||||
result['gt_label'] = gt_label['label'].cpu()
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
|
|
@ -24,9 +24,6 @@ class CutMix(Mixup):
|
|||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
can be found in :class:`Mixup`.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
cutmix_minmax (List[float], optional): The min/max area ratio of the
|
||||
patches. If not None, the bounding-box of patches is uniform
|
||||
sampled within this ratio range, and the ``alpha`` will be ignored.
|
||||
|
@ -49,10 +46,9 @@ class CutMix(Mixup):
|
|||
|
||||
def __init__(self,
|
||||
alpha: float,
|
||||
num_classes: Optional[int] = None,
|
||||
cutmix_minmax: Optional[List[float]] = None,
|
||||
correct_lam: bool = True):
|
||||
super().__init__(alpha=alpha, num_classes=num_classes)
|
||||
super().__init__(alpha=alpha)
|
||||
|
||||
self.cutmix_minmax = cutmix_minmax
|
||||
self.correct_lam = correct_lam
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
||||
|
||||
@BATCH_AUGMENTS.register_module()
|
||||
|
@ -22,9 +20,6 @@ class Mixup:
|
|||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
are in the note.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
|
||||
Note:
|
||||
The :math:`\alpha` (``alpha``) determines a random distribution
|
||||
|
@ -33,12 +28,10 @@ class Mixup:
|
|||
distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha: float, num_classes: Optional[int] = None):
|
||||
def __init__(self, alpha: float):
|
||||
assert isinstance(alpha, float) and alpha > 0
|
||||
assert isinstance(num_classes, int) or num_classes is None
|
||||
|
||||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor,
|
||||
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
@ -62,28 +55,11 @@ class Mixup:
|
|||
|
||||
return mixed_inputs, mixed_scores
|
||||
|
||||
def __call__(self, batch_inputs: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]):
|
||||
def __call__(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch data samples."""
|
||||
assert data_samples is not None, f'{self.__class__.__name__} ' \
|
||||
'requires data_samples. If you only want to inference, please ' \
|
||||
'disable it from preprocessing.'
|
||||
assert batch_score.ndim == 2, \
|
||||
'The input `batch_score` should be a one-hot format tensor, '\
|
||||
'which shape should be ``(N, num_classes)``.'
|
||||
|
||||
if self.num_classes is None and 'num_classes' not in data_samples[0]:
|
||||
raise RuntimeError(
|
||||
'Not specify the `num_classes` and cannot get it from '
|
||||
'data samples. Please specify `num_classes` in the '
|
||||
f'{self.__class__.__name__}.')
|
||||
num_classes = self.num_classes or data_samples[0].get('num_classes')
|
||||
|
||||
batch_score = torch.stack([
|
||||
LabelData.label_to_onehot(sample.gt_label.label, num_classes)
|
||||
for sample in data_samples
|
||||
])
|
||||
|
||||
mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score)
|
||||
|
||||
for i, sample in enumerate(data_samples):
|
||||
sample.set_gt_score(mixed_score[i])
|
||||
|
||||
return mixed_inputs, data_samples
|
||||
mixed_inputs, mixed_score = self.mix(batch_inputs, batch_score.float())
|
||||
return mixed_inputs, mixed_score
|
||||
|
|
|
@ -21,9 +21,6 @@ class ResizeMix(CutMix):
|
|||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
can be found in :class:`Mixup`.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
lam_min(float): The minimum value of lam. Defaults to 0.1.
|
||||
lam_max(float): The maximum value of lam. Defaults to 0.8.
|
||||
interpolation (str): algorithm used for upsampling:
|
||||
|
@ -57,17 +54,13 @@ class ResizeMix(CutMix):
|
|||
|
||||
def __init__(self,
|
||||
alpha: float,
|
||||
num_classes: Optional[int] = None,
|
||||
lam_min: float = 0.1,
|
||||
lam_max: float = 0.8,
|
||||
interpolation: str = 'bilinear',
|
||||
cutmix_minmax: Optional[List[float]] = None,
|
||||
correct_lam: bool = True):
|
||||
super().__init__(
|
||||
alpha=alpha,
|
||||
num_classes=num_classes,
|
||||
cutmix_minmax=cutmix_minmax,
|
||||
correct_lam=correct_lam)
|
||||
alpha=alpha, cutmix_minmax=cutmix_minmax, correct_lam=correct_lam)
|
||||
self.lam_min = lam_min
|
||||
self.lam_max = lam_max
|
||||
self.interpolation = interpolation
|
||||
|
|
|
@ -17,13 +17,16 @@ class RandomBatchAugment:
|
|||
augmentations. If None, choose evenly. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from mmcls.models import RandomBatchAugment
|
||||
>>> augments_cfg = [
|
||||
... dict(type='CutMix', alpha=1., num_classes=10),
|
||||
... dict(type='Mixup', alpha=1., num_classes=10)
|
||||
... dict(type='CutMix', alpha=1.),
|
||||
... dict(type='Mixup', alpha=1.)
|
||||
... ]
|
||||
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
|
||||
>>> imgs = torch.randn(16, 3, 32, 32)
|
||||
>>> label = torch.randint(0, 10, (16, ))
|
||||
>>> imgs = torch.rand(16, 3, 32, 32)
|
||||
>>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10)
|
||||
>>> imgs, label = batch_augment(imgs, label)
|
||||
|
||||
.. note ::
|
||||
|
@ -59,13 +62,13 @@ class RandomBatchAugment:
|
|||
|
||||
self.probs = probs
|
||||
|
||||
def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]):
|
||||
def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Randomly apply batch augmentations to the batch inputs and batch
|
||||
data samples."""
|
||||
aug_index = np.random.choice(len(self.augments), p=self.probs)
|
||||
aug = self.augments[aug_index]
|
||||
|
||||
if aug is not None:
|
||||
return aug(inputs, data_samples)
|
||||
return aug(batch_input, batch_score)
|
||||
else:
|
||||
return inputs, data_samples
|
||||
return batch_input, batch_score.float()
|
||||
|
|
|
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
|||
from mmengine.model import BaseDataPreprocessor, stack_batch
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import (batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
from .batch_augments import RandomBatchAugment
|
||||
|
||||
|
||||
|
@ -42,6 +44,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
pad_value (Number): The padded pixel value. Defaults to 0.
|
||||
to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
to_onehot (bool): Whether to generate one-hot format gt-labels and set
|
||||
to data samples. Defaults to False.
|
||||
num_classes (int, optional): The number of classes. Defaults to None.
|
||||
batch_augments (dict, optional): The batch augmentations settings,
|
||||
including "augments" and "probs". For more details, see
|
||||
:class:`mmcls.models.RandomBatchAugment`.
|
||||
|
@ -53,11 +58,15 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
pad_size_divisor: int = 1,
|
||||
pad_value: Number = 0,
|
||||
to_rgb: bool = False,
|
||||
to_onehot: bool = False,
|
||||
num_classes: Optional[int] = None,
|
||||
batch_augments: Optional[dict] = None):
|
||||
super().__init__()
|
||||
self.pad_size_divisor = pad_size_divisor
|
||||
self.pad_value = pad_value
|
||||
self.to_rgb = to_rgb
|
||||
self.to_onehot = to_onehot
|
||||
self.num_classes = num_classes
|
||||
|
||||
if mean is not None:
|
||||
assert std is not None, 'To enable the normalization in ' \
|
||||
|
@ -73,6 +82,13 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
|
||||
if batch_augments is not None:
|
||||
self.batch_augments = RandomBatchAugment(**batch_augments)
|
||||
if not self.to_onehot:
|
||||
from mmengine.logging import MMLogger
|
||||
MMLogger.get_current_instance().info(
|
||||
'Because batch augmentations are enabled, the data '
|
||||
'preprocessor automatically enables the `to_onehot` '
|
||||
'option to generate one-hot format labels.')
|
||||
self.to_onehot = True
|
||||
else:
|
||||
self.batch_augments = None
|
||||
|
||||
|
@ -87,8 +103,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
Returns:
|
||||
dict: Data in the same format as the model input.
|
||||
"""
|
||||
data = self.cast_data(data)
|
||||
inputs = data['inputs']
|
||||
inputs = self.cast_data(data['inputs'])
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
# The branch if use `default_collate` as the collate_fn in the
|
||||
|
@ -135,12 +150,36 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
|
||||
# ----- Batch Aug ----
|
||||
if training and self.batch_augments is not None:
|
||||
data_samples = data['data_samples']
|
||||
inputs, data_samples = self.batch_augments(inputs, data_samples)
|
||||
data['data_samples'] = data_samples
|
||||
data_samples = data.get('data_samples', None)
|
||||
if data_samples is not None:
|
||||
gt_labels = [sample.gt_label for sample in data_samples]
|
||||
batch_label, label_indices = cat_batch_labels(
|
||||
gt_labels, device=self.device)
|
||||
|
||||
data['inputs'] = inputs
|
||||
batch_score = stack_batch_scores(gt_labels, device=self.device)
|
||||
if batch_score is None and self.to_onehot:
|
||||
assert batch_label is not None, \
|
||||
'Cannot generate onehot format labels because no labels.'
|
||||
num_classes = self.num_classes or data_samples[0].get(
|
||||
'num_classes')
|
||||
assert num_classes is not None, \
|
||||
'Cannot generate one-hot format labels because not set ' \
|
||||
'`num_classes` in `data_preprocessor`.'
|
||||
batch_score = batch_label_to_onehot(batch_label, label_indices,
|
||||
num_classes)
|
||||
|
||||
return data
|
||||
# ----- Batch Augmentations ----
|
||||
if training and self.batch_augments is not None:
|
||||
inputs, batch_score = self.batch_augments(inputs, batch_score)
|
||||
|
||||
# ----- scatter labels and scores to data samples ---
|
||||
if batch_label is not None:
|
||||
for sample, label in zip(
|
||||
data_samples, tensor_split(batch_label,
|
||||
label_indices)):
|
||||
sample.set_gt_label(label)
|
||||
if batch_score is not None:
|
||||
for sample, score in zip(data_samples, batch_score):
|
||||
sample.set_gt_score(score)
|
||||
|
||||
return {'inputs': inputs, 'data_samples': data_samples}
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cls_data_sample import ClsDataSample
|
||||
from .utils import (batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
|
||||
__all__ = ['ClsDataSample']
|
||||
__all__ = [
|
||||
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels',
|
||||
'stack_batch_scores', 'tensor_split'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from numbers import Number
|
||||
from typing import Sequence, Union
|
||||
|
||||
|
@ -9,20 +10,18 @@ from mmengine.structures import BaseDataElement, LabelData
|
|||
from mmengine.utils import is_str
|
||||
|
||||
|
||||
def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
||||
num_classes: int = None) -> LabelData:
|
||||
"""Convert label of various python types to :obj:`mmengine.LabelData`.
|
||||
def format_label(
|
||||
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
|
||||
"""Convert various python types to label-format tensor.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int`.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
|
||||
num_classes (int, optional): The number of classes. If not None, set
|
||||
it to the metainfo. Defaults to None.
|
||||
|
||||
Returns:
|
||||
:obj:`mmengine.LabelData`: The foramtted label data.
|
||||
:obj:`torch.Tensor`: The foramtted label tensor.
|
||||
"""
|
||||
|
||||
# Handle single number
|
||||
|
@ -37,15 +36,36 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
|||
value = torch.LongTensor([value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Type {type(value)} is not an available label type.')
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
metainfo = {}
|
||||
if num_classes is not None:
|
||||
metainfo['num_classes'] = num_classes
|
||||
if value.max() >= num_classes:
|
||||
raise ValueError(f'The label data ({value}) should not '
|
||||
f'exceed num_classes ({num_classes}).')
|
||||
label = LabelData(label=value, metainfo=metainfo)
|
||||
return label
|
||||
return value
|
||||
|
||||
|
||||
def format_score(
|
||||
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
|
||||
"""Convert various python types to score-format tensor.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: The foramtted score tensor.
|
||||
"""
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value).float()
|
||||
elif isinstance(value, Sequence) and not is_str(value):
|
||||
value = torch.tensor(value).float()
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Type {type(value)} is not an available label type.')
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ClsDataSample(BaseDataElement):
|
||||
|
@ -115,64 +135,50 @@ class ClsDataSample(BaseDataElement):
|
|||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
|
||||
) -> 'ClsDataSample':
|
||||
"""Set label of ``gt_label``."""
|
||||
label = format_label(value, self.get('num_classes'))
|
||||
if 'gt_label' in self:
|
||||
self.gt_label.label = label.label
|
||||
else:
|
||||
self.gt_label = label
|
||||
label_data = getattr(self, '_gt_label', LabelData())
|
||||
label_data.label = format_label(value)
|
||||
self.gt_label = label_data
|
||||
return self
|
||||
|
||||
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
|
||||
"""Set score of ``gt_label``."""
|
||||
assert isinstance(value, torch.Tensor), \
|
||||
f'The value should be a torch.Tensor but got {type(value)}.'
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
if 'num_classes' in self:
|
||||
assert value.size(0) == self.num_classes, \
|
||||
f"The length of value ({value.size(0)}) doesn't "\
|
||||
f'match the num_classes ({self.num_classes}).'
|
||||
metainfo = {'num_classes': self.num_classes}
|
||||
label_data = getattr(self, '_gt_label', LabelData())
|
||||
label_data.score = format_score(value)
|
||||
if hasattr(self, 'num_classes'):
|
||||
assert len(label_data.score) == self.num_classes, \
|
||||
f'The length of score {len(label_data.score)} should be '\
|
||||
f'equal to the num_classes {self.num_classes}.'
|
||||
else:
|
||||
metainfo = {'num_classes': value.size(0)}
|
||||
|
||||
if 'gt_label' in self:
|
||||
self.gt_label.score = value
|
||||
else:
|
||||
self.gt_label = LabelData(score=value, metainfo=metainfo)
|
||||
self.set_field(
|
||||
name='num_classes',
|
||||
value=len(label_data.score),
|
||||
field_type='metainfo')
|
||||
self.gt_label = label_data
|
||||
return self
|
||||
|
||||
def set_pred_label(
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
|
||||
) -> 'ClsDataSample':
|
||||
"""Set label of ``pred_label``."""
|
||||
label = format_label(value, self.get('num_classes'))
|
||||
if 'pred_label' in self:
|
||||
self.pred_label.label = label.label
|
||||
else:
|
||||
self.pred_label = label
|
||||
label_data = getattr(self, '_pred_label', LabelData())
|
||||
label_data.label = format_label(value)
|
||||
self.pred_label = label_data
|
||||
return self
|
||||
|
||||
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
|
||||
"""Set score of ``pred_label``."""
|
||||
assert isinstance(value, torch.Tensor), \
|
||||
f'The value should be a torch.Tensor but got {type(value)}.'
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
if 'num_classes' in self:
|
||||
assert value.size(0) == self.num_classes, \
|
||||
f"The length of value ({value.size(0)}) doesn't "\
|
||||
f'match the num_classes ({self.num_classes}).'
|
||||
metainfo = {'num_classes': self.num_classes}
|
||||
label_data = getattr(self, '_pred_label', LabelData())
|
||||
label_data.score = format_score(value)
|
||||
if hasattr(self, 'num_classes'):
|
||||
assert len(label_data.score) == self.num_classes, \
|
||||
f'The length of score {len(label_data.score)} should be '\
|
||||
f'equal to the num_classes {self.num_classes}.'
|
||||
else:
|
||||
metainfo = {'num_classes': value.size(0)}
|
||||
|
||||
if 'pred_label' in self:
|
||||
self.pred_label.score = value
|
||||
else:
|
||||
self.pred_label = LabelData(score=value, metainfo=metainfo)
|
||||
self.set_field(
|
||||
name='num_classes',
|
||||
value=len(label_data.score),
|
||||
field_type='metainfo')
|
||||
self.pred_label = label_data
|
||||
return self
|
||||
|
||||
@property
|
||||
|
@ -198,3 +204,32 @@ class ClsDataSample(BaseDataElement):
|
|||
@pred_label.deleter
|
||||
def pred_label(self):
|
||||
del self._pred_label
|
||||
|
||||
|
||||
def _reduce_cls_datasample(data_sample):
|
||||
"""reduce ClsDataSample."""
|
||||
attr_dict = data_sample.__dict__
|
||||
convert_keys = []
|
||||
for k, v in attr_dict.items():
|
||||
if isinstance(v, LabelData):
|
||||
attr_dict[k] = v.numpy()
|
||||
convert_keys.append(k)
|
||||
return _rebuild_cls_datasample, (attr_dict, convert_keys)
|
||||
|
||||
|
||||
def _rebuild_cls_datasample(attr_dict, convert_keys):
|
||||
"""rebuild ClsDataSample."""
|
||||
data_sample = ClsDataSample()
|
||||
for k in convert_keys:
|
||||
attr_dict[k] = attr_dict[k].to_tensor()
|
||||
data_sample.__dict__ = attr_dict
|
||||
return data_sample
|
||||
|
||||
|
||||
# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume
|
||||
# many file descriptors because it contains multiple LabelData with tensors.
|
||||
# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and
|
||||
# convert these tensors to np.ndarray during pickling. It may influence the
|
||||
# performance of dataloader, but slightly because these tensors in LabelData
|
||||
# are very small.
|
||||
ForkingPickler.register(ClsDataSample, _reduce_cls_datasample)
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
if hasattr(torch, 'tensor_split'):
|
||||
tensor_split = torch.tensor_split
|
||||
else:
|
||||
# A simple implementation of `tensor_split`.
|
||||
def tensor_split(input: torch.Tensor, indices: list):
|
||||
outs = []
|
||||
for start, end in zip([0] + indices, indices + [input.size(0)]):
|
||||
outs.append(input[start:end])
|
||||
return outs
|
||||
|
||||
|
||||
def cat_batch_labels(elements: List[LabelData], device=None):
|
||||
"""Concat the ``label`` of a batch of :obj:`LabelData` to a tensor.
|
||||
|
||||
Args:
|
||||
elements (List[LabelData]): A batch of :obj`LabelData`.
|
||||
device (torch.device, optional): The output device of the batch label.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, List[int]]: The first item is the concated label
|
||||
tensor, and the second item is the split indices of every sample.
|
||||
"""
|
||||
item = elements[0]
|
||||
if 'label' not in item._data_fields:
|
||||
return None, None
|
||||
|
||||
labels = []
|
||||
splits = [0]
|
||||
for element in elements:
|
||||
labels.append(element.label)
|
||||
splits.append(splits[-1] + element.label.size(0))
|
||||
batch_label = torch.cat(labels)
|
||||
if device is not None:
|
||||
batch_label = batch_label.to(device=device)
|
||||
return batch_label, splits[1:-1]
|
||||
|
||||
|
||||
def batch_label_to_onehot(batch_label, split_indices, num_classes):
|
||||
"""Convert a concated label tensor to onehot format.
|
||||
|
||||
Args:
|
||||
batch_label (torch.Tensor): A concated label tensor from multiple
|
||||
samples.
|
||||
split_indices (List[int]): The split indices of every sample.
|
||||
num_classes (int): The number of classes.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The onehot format label tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.structures import batch_label_to_onehot
|
||||
>>> # Assume a concated label from 3 samples.
|
||||
>>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1]
|
||||
>>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1])
|
||||
>>> split_indices = [2, 5]
|
||||
>>> batch_label_to_onehot(batch_label, split_indices, num_classes=5)
|
||||
tensor([[1, 1, 0, 0, 0],
|
||||
[1, 0, 1, 0, 1],
|
||||
[0, 1, 0, 1, 0]])
|
||||
"""
|
||||
sparse_onehot_list = F.one_hot(batch_label, num_classes)
|
||||
onehot_list = [
|
||||
sparse_onehot.sum(0)
|
||||
for sparse_onehot in tensor_split(sparse_onehot_list, split_indices)
|
||||
]
|
||||
return torch.stack(onehot_list)
|
||||
|
||||
|
||||
def stack_batch_scores(elements, device=None):
|
||||
"""Stack the ``score`` of a batch of :obj:`LabelData` to a tensor.
|
||||
|
||||
Args:
|
||||
elements (List[LabelData]): A batch of :obj`LabelData`.
|
||||
device (torch.device, optional): The output device of the batch label.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The stacked score tensor.
|
||||
"""
|
||||
item = elements[0]
|
||||
if 'score' not in item._data_fields:
|
||||
return None
|
||||
|
||||
batch_score = torch.stack([element.score for element in elements])
|
||||
if device is not None:
|
||||
batch_score = batch_score.to(device)
|
||||
return batch_score
|
|
@ -212,7 +212,9 @@ class TestSingleLabel(TestCase):
|
|||
pred_no_score = copy.deepcopy(pred)
|
||||
for sample in pred_no_score:
|
||||
del sample['pred_label']['score']
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
|
||||
del sample['num_classes']
|
||||
metric = METRICS.build(
|
||||
dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3))
|
||||
metric.process(None, pred_no_score)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
@ -221,14 +223,13 @@ class TestSingleLabel(TestCase):
|
|||
self.assertAlmostEqual(res['single-label/recall'], 72.222, places=2)
|
||||
self.assertAlmostEqual(res['single-label/f1-score'], 65.555, places=2)
|
||||
|
||||
pred_no_num_classes = copy.deepcopy(pred_no_score)
|
||||
for sample in pred_no_num_classes:
|
||||
del sample['pred_label']['num_classes']
|
||||
with self.assertRaisesRegex(ValueError, 'neither `score` nor'):
|
||||
metric.process(None, pred_no_num_classes)
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
|
||||
with self.assertRaisesRegex(AssertionError, 'must be specified'):
|
||||
metric.process(None, pred_no_score)
|
||||
|
||||
# Test with empty items
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', items=tuple()))
|
||||
metric = METRICS.build(
|
||||
dict(type='SingleLabelMetric', items=tuple(), num_classes=3))
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
|
|
@ -47,7 +47,7 @@ class TestImageClassifier(TestCase):
|
|||
# test set batch augmentation from train_cfg
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'train_cfg':
|
||||
dict(augments=dict(type='Mixup', alpha=1., num_classes=10))
|
||||
dict(augments=dict(type='Mixup', alpha=1.))
|
||||
}
|
||||
model: ImageClassifier = MODELS.build(cfg)
|
||||
self.assertIsNotNone(model.data_preprocessor.batch_augments)
|
||||
|
|
|
@ -7,7 +7,6 @@ import torch
|
|||
|
||||
from mmcls.models import Mixup, RandomBatchAugment
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
||||
|
||||
class TestRandomBatchAugment(TestCase):
|
||||
|
@ -54,7 +53,7 @@ class TestRandomBatchAugment(TestCase):
|
|||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
scores = torch.rand(2, 10)
|
||||
|
||||
augments = [
|
||||
dict(type='Mixup', alpha=1.),
|
||||
|
@ -64,18 +63,17 @@ class TestRandomBatchAugment(TestCase):
|
|||
|
||||
with patch('numpy.random', np.random.RandomState(0)):
|
||||
batch_augments.augments[1] = MagicMock()
|
||||
batch_augments(inputs, data_samples)
|
||||
batch_augments.augments[1].assert_called_once_with(
|
||||
inputs, data_samples)
|
||||
batch_augments(inputs, scores)
|
||||
batch_augments.augments[1].assert_called_once_with(inputs, scores)
|
||||
|
||||
augments = [
|
||||
dict(type='Mixup', alpha=1.),
|
||||
dict(type='CutMix', alpha=0.8),
|
||||
]
|
||||
batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0])
|
||||
mixed_inputs, mixed_samples = batch_augments(inputs, data_samples)
|
||||
mixed_inputs, mixed_samples = batch_augments(inputs, scores)
|
||||
self.assertIs(mixed_inputs, inputs)
|
||||
self.assertIs(mixed_samples, data_samples)
|
||||
self.assertIs(mixed_samples, scores)
|
||||
|
||||
|
||||
class TestMixup(TestCase):
|
||||
|
@ -86,45 +84,21 @@ class TestMixup(TestCase):
|
|||
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample(metainfo={
|
||||
'num_classes': 10
|
||||
}).set_gt_label(1) for _ in range(2)
|
||||
]
|
||||
scores = torch.rand(2, 10)
|
||||
|
||||
# test get num_classes from data_samples
|
||||
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = mixup(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_scores.shape, (2, 10))
|
||||
|
||||
# test binary classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
|
||||
scores = torch.rand(2, 1)
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = mixup(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
|
||||
|
||||
# test multi-label classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
|
||||
self.assertEqual(mixed_scores.shape, (2, 1))
|
||||
|
||||
|
||||
class TestCutMix(TestCase):
|
||||
|
@ -135,59 +109,36 @@ class TestCutMix(TestCase):
|
|||
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample(metainfo={
|
||||
'num_classes': 10
|
||||
}).set_gt_label(1) for _ in range(2)
|
||||
]
|
||||
scores = torch.rand(2, 10)
|
||||
|
||||
# test with cutmix_minmax
|
||||
cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = cutmix(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
self.assertEqual(mixed_scores.shape, (2, 10))
|
||||
|
||||
# test without correct_lam
|
||||
cfg = {**self.DEFAULT_ARGS, 'correct_lam': False}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = cutmix(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
self.assertEqual(mixed_scores.shape, (2, 10))
|
||||
|
||||
# test get num_classes from data_samples
|
||||
# test default settings
|
||||
cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = cutmix(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_scores.shape, (2, 10))
|
||||
|
||||
# test binary classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
|
||||
scores = torch.rand(2, 1)
|
||||
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = cutmix(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
|
||||
|
||||
# test multi-label classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
|
||||
cutmix = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
|
||||
self.assertEqual(mixed_scores.shape, (2, 1))
|
||||
|
||||
|
||||
class TestResizeMix(TestCase):
|
||||
|
@ -198,42 +149,18 @@ class TestResizeMix(TestCase):
|
|||
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
|
||||
BATCH_AUGMENTS.build(cfg)
|
||||
|
||||
def test_call(self):
|
||||
inputs = torch.rand(2, 3, 224, 224)
|
||||
data_samples = [
|
||||
ClsDataSample(metainfo={
|
||||
'num_classes': 10
|
||||
}).set_gt_label(1) for _ in range(2)
|
||||
]
|
||||
scores = torch.rand(2, 10)
|
||||
|
||||
# test get num_classes from data_samples
|
||||
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = mixup(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
|
||||
mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_scores.shape, (2, 10))
|
||||
|
||||
# test binary classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
|
||||
scores = torch.rand(2, 1)
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
mixed_inputs, mixed_scores = mixup(inputs, scores)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
|
||||
|
||||
# test multi-label classification
|
||||
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
|
||||
mixup = BATCH_AUGMENTS.build(cfg)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
|
||||
|
||||
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
|
||||
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
|
||||
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
|
||||
self.assertEqual(mixed_scores.shape, (2, 1))
|
||||
|
|
|
@ -71,7 +71,7 @@ class TestClsDataPreprocessor(TestCase):
|
|||
inputs = processed_data['inputs']
|
||||
self.assertTrue((inputs >= -1).all())
|
||||
self.assertTrue((inputs <= 1).all())
|
||||
self.assertNotIn('data_samples', processed_data)
|
||||
self.assertIsNone(processed_data['data_samples'])
|
||||
|
||||
data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))}
|
||||
inputs = processor(data)['inputs']
|
||||
|
@ -81,9 +81,10 @@ class TestClsDataPreprocessor(TestCase):
|
|||
def test_batch_augmentation(self):
|
||||
cfg = dict(
|
||||
type='ClsDataPreprocessor',
|
||||
num_classes=10,
|
||||
batch_augments=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8, num_classes=10),
|
||||
dict(type='CutMix', alpha=1., num_classes=10)
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.)
|
||||
]))
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
|
||||
|
@ -101,4 +102,4 @@ class TestClsDataPreprocessor(TestCase):
|
|||
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
|
||||
processed_data = processor(data, training=True)
|
||||
self.assertIn('inputs', processed_data)
|
||||
self.assertNotIn('data_samples', processed_data)
|
||||
self.assertIsNone(processed_data['data_samples'])
|
||||
|
|
|
@ -60,20 +60,6 @@ class TestClsDataSample(TestCase):
|
|||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
method('hi')
|
||||
|
||||
# Test set num_classes
|
||||
data_sample = ClsDataSample(metainfo={'num_classes': 10})
|
||||
method = getattr(data_sample, 'set_' + key)
|
||||
method(5)
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIn('num_classes', label)
|
||||
self.assertEqual(label.num_classes, 10)
|
||||
|
||||
# Test unavailable label
|
||||
with self.assertRaisesRegex(ValueError, r'data .*[15].* should '):
|
||||
method(15)
|
||||
|
||||
def test_set_gt_label(self):
|
||||
self._test_set_label('gt_label')
|
||||
|
||||
|
@ -97,62 +83,42 @@ class TestClsDataSample(TestCase):
|
|||
self.assertNotIn('pred_label', data_sample)
|
||||
|
||||
def test_set_gt_score(self):
|
||||
data_sample = ClsDataSample(metainfo={'num_classes': 5})
|
||||
data_sample = ClsDataSample()
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertIn('score', data_sample.gt_label)
|
||||
torch.testing.assert_allclose(data_sample.gt_label.score,
|
||||
[0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
self.assertEqual(data_sample.gt_label.num_classes, 5)
|
||||
|
||||
# Test set again
|
||||
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
|
||||
torch.testing.assert_allclose(data_sample.gt_label.score,
|
||||
[0.2, 0.1, 0.5, 0.1, 0.1])
|
||||
|
||||
# Test invalid type
|
||||
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
|
||||
data_sample.set_gt_score([1, 2, 3])
|
||||
# Test invalid length
|
||||
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
|
||||
data_sample.set_gt_score([1, 2])
|
||||
|
||||
# Test invalid dims
|
||||
with self.assertRaisesRegex(AssertionError, 'but got 2'):
|
||||
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
|
||||
|
||||
# Test invalid num_classes
|
||||
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
|
||||
# Test auto inter num_classes
|
||||
data_sample = ClsDataSample()
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertEqual(data_sample.gt_label.num_classes, 5)
|
||||
|
||||
def test_set_pred_score(self):
|
||||
data_sample = ClsDataSample(metainfo={'num_classes': 5})
|
||||
data_sample = ClsDataSample()
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertIn('score', data_sample.pred_label)
|
||||
torch.testing.assert_allclose(data_sample.pred_label.score,
|
||||
[0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
self.assertEqual(data_sample.pred_label.num_classes, 5)
|
||||
|
||||
# Test set again
|
||||
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
|
||||
torch.testing.assert_allclose(data_sample.pred_label.score,
|
||||
[0.2, 0.1, 0.5, 0.1, 0.1])
|
||||
|
||||
# Test invalid type
|
||||
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
|
||||
data_sample.set_pred_score([1, 2, 3])
|
||||
# Test invalid length
|
||||
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
|
||||
data_sample.set_gt_score([1, 2])
|
||||
|
||||
# Test invalid dims
|
||||
with self.assertRaisesRegex(AssertionError, 'but got 2'):
|
||||
data_sample.set_pred_score(
|
||||
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
|
||||
|
||||
# Test invalid num_classes
|
||||
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
|
||||
# Test auto inter num_classes
|
||||
data_sample = ClsDataSample()
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertEqual(data_sample.pred_label.num_classes, 5)
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.structures import (batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
|
||||
|
||||
class TestStructureUtils(TestCase):
|
||||
|
||||
def test_tensor_split(self):
|
||||
tensor = torch.tensor([0, 1, 2, 3, 4, 5, 6])
|
||||
split_indices = [0, 2, 6, 6]
|
||||
outs = tensor_split(tensor, split_indices)
|
||||
self.assertEqual(len(outs), len(split_indices) + 1)
|
||||
self.assertEqual(outs[0].size(0), 0)
|
||||
self.assertEqual(outs[1].size(0), 2)
|
||||
self.assertEqual(outs[2].size(0), 4)
|
||||
self.assertEqual(outs[3].size(0), 0)
|
||||
self.assertEqual(outs[4].size(0), 1)
|
||||
|
||||
tensor = torch.tensor([])
|
||||
split_indices = [0, 0, 0, 0]
|
||||
outs = tensor_split(tensor, split_indices)
|
||||
self.assertEqual(len(outs), len(split_indices) + 1)
|
||||
|
||||
def test_cat_batch_labels(self):
|
||||
labels = [
|
||||
LabelData(label=torch.tensor([1])),
|
||||
LabelData(label=torch.tensor([3, 2])),
|
||||
LabelData(label=torch.tensor([0, 1, 4])),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
]
|
||||
|
||||
batch_label, split_indices = cat_batch_labels(labels)
|
||||
self.assertEqual(split_indices, [1, 3, 6, 6])
|
||||
self.assertEqual(len(batch_label), 6)
|
||||
labels = tensor_split(batch_label, split_indices)
|
||||
self.assertEqual(labels[0].tolist(), [1])
|
||||
self.assertEqual(labels[1].tolist(), [3, 2])
|
||||
self.assertEqual(labels[2].tolist(), [0, 1, 4])
|
||||
self.assertEqual(labels[3].tolist(), [])
|
||||
self.assertEqual(labels[4].tolist(), [])
|
||||
|
||||
labels = [
|
||||
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
|
||||
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
|
||||
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
|
||||
]
|
||||
batch_label, split_indices = cat_batch_labels(labels)
|
||||
self.assertIsNone(batch_label)
|
||||
self.assertIsNone(split_indices)
|
||||
|
||||
def test_stack_batch_scores(self):
|
||||
labels = [
|
||||
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
|
||||
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
|
||||
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
|
||||
]
|
||||
|
||||
batch_score = stack_batch_scores(labels)
|
||||
self.assertEqual(batch_score.shape, (3, 5))
|
||||
|
||||
labels = [
|
||||
LabelData(label=torch.tensor([1])),
|
||||
LabelData(label=torch.tensor([3, 2])),
|
||||
LabelData(label=torch.tensor([0, 1, 4])),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
]
|
||||
batch_score = stack_batch_scores(labels)
|
||||
self.assertIsNone(batch_score)
|
||||
|
||||
def test_batch_label_to_onehot(self):
|
||||
labels = [
|
||||
LabelData(label=torch.tensor([1])),
|
||||
LabelData(label=torch.tensor([3, 2])),
|
||||
LabelData(label=torch.tensor([0, 1, 4])),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
]
|
||||
|
||||
batch_label, split_indices = cat_batch_labels(labels)
|
||||
batch_score = batch_label_to_onehot(
|
||||
batch_label, split_indices, num_classes=5)
|
||||
self.assertEqual(batch_score[0].tolist(), [0, 1, 0, 0, 0])
|
||||
self.assertEqual(batch_score[1].tolist(), [0, 0, 1, 1, 0])
|
||||
self.assertEqual(batch_score[2].tolist(), [1, 1, 0, 0, 1])
|
||||
self.assertEqual(batch_score[3].tolist(), [0, 0, 0, 0, 0])
|
||||
self.assertEqual(batch_score[4].tolist(), [0, 0, 0, 0, 0])
|
Loading…
Reference in New Issue