[Feature] Add `PackClsInputs` and use `LoadImageFromFile`, `Resize` & `RandomFlip` in MMCV.
parent
0537c4d70c
commit
93a27c8324
|
@ -92,7 +92,7 @@ def inference(config_file, checkpoint, classes, args):
|
||||||
cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
||||||
if cfg.data.test.type in ['CIFAR10', 'CIFAR100']:
|
if cfg.data.test.type in ['CIFAR10', 'CIFAR100']:
|
||||||
# The image shape of CIFAR is (32, 32, 3)
|
# The image shape of CIFAR is (32, 32, 3)
|
||||||
cfg.data.test.pipeline.insert(1, dict(type='Resize', size=32))
|
cfg.data.test.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||||
|
|
||||||
data = dict(img_info=dict(filename=args.img), img_prefix=None)
|
data = dict(img_info=dict(filename=args.img), img_prefix=None)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ img_norm_cfg = dict(
|
||||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=510),
|
dict(type='Resize', scale=510),
|
||||||
dict(type='RandomCrop', size=384),
|
dict(type='RandomCrop', size=384),
|
||||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
@ -14,7 +14,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=510),
|
dict(type='Resize', scale=510),
|
||||||
dict(type='CenterCrop', crop_size=384),
|
dict(type='CenterCrop', crop_size=384),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -4,7 +4,7 @@ img_norm_cfg = dict(
|
||||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=600),
|
dict(type='Resize', scale=600),
|
||||||
dict(type='RandomCrop', size=448),
|
dict(type='RandomCrop', size=448),
|
||||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
@ -14,7 +14,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=600),
|
dict(type='Resize', scale=600),
|
||||||
dict(type='CenterCrop', crop_size=448),
|
dict(type='CenterCrop', crop_size=448),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -13,7 +13,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -41,7 +41,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(236, -1),
|
scale=(236, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -25,7 +25,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(236, -1)),
|
dict(type='Resize', scale=(236, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -25,7 +25,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(236, -1)),
|
dict(type='Resize', scale=(236, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -13,7 +13,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -19,7 +19,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(256, -1),
|
scale=(256, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -13,7 +13,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1), backend='pillow'),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -13,7 +13,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -16,7 +16,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -41,7 +41,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(233, -1),
|
scale=(233, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -20,7 +20,11 @@ train_pipeline = [
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize', size=(256, -1), backend='cv2', interpolation='bicubic'),
|
type='Resize',
|
||||||
|
scale=(256, -1),
|
||||||
|
keep_ratio=True,
|
||||||
|
backend='cv2',
|
||||||
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -13,7 +13,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1), backend='pillow'),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -24,7 +24,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(256, -1),
|
scale=(256, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -41,7 +41,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(256, -1),
|
scale=(256, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -17,7 +17,12 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=384, backend='pillow', interpolation='bicubic'),
|
dict(
|
||||||
|
type='Resize',
|
||||||
|
scale=(256, -1),
|
||||||
|
keep_ratio=True,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img'])
|
dict(type='Collect', keys=['img'])
|
||||||
|
|
|
@ -41,7 +41,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(248, -1),
|
scale=(248, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -13,7 +13,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -36,7 +36,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(288, -1),
|
scale=(288, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=256),
|
dict(type='CenterCrop', crop_size=256),
|
||||||
|
|
|
@ -37,7 +37,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(288, -1),
|
scale=(288, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=256),
|
dict(type='CenterCrop', crop_size=256),
|
||||||
|
|
|
@ -36,7 +36,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(256, -1),
|
scale=(256, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -11,14 +11,14 @@ model = dict(
|
||||||
dataset_type = 'MNIST'
|
dataset_type = 'MNIST'
|
||||||
img_norm_cfg = dict(mean=[33.46], std=[78.87], to_rgb=True)
|
img_norm_cfg = dict(mean=[33.46], std=[78.87], to_rgb=True)
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='Resize', size=32),
|
dict(type='Resize', scale=32),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='ToTensor', keys=['gt_label']),
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
dict(type='Collect', keys=['img', 'gt_label']),
|
dict(type='Collect', keys=['img', 'gt_label']),
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='Resize', size=32),
|
dict(type='Resize', scale=32),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img']),
|
dict(type='Collect', keys=['img']),
|
||||||
|
|
|
@ -51,7 +51,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', mean=NORM_MEAN, std=NORM_STD, to_rgb=False),
|
dict(type='Normalize', mean=NORM_MEAN, std=NORM_STD, to_rgb=False),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -11,7 +11,11 @@ img_norm_cfg = dict(
|
||||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256 * 256 // 224, -1), backend='pillow'),
|
dict(
|
||||||
|
type='Resize',
|
||||||
|
scale_factor=(256 * 256 // 224, -1),
|
||||||
|
keep_ratio=True,
|
||||||
|
backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=256),
|
dict(type='CenterCrop', crop_size=256),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -10,7 +10,7 @@ img_norm_cfg = dict(
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
# resizing to (256, 256) here, different with resizing shorter edge to 256
|
# resizing to (256, 256) here, different with resizing shorter edge to 256
|
||||||
dict(type='Resize', size=(256, 256), backend='pillow'),
|
dict(type='Resize', scale=(256, 256), backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -12,7 +12,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(248, -1),
|
scale=(248, -1),
|
||||||
|
keep_ratio=True,
|
||||||
interpolation='bicubic',
|
interpolation='bicubic',
|
||||||
backend='pillow'),
|
backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -45,7 +45,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(248, -1),
|
scale=(248, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -45,7 +45,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(248, -1),
|
scale=(248, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -45,7 +45,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(248, -1),
|
scale=(248, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -45,7 +45,8 @@ test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(
|
dict(
|
||||||
type='Resize',
|
type='Resize',
|
||||||
size=(248, -1),
|
scale=(248, -1),
|
||||||
|
keep_ratio=True,
|
||||||
backend='pillow',
|
backend='pillow',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
|
|
@ -39,7 +39,7 @@ train_pipeline = [
|
||||||
|
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(224, -1), backend='pillow'),
|
dict(type='Resize', scale=(224, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -22,7 +22,7 @@ train_pipeline = [
|
||||||
|
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(384, -1), backend='pillow'),
|
dict(type='Resize', scale=(384, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=384),
|
dict(type='CenterCrop', crop_size=384),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -22,7 +22,7 @@ train_pipeline = [
|
||||||
|
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(384, -1), backend='pillow'),
|
dict(type='Resize', scale=(384, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=384),
|
dict(type='CenterCrop', crop_size=384),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -22,7 +22,7 @@ train_pipeline = [
|
||||||
|
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(384, -1), backend='pillow'),
|
dict(type='Resize', scale=(384, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=384),
|
dict(type='CenterCrop', crop_size=384),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -23,7 +23,7 @@ train_pipeline = [
|
||||||
|
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(384, -1), backend='pillow'),
|
dict(type='Resize', scale=(384, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='CenterCrop', crop_size=384),
|
dict(type='CenterCrop', crop_size=384),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -30,7 +30,7 @@ for example:
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=256),
|
dict(type='Resize', scale=256),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -193,7 +193,7 @@ train_pipeline = [
|
||||||
# test data pipeline
|
# test data pipeline
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
@ -309,7 +309,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=384, backend='pillow'),
|
dict(type='Resize', scale=(384, -1), keep_ratio=True, backend='pillow'),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img'])
|
dict(type='Collect', keys=['img'])
|
||||||
|
|
|
@ -28,7 +28,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=256),
|
dict(type='Resize', scale=256),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -112,14 +112,14 @@ img_norm_cfg = dict(
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='RandomCrop', size=32, padding=4),
|
dict(type='RandomCrop', size=32, padding=4),
|
||||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='ToTensor', keys=['gt_label']),
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
dict(type='Collect', keys=['img', 'gt_label']),
|
dict(type='Collect', keys=['img', 'gt_label']),
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img']),
|
dict(type='Collect', keys=['img']),
|
||||||
|
@ -177,14 +177,14 @@ img_norm_cfg = dict(
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='RandomCrop', size=32, padding=4),
|
dict(type='RandomCrop', size=32, padding=4),
|
||||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='ToTensor', keys=['gt_label']),
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
dict(type='Collect', keys=['img', 'gt_label']),
|
dict(type='Collect', keys=['img', 'gt_label']),
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img']),
|
dict(type='Collect', keys=['img']),
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -194,7 +194,7 @@ train_pipeline = [
|
||||||
# 测试数据流水线
|
# 测试数据流水线
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=(256, -1)),
|
dict(type='Resize', scale=(256, -1), keep_ratio=True),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
@ -310,7 +310,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=384, backend='pillow'),
|
dict(type='Resize', scale=384, backend='pillow'),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img'])
|
dict(type='Collect', keys=['img'])
|
||||||
|
|
|
@ -27,7 +27,7 @@ train_pipeline = [
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='LoadImageFromFile'),
|
dict(type='LoadImageFromFile'),
|
||||||
dict(type='Resize', size=256),
|
dict(type='Resize', scale=256),
|
||||||
dict(type='CenterCrop', crop_size=224),
|
dict(type='CenterCrop', crop_size=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
|
|
@ -102,14 +102,14 @@ img_norm_cfg = dict(
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='RandomCrop', size=32, padding=4),
|
dict(type='RandomCrop', size=32, padding=4),
|
||||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='ToTensor', keys=['gt_label']),
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
dict(type='Collect', keys=['img', 'gt_label']),
|
dict(type='Collect', keys=['img', 'gt_label']),
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img']),
|
dict(type='Collect', keys=['img']),
|
||||||
|
@ -166,14 +166,14 @@ img_norm_cfg = dict(
|
||||||
train_pipeline = [
|
train_pipeline = [
|
||||||
dict(type='RandomCrop', size=32, padding=4),
|
dict(type='RandomCrop', size=32, padding=4),
|
||||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='ToTensor', keys=['gt_label']),
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
dict(type='Collect', keys=['img', 'gt_label']),
|
dict(type='Collect', keys=['img', 'gt_label']),
|
||||||
]
|
]
|
||||||
test_pipeline = [
|
test_pipeline = [
|
||||||
dict(type='Resize', size=224),
|
dict(type='Resize', scale=224),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
dict(type='Collect', keys=['img']),
|
dict(type='Collect', keys=['img']),
|
||||||
|
|
|
@ -4,19 +4,18 @@ from .auto_augment import (AutoAugment, AutoContrast, Brightness,
|
||||||
Posterize, RandAugment, Rotate, Sharpness, Shear,
|
Posterize, RandAugment, Rotate, Sharpness, Shear,
|
||||||
Solarize, SolarizeAdd, Translate)
|
Solarize, SolarizeAdd, Translate)
|
||||||
from .compose import Compose
|
from .compose import Compose
|
||||||
from .formatting import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
from .formatting import (Collect, ImageToTensor, PackClsInputs, ToNumpy, ToPIL,
|
||||||
Transpose, to_tensor)
|
ToTensor, Transpose, to_tensor)
|
||||||
from .loading import LoadImageFromFile
|
|
||||||
from .transforms import (CenterCrop, ColorJitter, Lighting, Normalize, Pad,
|
from .transforms import (CenterCrop, ColorJitter, Lighting, Normalize, Pad,
|
||||||
RandomCrop, RandomErasing, RandomFlip,
|
RandomCrop, RandomErasing, RandomGrayscale,
|
||||||
RandomGrayscale, RandomResizedCrop, Resize)
|
RandomResizedCrop)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
'Transpose', 'Collect', 'CenterCrop', 'Normalize', 'RandomCrop',
|
||||||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
'RandomResizedCrop', 'RandomGrayscale', 'Shear', 'Translate', 'Rotate',
|
||||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
|
'Invert', 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast',
|
||||||
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
'Equalize', 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment',
|
||||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
'SolarizeAdd', 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter',
|
||||||
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad'
|
'RandomErasing', 'Pad', 'PackClsInputs'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcv.parallel import DataContainer as DC
|
from mmcv.parallel import DataContainer as DC
|
||||||
|
from mmcv.transforms.base import BaseTransform
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from mmcls.core import ClsDataSample
|
||||||
from mmcls.registry import TRANSFORMS
|
from mmcls.registry import TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,8 +36,87 @@ def to_tensor(data):
|
||||||
'`Sequence`, `int` and `float`')
|
'`Sequence`, `int` and `float`')
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class PackClsInputs(BaseTransform):
|
||||||
|
"""Pack the inputs data for the classification.
|
||||||
|
|
||||||
|
The ``img_meta`` item is always populated. The contents of the
|
||||||
|
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||||
|
|
||||||
|
- ``sample_idx``: id of the image sample
|
||||||
|
|
||||||
|
- ``img_path``: path to the image file
|
||||||
|
|
||||||
|
- ``ori_shape``: original shape of the image as a tuple (H, W).
|
||||||
|
|
||||||
|
- ``img_shape``: shape of the image input to the network as a tuple
|
||||||
|
(H, W). Note that images may be zero padded on the bottom/right
|
||||||
|
if the batch tensor is larger than this shape.
|
||||||
|
|
||||||
|
- ``scale_factor``: a float indicating the preprocessing scale
|
||||||
|
|
||||||
|
- ``flip``: a boolean indicating if image flip transform was used
|
||||||
|
|
||||||
|
- ``flip_direction``: the flipping direction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
meta_keys (Sequence[str], optional): The meta keys to saved in the
|
||||||
|
``metainfo`` of the packed ``data_sample``.
|
||||||
|
Default: ``('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||||
|
'scale_factor', 'flip', 'flip_direction')``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||||
|
'scale_factor', 'flip', 'flip_direction')):
|
||||||
|
self.meta_keys = meta_keys
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Method to pack the input data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from the data pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict:
|
||||||
|
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||||
|
- 'data_sample' (obj:`ClsDataSample`): The annotation info of the
|
||||||
|
sample.
|
||||||
|
"""
|
||||||
|
packed_results = dict()
|
||||||
|
if 'img' in results:
|
||||||
|
img = results['img']
|
||||||
|
if len(img.shape) < 3:
|
||||||
|
img = np.expand_dims(img, -1)
|
||||||
|
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||||
|
packed_results['inputs'] = to_tensor(img)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
'Cannot get "img" in the input dict of `PackClsInputs`,'
|
||||||
|
'please make sure `LoadImageFromFile` has been added '
|
||||||
|
'in the data pipeline or images have been loaded in '
|
||||||
|
'the dataset.')
|
||||||
|
|
||||||
|
data_sample = ClsDataSample()
|
||||||
|
if 'gt_label' in results:
|
||||||
|
gt_label = results['gt_label']
|
||||||
|
data_sample.set_gt_label(gt_label)
|
||||||
|
|
||||||
|
img_meta = {k: results[k] for k in self.meta_keys if k in results}
|
||||||
|
data_sample.set_metainfo(img_meta)
|
||||||
|
packed_results['data_sample'] = data_sample
|
||||||
|
|
||||||
|
return packed_results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(meta_keys={self.meta_keys})'
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ToTensor(object):
|
class ToTensor(object):
|
||||||
|
"""Convert objects of various python types to :obj:`torch.Tensor`."""
|
||||||
|
|
||||||
def __init__(self, keys):
|
def __init__(self, keys):
|
||||||
self.keys = keys
|
self.keys = keys
|
||||||
|
@ -50,6 +132,7 @@ class ToTensor(object):
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ImageToTensor(object):
|
class ImageToTensor(object):
|
||||||
|
"""Convert objects :obj:`PIL.Image` to :obj:`torch.Tensor`."""
|
||||||
|
|
||||||
def __init__(self, keys):
|
def __init__(self, keys):
|
||||||
self.keys = keys
|
self.keys = keys
|
||||||
|
@ -68,6 +151,7 @@ class ImageToTensor(object):
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class Transpose(object):
|
class Transpose(object):
|
||||||
|
"""matrix transpose."""
|
||||||
|
|
||||||
def __init__(self, keys, order):
|
def __init__(self, keys, order):
|
||||||
self.keys = keys
|
self.keys = keys
|
||||||
|
@ -85,6 +169,7 @@ class Transpose(object):
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ToPIL(object):
|
class ToPIL(object):
|
||||||
|
"""Convert tensor to :obj:`PIL.Image`."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -96,6 +181,7 @@ class ToPIL(object):
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ToNumpy(object):
|
class ToNumpy(object):
|
||||||
|
"""Convert tensor to :obj:`np.ndarray`."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mmcls.registry import TRANSFORMS
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class LoadImageFromFile(object):
|
|
||||||
"""Load an image from file.
|
|
||||||
|
|
||||||
Required keys are "img_prefix" and "img_info" (a dict that must contain the
|
|
||||||
key "filename"). Added or updated keys are "filename", "img", "img_shape",
|
|
||||||
"ori_shape" (same as `img_shape`) and "img_norm_cfg" (means=0 and stds=1).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
||||||
numpy array. If set to False, the loaded image is an uint8 array.
|
|
||||||
Defaults to False.
|
|
||||||
color_type (str): The flag argument for :func:`mmcv.imfrombytes()`.
|
|
||||||
Defaults to 'color'.
|
|
||||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
|
||||||
See :class:`mmcv.fileio.FileClient` for details.
|
|
||||||
Defaults to ``dict(backend='disk')``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
to_float32=False,
|
|
||||||
color_type='color',
|
|
||||||
file_client_args=dict(backend='disk')):
|
|
||||||
self.to_float32 = to_float32
|
|
||||||
self.color_type = color_type
|
|
||||||
self.file_client_args = file_client_args.copy()
|
|
||||||
self.file_client = None
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
if self.file_client is None:
|
|
||||||
self.file_client = mmcv.FileClient(**self.file_client_args)
|
|
||||||
|
|
||||||
if results['img_prefix'] is not None:
|
|
||||||
filename = osp.join(results['img_prefix'],
|
|
||||||
results['img_info']['filename'])
|
|
||||||
else:
|
|
||||||
filename = results['img_info']['filename']
|
|
||||||
|
|
||||||
img_bytes = self.file_client.get(filename)
|
|
||||||
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
|
|
||||||
if self.to_float32:
|
|
||||||
img = img.astype(np.float32)
|
|
||||||
|
|
||||||
results['filename'] = filename
|
|
||||||
results['ori_filename'] = results['img_info']['filename']
|
|
||||||
results['img'] = img
|
|
||||||
results['img_shape'] = img.shape
|
|
||||||
results['ori_shape'] = img.shape
|
|
||||||
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
|
|
||||||
results['img_norm_cfg'] = dict(
|
|
||||||
mean=np.zeros(num_channels, dtype=np.float32),
|
|
||||||
std=np.ones(num_channels, dtype=np.float32),
|
|
||||||
to_rgb=False)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = (f'{self.__class__.__name__}('
|
|
||||||
f'to_float32={self.to_float32}, '
|
|
||||||
f"color_type='{self.color_type}', "
|
|
||||||
f'file_client_args={self.file_client_args})')
|
|
||||||
return repr_str
|
|
|
@ -430,48 +430,6 @@ class RandomGrayscale(object):
|
||||||
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
|
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class RandomFlip(object):
|
|
||||||
"""Flip the image randomly.
|
|
||||||
|
|
||||||
Flip the image randomly based on flip probaility and flip direction.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
flip_prob (float): probability of the image being flipped. Default: 0.5
|
|
||||||
direction (str): The flipping direction. Options are
|
|
||||||
'horizontal' and 'vertical'. Default: 'horizontal'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, flip_prob=0.5, direction='horizontal'):
|
|
||||||
assert 0 <= flip_prob <= 1
|
|
||||||
assert direction in ['horizontal', 'vertical']
|
|
||||||
self.flip_prob = flip_prob
|
|
||||||
self.direction = direction
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
"""Call function to flip image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (dict): Result dict from loading pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Flipped results, 'flip', 'flip_direction' keys are added into
|
|
||||||
result dict.
|
|
||||||
"""
|
|
||||||
flip = True if np.random.rand() < self.flip_prob else False
|
|
||||||
results['flip'] = flip
|
|
||||||
results['flip_direction'] = self.direction
|
|
||||||
if results['flip']:
|
|
||||||
# flip image
|
|
||||||
for key in results.get('img_fields', ['img']):
|
|
||||||
results[key] = mmcv.imflip(
|
|
||||||
results[key], direction=results['flip_direction'])
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self.__class__.__name__ + f'(flip_prob={self.flip_prob})'
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class RandomErasing(object):
|
class RandomErasing(object):
|
||||||
"""Randomly selects a rectangle region in an image and erase pixels.
|
"""Randomly selects a rectangle region in an image and erase pixels.
|
||||||
|
@ -664,111 +622,6 @@ class Pad(object):
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
|
||||||
class Resize(object):
|
|
||||||
"""Resize images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size (int | tuple): Images scales for resizing (h, w).
|
|
||||||
When size is int, the default behavior is to resize an image
|
|
||||||
to (size, size). When size is tuple and the second value is -1,
|
|
||||||
the image will be resized according to adaptive_side. For example,
|
|
||||||
when size is 224, the image is resized to 224x224. When size is
|
|
||||||
(224, -1) and adaptive_size is "short", the short side is resized
|
|
||||||
to 224 and the other side is computed based on the short side,
|
|
||||||
maintaining the aspect ratio.
|
|
||||||
interpolation (str): Interpolation method. For "cv2" backend, accepted
|
|
||||||
values are "nearest", "bilinear", "bicubic", "area", "lanczos". For
|
|
||||||
"pillow" backend, accepted values are "nearest", "bilinear",
|
|
||||||
"bicubic", "box", "lanczos", "hamming".
|
|
||||||
More details can be found in `mmcv.image.geometric`.
|
|
||||||
adaptive_side(str): Adaptive resize policy, accepted values are
|
|
||||||
"short", "long", "height", "width". Default to "short".
|
|
||||||
backend (str): The image resize backend type, accepted values are
|
|
||||||
`cv2` and `pillow`. Default: `cv2`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
size,
|
|
||||||
interpolation='bilinear',
|
|
||||||
adaptive_side='short',
|
|
||||||
backend='cv2'):
|
|
||||||
assert isinstance(size, int) or (isinstance(size, tuple)
|
|
||||||
and len(size) == 2)
|
|
||||||
assert adaptive_side in {'short', 'long', 'height', 'width'}
|
|
||||||
|
|
||||||
self.adaptive_side = adaptive_side
|
|
||||||
self.adaptive_resize = False
|
|
||||||
if isinstance(size, int):
|
|
||||||
assert size > 0
|
|
||||||
size = (size, size)
|
|
||||||
else:
|
|
||||||
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
|
|
||||||
if size[1] == -1:
|
|
||||||
self.adaptive_resize = True
|
|
||||||
if backend not in ['cv2', 'pillow']:
|
|
||||||
raise ValueError(f'backend: {backend} is not supported for resize.'
|
|
||||||
'Supported backends are "cv2", "pillow"')
|
|
||||||
if backend == 'cv2':
|
|
||||||
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
|
|
||||||
'lanczos')
|
|
||||||
else:
|
|
||||||
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'box',
|
|
||||||
'lanczos', 'hamming')
|
|
||||||
self.size = size
|
|
||||||
self.interpolation = interpolation
|
|
||||||
self.backend = backend
|
|
||||||
|
|
||||||
def _resize_img(self, results):
|
|
||||||
for key in results.get('img_fields', ['img']):
|
|
||||||
img = results[key]
|
|
||||||
ignore_resize = False
|
|
||||||
if self.adaptive_resize:
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
target_size = self.size[0]
|
|
||||||
|
|
||||||
condition_ignore_resize = {
|
|
||||||
'short': min(h, w) == target_size,
|
|
||||||
'long': max(h, w) == target_size,
|
|
||||||
'height': h == target_size,
|
|
||||||
'width': w == target_size
|
|
||||||
}
|
|
||||||
|
|
||||||
if condition_ignore_resize[self.adaptive_side]:
|
|
||||||
ignore_resize = True
|
|
||||||
elif any([
|
|
||||||
self.adaptive_side == 'short' and w < h,
|
|
||||||
self.adaptive_side == 'long' and w > h,
|
|
||||||
self.adaptive_side == 'width',
|
|
||||||
]):
|
|
||||||
width = target_size
|
|
||||||
height = int(target_size * h / w)
|
|
||||||
else:
|
|
||||||
height = target_size
|
|
||||||
width = int(target_size * w / h)
|
|
||||||
else:
|
|
||||||
height, width = self.size
|
|
||||||
if not ignore_resize:
|
|
||||||
img = mmcv.imresize(
|
|
||||||
img,
|
|
||||||
size=(width, height),
|
|
||||||
interpolation=self.interpolation,
|
|
||||||
return_scale=False,
|
|
||||||
backend=self.backend)
|
|
||||||
results[key] = img
|
|
||||||
results['img_shape'] = img.shape
|
|
||||||
|
|
||||||
def __call__(self, results):
|
|
||||||
self._resize_img(results)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
repr_str = self.__class__.__name__
|
|
||||||
repr_str += f'(size={self.size}, '
|
|
||||||
repr_str += f'interpolation={self.interpolation})'
|
|
||||||
return repr_str
|
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class CenterCrop(object):
|
class CenterCrop(object):
|
||||||
r"""Center crop the image.
|
r"""Center crop the image.
|
||||||
|
|
|
@ -1,272 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import os.path as osp
|
|
||||||
from copy import deepcopy
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from mmcv.utils import digit_version
|
|
||||||
|
|
||||||
from mmcls.datasets import ImageNet, build_dataloader, build_dataset
|
|
||||||
from mmcls.datasets.dataset_wrappers import (ClassBalancedDataset,
|
|
||||||
ConcatDataset, KFoldDataset,
|
|
||||||
RepeatDataset)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataloaderBuilder():
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls):
|
|
||||||
cls.data = list(range(20))
|
|
||||||
cls.samples_per_gpu = 5
|
|
||||||
cls.workers_per_gpu = 1
|
|
||||||
|
|
||||||
@patch('mmcls.datasets.builder.get_dist_info', return_value=(0, 1))
|
|
||||||
def test_single_gpu(self, _):
|
|
||||||
common_cfg = dict(
|
|
||||||
dataset=self.data,
|
|
||||||
samples_per_gpu=self.samples_per_gpu,
|
|
||||||
workers_per_gpu=self.workers_per_gpu,
|
|
||||||
dist=False)
|
|
||||||
|
|
||||||
# Test default config
|
|
||||||
dataloader = build_dataloader(**common_cfg)
|
|
||||||
|
|
||||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
|
||||||
assert dataloader.persistent_workers
|
|
||||||
elif hasattr(dataloader, 'persistent_workers'):
|
|
||||||
assert not dataloader.persistent_workers
|
|
||||||
|
|
||||||
assert dataloader.batch_size == self.samples_per_gpu
|
|
||||||
assert dataloader.num_workers == self.workers_per_gpu
|
|
||||||
assert not all(
|
|
||||||
torch.cat(list(iter(dataloader))) == torch.tensor(self.data))
|
|
||||||
|
|
||||||
# Test without shuffle
|
|
||||||
dataloader = build_dataloader(**common_cfg, shuffle=False)
|
|
||||||
assert all(
|
|
||||||
torch.cat(list(iter(dataloader))) == torch.tensor(self.data))
|
|
||||||
|
|
||||||
# Test with custom sampler_cfg
|
|
||||||
dataloader = build_dataloader(
|
|
||||||
**common_cfg,
|
|
||||||
sampler_cfg=dict(type='RepeatAugSampler', selected_round=0),
|
|
||||||
shuffle=False)
|
|
||||||
expect = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6]
|
|
||||||
assert all(torch.cat(list(iter(dataloader))) == torch.tensor(expect))
|
|
||||||
|
|
||||||
@patch('mmcls.datasets.builder.get_dist_info', return_value=(0, 1))
|
|
||||||
def test_multi_gpu(self, _):
|
|
||||||
common_cfg = dict(
|
|
||||||
dataset=self.data,
|
|
||||||
samples_per_gpu=self.samples_per_gpu,
|
|
||||||
workers_per_gpu=self.workers_per_gpu,
|
|
||||||
num_gpus=2,
|
|
||||||
dist=False)
|
|
||||||
|
|
||||||
# Test default config
|
|
||||||
dataloader = build_dataloader(**common_cfg)
|
|
||||||
|
|
||||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
|
||||||
assert dataloader.persistent_workers
|
|
||||||
elif hasattr(dataloader, 'persistent_workers'):
|
|
||||||
assert not dataloader.persistent_workers
|
|
||||||
|
|
||||||
assert dataloader.batch_size == self.samples_per_gpu * 2
|
|
||||||
assert dataloader.num_workers == self.workers_per_gpu * 2
|
|
||||||
assert not all(
|
|
||||||
torch.cat(list(iter(dataloader))) == torch.tensor(self.data))
|
|
||||||
|
|
||||||
# Test without shuffle
|
|
||||||
dataloader = build_dataloader(**common_cfg, shuffle=False)
|
|
||||||
assert all(
|
|
||||||
torch.cat(list(iter(dataloader))) == torch.tensor(self.data))
|
|
||||||
|
|
||||||
# Test with custom sampler_cfg
|
|
||||||
dataloader = build_dataloader(
|
|
||||||
**common_cfg,
|
|
||||||
sampler_cfg=dict(type='RepeatAugSampler', selected_round=0),
|
|
||||||
shuffle=False)
|
|
||||||
expect = torch.tensor(
|
|
||||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6])
|
|
||||||
assert all(torch.cat(list(iter(dataloader))) == expect)
|
|
||||||
|
|
||||||
@patch('mmcls.datasets.builder.get_dist_info', return_value=(1, 2))
|
|
||||||
def test_distributed(self, _):
|
|
||||||
common_cfg = dict(
|
|
||||||
dataset=self.data,
|
|
||||||
samples_per_gpu=self.samples_per_gpu,
|
|
||||||
workers_per_gpu=self.workers_per_gpu,
|
|
||||||
num_gpus=2, # num_gpus will be ignored in distributed environment.
|
|
||||||
dist=True)
|
|
||||||
|
|
||||||
# Test default config
|
|
||||||
dataloader = build_dataloader(**common_cfg)
|
|
||||||
|
|
||||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
|
||||||
assert dataloader.persistent_workers
|
|
||||||
elif hasattr(dataloader, 'persistent_workers'):
|
|
||||||
assert not dataloader.persistent_workers
|
|
||||||
|
|
||||||
assert dataloader.batch_size == self.samples_per_gpu
|
|
||||||
assert dataloader.num_workers == self.workers_per_gpu
|
|
||||||
non_expect = torch.tensor(self.data[1::2])
|
|
||||||
assert not all(torch.cat(list(iter(dataloader))) == non_expect)
|
|
||||||
|
|
||||||
# Test without shuffle
|
|
||||||
dataloader = build_dataloader(**common_cfg, shuffle=False)
|
|
||||||
expect = torch.tensor(self.data[1::2])
|
|
||||||
assert all(torch.cat(list(iter(dataloader))) == expect)
|
|
||||||
|
|
||||||
# Test with custom sampler_cfg
|
|
||||||
dataloader = build_dataloader(
|
|
||||||
**common_cfg,
|
|
||||||
sampler_cfg=dict(type='RepeatAugSampler', selected_round=0),
|
|
||||||
shuffle=False)
|
|
||||||
expect = torch.tensor(
|
|
||||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6][1::2])
|
|
||||||
assert all(torch.cat(list(iter(dataloader))) == expect)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetBuilder():
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls):
|
|
||||||
data_prefix = osp.join(osp.dirname(__file__), '../data/dataset')
|
|
||||||
cls.dataset_cfg = dict(
|
|
||||||
type='ImageNet',
|
|
||||||
data_prefix=data_prefix,
|
|
||||||
ann_file=osp.join(data_prefix, 'ann.txt'),
|
|
||||||
pipeline=[],
|
|
||||||
test_mode=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_normal_dataset(self):
|
|
||||||
# Test build
|
|
||||||
dataset = build_dataset(self.dataset_cfg)
|
|
||||||
assert isinstance(dataset, ImageNet)
|
|
||||||
assert dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
# Test default_args
|
|
||||||
dataset = build_dataset(self.dataset_cfg, {'test_mode': True})
|
|
||||||
assert dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
cp_cfg = deepcopy(self.dataset_cfg)
|
|
||||||
cp_cfg.pop('test_mode')
|
|
||||||
dataset = build_dataset(cp_cfg, {'test_mode': True})
|
|
||||||
assert dataset.test_mode
|
|
||||||
|
|
||||||
def test_concat_dataset(self):
|
|
||||||
# Test build
|
|
||||||
dataset = build_dataset([self.dataset_cfg, self.dataset_cfg])
|
|
||||||
assert isinstance(dataset, ConcatDataset)
|
|
||||||
assert dataset.datasets[0].test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
# Test default_args
|
|
||||||
dataset = build_dataset([self.dataset_cfg, self.dataset_cfg],
|
|
||||||
{'test_mode': True})
|
|
||||||
assert dataset.datasets[0].test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
cp_cfg = deepcopy(self.dataset_cfg)
|
|
||||||
cp_cfg.pop('test_mode')
|
|
||||||
dataset = build_dataset([cp_cfg, cp_cfg], {'test_mode': True})
|
|
||||||
assert dataset.datasets[0].test_mode
|
|
||||||
|
|
||||||
def test_repeat_dataset(self):
|
|
||||||
# Test build
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(type='RepeatDataset', dataset=self.dataset_cfg, times=3))
|
|
||||||
assert isinstance(dataset, RepeatDataset)
|
|
||||||
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
# Test default_args
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(type='RepeatDataset', dataset=self.dataset_cfg, times=3),
|
|
||||||
{'test_mode': True})
|
|
||||||
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
cp_cfg = deepcopy(self.dataset_cfg)
|
|
||||||
cp_cfg.pop('test_mode')
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(type='RepeatDataset', dataset=cp_cfg, times=3),
|
|
||||||
{'test_mode': True})
|
|
||||||
assert dataset.dataset.test_mode
|
|
||||||
|
|
||||||
def test_class_balance_dataset(self):
|
|
||||||
# Test build
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(
|
|
||||||
type='ClassBalancedDataset',
|
|
||||||
dataset=self.dataset_cfg,
|
|
||||||
oversample_thr=1.,
|
|
||||||
))
|
|
||||||
assert isinstance(dataset, ClassBalancedDataset)
|
|
||||||
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
# Test default_args
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(
|
|
||||||
type='ClassBalancedDataset',
|
|
||||||
dataset=self.dataset_cfg,
|
|
||||||
oversample_thr=1.,
|
|
||||||
), {'test_mode': True})
|
|
||||||
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
cp_cfg = deepcopy(self.dataset_cfg)
|
|
||||||
cp_cfg.pop('test_mode')
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(
|
|
||||||
type='ClassBalancedDataset',
|
|
||||||
dataset=cp_cfg,
|
|
||||||
oversample_thr=1.,
|
|
||||||
), {'test_mode': True})
|
|
||||||
assert dataset.dataset.test_mode
|
|
||||||
|
|
||||||
def test_kfold_dataset(self):
|
|
||||||
# Test build
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(
|
|
||||||
type='KFoldDataset',
|
|
||||||
dataset=self.dataset_cfg,
|
|
||||||
fold=0,
|
|
||||||
num_splits=5,
|
|
||||||
test_mode=False,
|
|
||||||
))
|
|
||||||
assert isinstance(dataset, KFoldDataset)
|
|
||||||
assert not dataset.test_mode
|
|
||||||
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
|
|
||||||
# Test default_args
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(
|
|
||||||
type='KFoldDataset',
|
|
||||||
dataset=self.dataset_cfg,
|
|
||||||
fold=0,
|
|
||||||
num_splits=5,
|
|
||||||
test_mode=False,
|
|
||||||
),
|
|
||||||
default_args={
|
|
||||||
'test_mode': True,
|
|
||||||
'classes': [1, 2, 3]
|
|
||||||
})
|
|
||||||
assert not dataset.test_mode
|
|
||||||
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
|
|
||||||
assert dataset.dataset.CLASSES == [1, 2, 3]
|
|
||||||
|
|
||||||
cp_cfg = deepcopy(self.dataset_cfg)
|
|
||||||
cp_cfg.pop('test_mode')
|
|
||||||
dataset = build_dataset(
|
|
||||||
dict(
|
|
||||||
type='KFoldDataset',
|
|
||||||
dataset=self.dataset_cfg,
|
|
||||||
fold=0,
|
|
||||||
num_splits=5,
|
|
||||||
),
|
|
||||||
default_args={
|
|
||||||
'test_mode': True,
|
|
||||||
'classes': [1, 2, 3]
|
|
||||||
})
|
|
||||||
# The test_mode in default_args will be passed to KFoldDataset
|
|
||||||
assert dataset.test_mode
|
|
||||||
assert not dataset.dataset.test_mode
|
|
||||||
# Other default_args will be passed to child dataset.
|
|
||||||
assert dataset.dataset.CLASSES == [1, 2, 3]
|
|
|
@ -1,192 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import bisect
|
|
||||||
import math
|
|
||||||
from collections import defaultdict
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
|
|
||||||
KFoldDataset, RepeatDataset)
|
|
||||||
|
|
||||||
|
|
||||||
def mock_evaluate(results,
|
|
||||||
metric='accuracy',
|
|
||||||
metric_options=None,
|
|
||||||
indices=None,
|
|
||||||
logger=None):
|
|
||||||
return dict(
|
|
||||||
results=results,
|
|
||||||
metric=metric,
|
|
||||||
metric_options=metric_options,
|
|
||||||
indices=indices,
|
|
||||||
logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
||||||
def construct_toy_multi_label_dataset(length):
|
|
||||||
BaseDataset.CLASSES = ('foo', 'bar')
|
|
||||||
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
||||||
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
||||||
cat_ids_list = [
|
|
||||||
np.random.randint(0, 80, num).tolist()
|
|
||||||
for num in np.random.randint(1, 20, length)
|
|
||||||
]
|
|
||||||
dataset.data_infos = MagicMock()
|
|
||||||
dataset.data_infos.__len__.return_value = length
|
|
||||||
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
||||||
dataset.get_gt_labels = \
|
|
||||||
MagicMock(side_effect=lambda: np.array(cat_ids_list))
|
|
||||||
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
|
|
||||||
return dataset, cat_ids_list
|
|
||||||
|
|
||||||
|
|
||||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
||||||
def construct_toy_single_label_dataset(length):
|
|
||||||
BaseDataset.CLASSES = ('foo', 'bar')
|
|
||||||
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
||||||
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
||||||
cat_ids_list = [[np.random.randint(0, 80)] for _ in range(length)]
|
|
||||||
dataset.data_infos = MagicMock()
|
|
||||||
dataset.data_infos.__len__.return_value = length
|
|
||||||
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
||||||
dataset.get_gt_labels = \
|
|
||||||
MagicMock(side_effect=lambda: np.array(cat_ids_list))
|
|
||||||
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
|
|
||||||
return dataset, cat_ids_list
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('construct_dataset', [
|
|
||||||
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
||||||
])
|
|
||||||
def test_concat_dataset(construct_dataset):
|
|
||||||
construct_toy_dataset = eval(construct_dataset)
|
|
||||||
dataset_a, cat_ids_list_a = construct_toy_dataset(10)
|
|
||||||
dataset_b, cat_ids_list_b = construct_toy_dataset(20)
|
|
||||||
|
|
||||||
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
|
||||||
assert concat_dataset[5] == 5
|
|
||||||
assert concat_dataset[25] == 15
|
|
||||||
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
|
|
||||||
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
|
|
||||||
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
|
||||||
assert concat_dataset.CLASSES == BaseDataset.CLASSES
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('construct_dataset', [
|
|
||||||
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
||||||
])
|
|
||||||
def test_repeat_dataset(construct_dataset):
|
|
||||||
construct_toy_dataset = eval(construct_dataset)
|
|
||||||
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
||||||
repeat_dataset = RepeatDataset(dataset, 10)
|
|
||||||
assert repeat_dataset[5] == 5
|
|
||||||
assert repeat_dataset[15] == 5
|
|
||||||
assert repeat_dataset[27] == 7
|
|
||||||
assert repeat_dataset.get_cat_ids(5) == cat_ids_list[5]
|
|
||||||
assert repeat_dataset.get_cat_ids(15) == cat_ids_list[5]
|
|
||||||
assert repeat_dataset.get_cat_ids(27) == cat_ids_list[7]
|
|
||||||
assert len(repeat_dataset) == 10 * len(dataset)
|
|
||||||
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('construct_dataset', [
|
|
||||||
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
||||||
])
|
|
||||||
def test_class_balanced_dataset(construct_dataset):
|
|
||||||
construct_toy_dataset = eval(construct_dataset)
|
|
||||||
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
||||||
|
|
||||||
category_freq = defaultdict(int)
|
|
||||||
for cat_ids in cat_ids_list:
|
|
||||||
cat_ids = set(cat_ids)
|
|
||||||
for cat_id in cat_ids:
|
|
||||||
category_freq[cat_id] += 1
|
|
||||||
for k, v in category_freq.items():
|
|
||||||
category_freq[k] = v / len(cat_ids_list)
|
|
||||||
|
|
||||||
mean_freq = np.mean(list(category_freq.values()))
|
|
||||||
repeat_thr = mean_freq
|
|
||||||
|
|
||||||
category_repeat = {
|
|
||||||
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
|
||||||
for cat_id, cat_freq in category_freq.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
repeat_factors = []
|
|
||||||
for cat_ids in cat_ids_list:
|
|
||||||
cat_ids = set(cat_ids)
|
|
||||||
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
|
|
||||||
repeat_factors.append(math.ceil(repeat_factor))
|
|
||||||
repeat_factors_cumsum = np.cumsum(repeat_factors)
|
|
||||||
repeat_factor_dataset = ClassBalancedDataset(dataset, repeat_thr)
|
|
||||||
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
|
|
||||||
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
|
|
||||||
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
|
|
||||||
assert repeat_factor_dataset[idx] == bisect.bisect_right(
|
|
||||||
repeat_factors_cumsum, idx)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('construct_dataset', [
|
|
||||||
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
||||||
])
|
|
||||||
def test_kfold_dataset(construct_dataset):
|
|
||||||
construct_toy_dataset = eval(construct_dataset)
|
|
||||||
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
||||||
|
|
||||||
# test without random seed
|
|
||||||
train_datasets = [
|
|
||||||
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False)
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
test_datasets = [
|
|
||||||
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True)
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert sum([i.indices for i in test_datasets], []) == list(range(10))
|
|
||||||
for train_set, test_set in zip(train_datasets, test_datasets):
|
|
||||||
train_samples = [train_set[i] for i in range(len(train_set))]
|
|
||||||
test_samples = [test_set[i] for i in range(len(test_set))]
|
|
||||||
assert set(train_samples + test_samples) == set(range(10))
|
|
||||||
|
|
||||||
# test with random seed
|
|
||||||
train_datasets = [
|
|
||||||
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False, seed=1)
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
test_datasets = [
|
|
||||||
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True, seed=1)
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert sum([i.indices for i in test_datasets], []) != list(range(10))
|
|
||||||
assert set(sum([i.indices for i in test_datasets], [])) == set(range(10))
|
|
||||||
for train_set, test_set in zip(train_datasets, test_datasets):
|
|
||||||
train_samples = [train_set[i] for i in range(len(train_set))]
|
|
||||||
test_samples = [test_set[i] for i in range(len(test_set))]
|
|
||||||
assert set(train_samples + test_samples) == set(range(10))
|
|
||||||
|
|
||||||
# test behavior of get_cat_ids method
|
|
||||||
for train_set, test_set in zip(train_datasets, test_datasets):
|
|
||||||
for i in range(len(train_set)):
|
|
||||||
cat_ids = train_set.get_cat_ids(i)
|
|
||||||
assert cat_ids == cat_ids_list[train_set.indices[i]]
|
|
||||||
for i in range(len(test_set)):
|
|
||||||
cat_ids = test_set.get_cat_ids(i)
|
|
||||||
assert cat_ids == cat_ids_list[test_set.indices[i]]
|
|
||||||
|
|
||||||
# test behavior of get_gt_labels method
|
|
||||||
for train_set, test_set in zip(train_datasets, test_datasets):
|
|
||||||
for i in range(len(train_set)):
|
|
||||||
gt_label = train_set.get_gt_labels()[i]
|
|
||||||
assert gt_label == cat_ids_list[train_set.indices[i]]
|
|
||||||
for i in range(len(test_set)):
|
|
||||||
gt_label = test_set.get_gt_labels()[i]
|
|
||||||
assert gt_label == cat_ids_list[test_set.indices[i]]
|
|
||||||
|
|
||||||
# test evaluate
|
|
||||||
for test_set in test_datasets:
|
|
||||||
eval_inputs = test_set.evaluate(None)
|
|
||||||
assert eval_inputs['indices'] == test_set.indices
|
|
|
@ -1,53 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mmcls.datasets import BaseDataset, RepeatAugSampler, build_sampler
|
|
||||||
|
|
||||||
|
|
||||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
||||||
def construct_toy_single_label_dataset(length):
|
|
||||||
BaseDataset.CLASSES = ('foo', 'bar')
|
|
||||||
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
||||||
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
||||||
cat_ids_list = [[np.random.randint(0, 80)] for _ in range(length)]
|
|
||||||
dataset.data_infos = MagicMock()
|
|
||||||
dataset.data_infos.__len__.return_value = length
|
|
||||||
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
||||||
return dataset, cat_ids_list
|
|
||||||
|
|
||||||
|
|
||||||
@patch('mmcls.datasets.samplers.repeat_aug.get_dist_info', return_value=(0, 1))
|
|
||||||
def test_sampler_builder(_):
|
|
||||||
assert build_sampler(None) is None
|
|
||||||
dataset = construct_toy_single_label_dataset(1000)[0]
|
|
||||||
build_sampler(dict(type='RepeatAugSampler', dataset=dataset))
|
|
||||||
|
|
||||||
|
|
||||||
@patch('mmcls.datasets.samplers.repeat_aug.get_dist_info', return_value=(0, 1))
|
|
||||||
def test_rep_aug(_):
|
|
||||||
dataset = construct_toy_single_label_dataset(1000)[0]
|
|
||||||
ra = RepeatAugSampler(dataset, selected_round=0, shuffle=False)
|
|
||||||
ra.set_epoch(0)
|
|
||||||
assert len(ra) == 1000
|
|
||||||
ra = RepeatAugSampler(dataset)
|
|
||||||
assert len(ra) == 768
|
|
||||||
val = None
|
|
||||||
for idx, content in enumerate(ra):
|
|
||||||
if idx % 3 == 0:
|
|
||||||
val = content
|
|
||||||
else:
|
|
||||||
assert val is not None
|
|
||||||
assert content == val
|
|
||||||
|
|
||||||
|
|
||||||
@patch('mmcls.datasets.samplers.repeat_aug.get_dist_info', return_value=(0, 2))
|
|
||||||
def test_rep_aug_dist(_):
|
|
||||||
dataset = construct_toy_single_label_dataset(1000)[0]
|
|
||||||
ra = RepeatAugSampler(dataset, selected_round=0, shuffle=False)
|
|
||||||
ra.set_epoch(0)
|
|
||||||
assert len(ra) == 1000 // 2
|
|
||||||
ra = RepeatAugSampler(dataset)
|
|
||||||
assert len(ra) == 768 // 2
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import os.path as osp
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine.data import LabelData
|
||||||
|
|
||||||
|
from mmcls.core import ClsDataSample
|
||||||
|
from mmcls.datasets.pipelines import PackClsInputs
|
||||||
|
|
||||||
|
|
||||||
|
class TestPackClsInputs(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Setup the model and optimizer which are used in every test method.
|
||||||
|
|
||||||
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||||
|
tearDown() -> cleanUp()
|
||||||
|
"""
|
||||||
|
data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
||||||
|
img_path = osp.join(data_prefix, 'color.jpg')
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
self.results1 = {
|
||||||
|
'sample_idx': 1,
|
||||||
|
'img_path': img_path,
|
||||||
|
'ori_height': 300,
|
||||||
|
'ori_width': 400,
|
||||||
|
'height': 600,
|
||||||
|
'width': 800,
|
||||||
|
'scale_factor': 2.0,
|
||||||
|
'flip': False,
|
||||||
|
'img': rng.rand(300, 400),
|
||||||
|
'gt_label': rng.randint(3, )
|
||||||
|
}
|
||||||
|
self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||||
|
'scale_factor', 'flip')
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
transform = PackClsInputs(meta_keys=self.meta_keys)
|
||||||
|
results = transform(copy.deepcopy(self.results1))
|
||||||
|
self.assertIn('inputs', results)
|
||||||
|
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||||
|
self.assertIn('data_sample', results)
|
||||||
|
self.assertIsInstance(results['data_sample'], ClsDataSample)
|
||||||
|
|
||||||
|
data_sample = results['data_sample']
|
||||||
|
self.assertIsInstance(data_sample.gt_label, LabelData)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
transform = PackClsInputs(meta_keys=self.meta_keys)
|
||||||
|
self.assertEqual(
|
||||||
|
repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})')
|
|
@ -1,59 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import copy
|
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mmcls.datasets.pipelines import LoadImageFromFile
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoading(object):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls):
|
|
||||||
cls.data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
|
||||||
|
|
||||||
def test_load_img(self):
|
|
||||||
results = dict(
|
|
||||||
img_prefix=self.data_prefix, img_info=dict(filename='color.jpg'))
|
|
||||||
transform = LoadImageFromFile()
|
|
||||||
results = transform(copy.deepcopy(results))
|
|
||||||
assert results['filename'] == osp.join(self.data_prefix, 'color.jpg')
|
|
||||||
assert results['ori_filename'] == 'color.jpg'
|
|
||||||
assert results['img'].shape == (300, 400, 3)
|
|
||||||
assert results['img'].dtype == np.uint8
|
|
||||||
assert results['img_shape'] == (300, 400, 3)
|
|
||||||
assert results['ori_shape'] == (300, 400, 3)
|
|
||||||
np.testing.assert_equal(results['img_norm_cfg']['mean'],
|
|
||||||
np.zeros(3, dtype=np.float32))
|
|
||||||
assert repr(transform) == transform.__class__.__name__ + \
|
|
||||||
"(to_float32=False, color_type='color', " + \
|
|
||||||
"file_client_args={'backend': 'disk'})"
|
|
||||||
|
|
||||||
# no img_prefix
|
|
||||||
results = dict(
|
|
||||||
img_prefix=None, img_info=dict(filename='tests/data/color.jpg'))
|
|
||||||
transform = LoadImageFromFile()
|
|
||||||
results = transform(copy.deepcopy(results))
|
|
||||||
assert results['filename'] == 'tests/data/color.jpg'
|
|
||||||
assert results['img'].shape == (300, 400, 3)
|
|
||||||
|
|
||||||
# to_float32
|
|
||||||
transform = LoadImageFromFile(to_float32=True)
|
|
||||||
results = transform(copy.deepcopy(results))
|
|
||||||
assert results['img'].dtype == np.float32
|
|
||||||
|
|
||||||
# gray image
|
|
||||||
results = dict(
|
|
||||||
img_prefix=self.data_prefix, img_info=dict(filename='gray.jpg'))
|
|
||||||
transform = LoadImageFromFile()
|
|
||||||
results = transform(copy.deepcopy(results))
|
|
||||||
assert results['img'].shape == (288, 512, 3)
|
|
||||||
assert results['img'].dtype == np.uint8
|
|
||||||
|
|
||||||
transform = LoadImageFromFile(color_type='unchanged')
|
|
||||||
results = transform(copy.deepcopy(results))
|
|
||||||
assert results['img'].shape == (288, 512)
|
|
||||||
assert results['img'].dtype == np.uint8
|
|
||||||
np.testing.assert_equal(results['img_norm_cfg']['mean'],
|
|
||||||
np.zeros(1, dtype=np.float32))
|
|
|
@ -31,204 +31,6 @@ def construct_toy_data():
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def test_resize():
|
|
||||||
# test assertion if size is smaller than 0
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=-1)
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if size is tuple but the second value is smaller than 0
|
|
||||||
# and the second value is not equal to -1
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=(224, -2))
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if size is tuple but the first value is smaller than 0
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=(-1, 224))
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if size is tuple and len(size) < 2
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=(224, ))
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if size is tuple len(size) > 2
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=(224, 224, 3))
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion when interpolation is invalid
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=224, interpolation='2333')
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion when resize_short is invalid
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='Resize', size=224, adaptive_side='False')
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test repr
|
|
||||||
transform = dict(type='Resize', size=224)
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
assert isinstance(repr(resize_module), str)
|
|
||||||
|
|
||||||
# read test image
|
|
||||||
results = dict()
|
|
||||||
img = mmcv.imread(
|
|
||||||
osp.join(osp.dirname(__file__), '../../data/color.jpg'), 'color')
|
|
||||||
original_img = copy.deepcopy(img)
|
|
||||||
results['img'] = img
|
|
||||||
results['img2'] = copy.deepcopy(img)
|
|
||||||
results['img_shape'] = img.shape
|
|
||||||
results['ori_shape'] = img.shape
|
|
||||||
results['img_fields'] = ['img', 'img2']
|
|
||||||
|
|
||||||
def reset_results(results, original_img):
|
|
||||||
results['img'] = copy.deepcopy(original_img)
|
|
||||||
results['img2'] = copy.deepcopy(original_img)
|
|
||||||
results['img_shape'] = original_img.shape
|
|
||||||
results['ori_shape'] = original_img.shape
|
|
||||||
results['img_fields'] = ['img', 'img2']
|
|
||||||
return results
|
|
||||||
|
|
||||||
# test resize when size is int
|
|
||||||
transform = dict(type='Resize', size=224, interpolation='bilinear')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = resize_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (224, 224, 3)
|
|
||||||
|
|
||||||
# test resize when size is tuple and the second value is -1
|
|
||||||
transform = dict(type='Resize', size=(224, -1), interpolation='bilinear')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (224, 298, 3)
|
|
||||||
|
|
||||||
# test resize when size is tuple
|
|
||||||
transform = dict(type='Resize', size=(224, 224), interpolation='bilinear')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (224, 224, 3)
|
|
||||||
|
|
||||||
# test resize when resize_height != resize_width
|
|
||||||
transform = dict(type='Resize', size=(224, 256), interpolation='bilinear')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (224, 256, 3)
|
|
||||||
|
|
||||||
# test resize when size is larger than img.shape
|
|
||||||
img_height, img_width, _ = original_img.shape
|
|
||||||
transform = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(img_height * 2, img_width * 2),
|
|
||||||
interpolation='bilinear')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (img_height * 2, img_width * 2, 3)
|
|
||||||
|
|
||||||
# test resize with different backends
|
|
||||||
transform_cv2 = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(224, 256),
|
|
||||||
interpolation='bilinear',
|
|
||||||
backend='cv2')
|
|
||||||
transform_pil = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(224, 256),
|
|
||||||
interpolation='bilinear',
|
|
||||||
backend='pillow')
|
|
||||||
resize_module_cv2 = build_from_cfg(transform_cv2, PIPELINES)
|
|
||||||
resize_module_pil = build_from_cfg(transform_pil, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results['img_fields'] = ['img']
|
|
||||||
results_cv2 = resize_module_cv2(results)
|
|
||||||
results['img_fields'] = ['img2']
|
|
||||||
results_pil = resize_module_pil(results)
|
|
||||||
assert np.allclose(results_cv2['img'], results_pil['img2'], atol=45)
|
|
||||||
|
|
||||||
# compare results with torchvision
|
|
||||||
transform = dict(type='Resize', size=(224, 224), interpolation='area')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module(results)
|
|
||||||
resize_module = transforms.Resize(
|
|
||||||
size=(224, 224), interpolation=Image.BILINEAR)
|
|
||||||
pil_img = Image.fromarray(original_img)
|
|
||||||
resized_img = resize_module(pil_img)
|
|
||||||
resized_img = np.array(resized_img)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (224, 224, 3)
|
|
||||||
assert np.allclose(results['img'], resized_img, atol=30)
|
|
||||||
|
|
||||||
# test resize when size is tuple, the second value is -1
|
|
||||||
# and adaptive_side='long'
|
|
||||||
transform = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(224, -1),
|
|
||||||
adaptive_side='long',
|
|
||||||
interpolation='bilinear')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (168, 224, 3)
|
|
||||||
|
|
||||||
# test resize when size is tuple, the second value is -1
|
|
||||||
# and adaptive_side='long', h > w
|
|
||||||
transform1 = dict(type='Resize', size=(300, 200), interpolation='bilinear')
|
|
||||||
resize_module1 = build_from_cfg(transform1, PIPELINES)
|
|
||||||
transform2 = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(224, -1),
|
|
||||||
adaptive_side='long',
|
|
||||||
interpolation='bilinear')
|
|
||||||
resize_module2 = build_from_cfg(transform2, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module1(results)
|
|
||||||
results = resize_module2(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (224, 149, 3)
|
|
||||||
|
|
||||||
# test resize when size is tuple, the second value is -1
|
|
||||||
# and adaptive_side='short', h > w
|
|
||||||
transform1 = dict(type='Resize', size=(300, 200), interpolation='bilinear')
|
|
||||||
resize_module1 = build_from_cfg(transform1, PIPELINES)
|
|
||||||
transform2 = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(224, -1),
|
|
||||||
adaptive_side='short',
|
|
||||||
interpolation='bilinear')
|
|
||||||
resize_module2 = build_from_cfg(transform2, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = resize_module1(results)
|
|
||||||
results = resize_module2(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert results['img_shape'] == (336, 224, 3)
|
|
||||||
|
|
||||||
# test interpolation method checking
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(
|
|
||||||
type='Resize', size=(300, 200), backend='cv2', interpolation='box')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(
|
|
||||||
type='Resize',
|
|
||||||
size=(300, 200),
|
|
||||||
backend='pillow',
|
|
||||||
interpolation='area')
|
|
||||||
resize_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pad():
|
def test_pad():
|
||||||
results = dict()
|
results = dict()
|
||||||
img = mmcv.imread(
|
img = mmcv.imread(
|
||||||
|
@ -361,7 +163,7 @@ def test_center_crop():
|
||||||
short_edge = min(*results['ori_shape'][:2])
|
short_edge = min(*results['ori_shape'][:2])
|
||||||
transform = dict(type='CenterCrop', crop_size=short_edge)
|
transform = dict(type='CenterCrop', crop_size=short_edge)
|
||||||
baseline_center_crop_module = build_from_cfg(transform, PIPELINES)
|
baseline_center_crop_module = build_from_cfg(transform, PIPELINES)
|
||||||
transform = dict(type='Resize', size=224)
|
transform = dict(type='Resize', scale=224)
|
||||||
baseline_resize_module = build_from_cfg(transform, PIPELINES)
|
baseline_resize_module = build_from_cfg(transform, PIPELINES)
|
||||||
results = reset_results(results, original_img)
|
results = reset_results(results, original_img)
|
||||||
results = baseline_center_crop_module(results)
|
results = baseline_center_crop_module(results)
|
||||||
|
@ -617,310 +419,6 @@ def test_randomcrop():
|
||||||
assert nonzero == nonzero_transform
|
assert nonzero == nonzero_transform
|
||||||
|
|
||||||
|
|
||||||
def test_randomresizedcrop():
|
|
||||||
ori_img = mmcv.imread(
|
|
||||||
osp.join(osp.dirname(__file__), '../../data/color.jpg'), 'color')
|
|
||||||
ori_img_pil = Image.open(
|
|
||||||
osp.join(osp.dirname(__file__), '../../data/color.jpg'))
|
|
||||||
|
|
||||||
seed = random.randint(0, 100)
|
|
||||||
|
|
||||||
# test when scale is not of kind (min, max)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
kwargs = dict(
|
|
||||||
size=(200, 300), scale=(1.0, 0.08), ratio=(3. / 4., 4. / 3.))
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
composed_transform(results)['img']
|
|
||||||
|
|
||||||
# test when ratio is not of kind (min, max)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
kwargs = dict(
|
|
||||||
size=(200, 300), scale=(0.08, 1.0), ratio=(4. / 3., 3. / 4.))
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
composed_transform(results)['img']
|
|
||||||
|
|
||||||
# test when efficientnet_style is True and crop_padding < 0
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
kwargs = dict(size=200, efficientnet_style=True, crop_padding=-1)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
composed_transform(results)['img']
|
|
||||||
|
|
||||||
# test crop size is int
|
|
||||||
kwargs = dict(size=200, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
baseline = composed_transform(ori_img_pil)
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
|
|
||||||
# test __repr__()
|
|
||||||
print(composed_transform)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert np.array(img).shape == (200, 200, 3)
|
|
||||||
assert np.array(baseline).shape == (200, 200, 3)
|
|
||||||
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
|
||||||
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
|
||||||
assert nonzero == nonzero_transform
|
|
||||||
|
|
||||||
# test crop size < image size
|
|
||||||
kwargs = dict(size=(200, 300), scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
baseline = composed_transform(ori_img_pil)
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert np.array(img).shape == (200, 300, 3)
|
|
||||||
assert np.array(baseline).shape == (200, 300, 3)
|
|
||||||
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
|
||||||
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
|
||||||
assert nonzero == nonzero_transform
|
|
||||||
|
|
||||||
# test crop size < image size when efficientnet_style = True
|
|
||||||
kwargs = dict(
|
|
||||||
size=200,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
ratio=(3. / 4., 4. / 3.),
|
|
||||||
efficientnet_style=True)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert img.shape == (200, 200, 3)
|
|
||||||
|
|
||||||
# test crop size > image size
|
|
||||||
kwargs = dict(size=(600, 700), scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.))
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
baseline = composed_transform(ori_img_pil)
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert np.array(img).shape == (600, 700, 3)
|
|
||||||
assert np.array(baseline).shape == (600, 700, 3)
|
|
||||||
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
|
||||||
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
|
||||||
assert nonzero == nonzero_transform
|
|
||||||
|
|
||||||
# test crop size < image size when efficientnet_style = True
|
|
||||||
kwargs = dict(
|
|
||||||
size=600,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
ratio=(3. / 4., 4. / 3.),
|
|
||||||
efficientnet_style=True)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert img.shape == (600, 600, 3)
|
|
||||||
|
|
||||||
# test cropping the whole image
|
|
||||||
kwargs = dict(
|
|
||||||
size=(ori_img.shape[0], ori_img.shape[1]),
|
|
||||||
scale=(1.0, 2.0),
|
|
||||||
ratio=(1.0, 2.0))
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
baseline = composed_transform(ori_img_pil)
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
|
||||||
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
|
||||||
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
|
||||||
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
|
||||||
assert nonzero == nonzero_transform
|
|
||||||
# assert_array_equal(ori_img, img)
|
|
||||||
# assert_array_equal(np.array(ori_img_pil), np.array(baseline))
|
|
||||||
|
|
||||||
# test central crop when in_ratio < min(ratio)
|
|
||||||
kwargs = dict(
|
|
||||||
size=(ori_img.shape[0], ori_img.shape[1]),
|
|
||||||
scale=(1.0, 2.0),
|
|
||||||
ratio=(2., 3.))
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
baseline = composed_transform(ori_img_pil)
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
|
||||||
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
|
||||||
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
|
||||||
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
|
||||||
assert nonzero == nonzero_transform
|
|
||||||
|
|
||||||
# test central crop when in_ratio > max(ratio)
|
|
||||||
kwargs = dict(
|
|
||||||
size=(ori_img.shape[0], ori_img.shape[1]),
|
|
||||||
scale=(1.0, 2.0),
|
|
||||||
ratio=(3. / 4., 1))
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([torchvision.transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
baseline = composed_transform(ori_img_pil)
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert np.array(img).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
|
||||||
assert np.array(baseline).shape == (ori_img.shape[0], ori_img.shape[1], 3)
|
|
||||||
nonzero = len((ori_img - np.array(ori_img_pil)[:, :, ::-1]).nonzero())
|
|
||||||
nonzero_transform = len((img - np.array(baseline)[:, :, ::-1]).nonzero())
|
|
||||||
assert nonzero == nonzero_transform
|
|
||||||
|
|
||||||
# test central crop when max_attempts = 0 and efficientnet_style = True
|
|
||||||
kwargs = dict(
|
|
||||||
size=200,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
ratio=(3. / 4., 4. / 3.),
|
|
||||||
efficientnet_style=True,
|
|
||||||
max_attempts=0,
|
|
||||||
crop_padding=32)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
|
|
||||||
kwargs = dict(crop_size=200, efficientnet_style=True, crop_padding=32)
|
|
||||||
resize_kwargs = dict(size=200)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.CenterCrop(**kwargs)])
|
|
||||||
aug.extend([mmcls_transforms.Resize(**resize_kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
baseline = composed_transform(results)['img']
|
|
||||||
|
|
||||||
assert img.shape == baseline.shape
|
|
||||||
assert np.equal(img, baseline).all()
|
|
||||||
|
|
||||||
# test central crop when max_attempts = 0 and efficientnet_style = True
|
|
||||||
kwargs = dict(
|
|
||||||
size=200,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
ratio=(3. / 4., 4. / 3.),
|
|
||||||
efficientnet_style=True,
|
|
||||||
max_attempts=100,
|
|
||||||
min_covered=1)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
|
|
||||||
kwargs = dict(crop_size=200, efficientnet_style=True, crop_padding=32)
|
|
||||||
resize_kwargs = dict(size=200)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.CenterCrop(**kwargs)])
|
|
||||||
aug.extend([mmcls_transforms.Resize(**resize_kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
baseline = composed_transform(results)['img']
|
|
||||||
|
|
||||||
assert img.shape == baseline.shape
|
|
||||||
assert np.equal(img, baseline).all()
|
|
||||||
|
|
||||||
# test different interpolation types
|
|
||||||
for mode in ['nearest', 'bilinear', 'bicubic', 'area', 'lanczos']:
|
|
||||||
kwargs = dict(
|
|
||||||
size=(600, 700),
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
ratio=(3. / 4., 4. / 3.),
|
|
||||||
interpolation=mode)
|
|
||||||
aug = []
|
|
||||||
aug.extend([mmcls_transforms.RandomResizedCrop(**kwargs)])
|
|
||||||
composed_transform = Compose(aug)
|
|
||||||
results = dict()
|
|
||||||
results['img'] = ori_img
|
|
||||||
img = composed_transform(results)['img']
|
|
||||||
assert img.shape == (600, 700, 3)
|
|
||||||
|
|
||||||
|
|
||||||
def test_randomgrayscale():
|
def test_randomgrayscale():
|
||||||
|
|
||||||
# test rgb2gray, return the grayscale image with p>1
|
# test rgb2gray, return the grayscale image with p>1
|
||||||
|
@ -978,83 +476,6 @@ def test_randomgrayscale():
|
||||||
assert np.array(img_pil).shape == (10, 10)
|
assert np.array(img_pil).shape == (10, 10)
|
||||||
|
|
||||||
|
|
||||||
def test_randomflip():
|
|
||||||
# test assertion if flip probability is smaller than 0
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='RandomFlip', flip_prob=-1)
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if flip probability is larger than 1
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='RandomFlip', flip_prob=2)
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if direction is not horizontal and vertical
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='RandomFlip', direction='random')
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# test assertion if direction is not lowercase
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
transform = dict(type='RandomFlip', direction='Horizontal')
|
|
||||||
build_from_cfg(transform, PIPELINES)
|
|
||||||
|
|
||||||
# read test image
|
|
||||||
results = dict()
|
|
||||||
img = mmcv.imread(
|
|
||||||
osp.join(osp.dirname(__file__), '../../data/color.jpg'), 'color')
|
|
||||||
original_img = copy.deepcopy(img)
|
|
||||||
results['img'] = img
|
|
||||||
results['img2'] = copy.deepcopy(img)
|
|
||||||
results['img_shape'] = img.shape
|
|
||||||
results['ori_shape'] = img.shape
|
|
||||||
results['img_fields'] = ['img', 'img2']
|
|
||||||
|
|
||||||
def reset_results(results, original_img):
|
|
||||||
results['img'] = copy.deepcopy(original_img)
|
|
||||||
results['img2'] = copy.deepcopy(original_img)
|
|
||||||
results['img_shape'] = original_img.shape
|
|
||||||
results['ori_shape'] = original_img.shape
|
|
||||||
return results
|
|
||||||
|
|
||||||
# test RandomFlip when flip_prob is 0
|
|
||||||
transform = dict(type='RandomFlip', flip_prob=0)
|
|
||||||
flip_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = flip_module(results)
|
|
||||||
assert np.equal(results['img'], original_img).all()
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
|
|
||||||
# test RandomFlip when flip_prob is 1
|
|
||||||
transform = dict(type='RandomFlip', flip_prob=1)
|
|
||||||
flip_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = flip_module(results)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
|
|
||||||
# compare horizontal flip with torchvision
|
|
||||||
transform = dict(type='RandomFlip', flip_prob=1, direction='horizontal')
|
|
||||||
flip_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = flip_module(results)
|
|
||||||
flip_module = transforms.RandomHorizontalFlip(p=1)
|
|
||||||
pil_img = Image.fromarray(original_img)
|
|
||||||
flipped_img = flip_module(pil_img)
|
|
||||||
flipped_img = np.array(flipped_img)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert np.equal(results['img'], flipped_img).all()
|
|
||||||
|
|
||||||
# compare vertical flip with torchvision
|
|
||||||
transform = dict(type='RandomFlip', flip_prob=1, direction='vertical')
|
|
||||||
flip_module = build_from_cfg(transform, PIPELINES)
|
|
||||||
results = reset_results(results, original_img)
|
|
||||||
results = flip_module(results)
|
|
||||||
flip_module = transforms.RandomVerticalFlip(p=1)
|
|
||||||
pil_img = Image.fromarray(original_img)
|
|
||||||
flipped_img = flip_module(pil_img)
|
|
||||||
flipped_img = np.array(flipped_img)
|
|
||||||
assert np.equal(results['img'], results['img2']).all()
|
|
||||||
assert np.equal(results['img'], flipped_img).all()
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_erasing():
|
def test_random_erasing():
|
||||||
# test erase_prob assertion
|
# test erase_prob assertion
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
@ -1269,6 +690,8 @@ def test_albu_transform():
|
||||||
results = dict(
|
results = dict(
|
||||||
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||||
img_info=dict(filename='color.jpg'))
|
img_info=dict(filename='color.jpg'))
|
||||||
|
results['img_path'] = osp.join(results['img_prefix'],
|
||||||
|
results['img_info']['filename'])
|
||||||
|
|
||||||
# Define simple pipeline
|
# Define simple pipeline
|
||||||
load = dict(type='LoadImageFromFile')
|
load = dict(type='LoadImageFromFile')
|
||||||
|
|
Loading…
Reference in New Issue