mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add resizemix * skip torch.__version__ < 1.7.0 * Update mmcls/models/utils/augment/resizemix.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Update mmcls/models/utils/augment/resizemix.py Co-authored-by: Ma Zerun <mzr1996@163.com> * resize -> F.interpolate * fix docs * fix test * add Copyright * add argument interpolation Co-authored-by: Ma Zerun <mzr1996@163.com>
97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmcls.models.utils import Augments
|
|
|
|
augment_cfgs = [
|
|
dict(type='BatchCutMix', alpha=1., prob=1.),
|
|
dict(type='BatchMixup', alpha=1., prob=1.),
|
|
dict(type='Identity', prob=1.),
|
|
dict(type='BatchResizeMix', alpha=1., prob=1.)
|
|
]
|
|
|
|
|
|
def test_augments():
|
|
imgs = torch.randn(4, 3, 32, 32)
|
|
labels = torch.randint(0, 10, (4, ))
|
|
|
|
# Test cutmix
|
|
augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test mixup
|
|
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test resizemix
|
|
augments_cfg = dict(
|
|
type='BatchResizeMix', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test cutmixup
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
|
|
dict(type='Identity', num_classes=10, prob=0.2)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
|
|
@pytest.mark.parametrize('cfg', augment_cfgs)
|
|
def test_binary_augment(cfg):
|
|
|
|
cfg_ = dict(num_classes=1, **cfg)
|
|
augs = Augments(cfg_)
|
|
|
|
imgs = torch.randn(4, 3, 32, 32)
|
|
labels = torch.randint(0, 2, (4, 1)).float()
|
|
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 1))
|
|
|
|
|
|
@pytest.mark.parametrize('cfg', augment_cfgs)
|
|
def test_multilabel_augment(cfg):
|
|
|
|
cfg_ = dict(num_classes=10, **cfg)
|
|
augs = Augments(cfg_)
|
|
|
|
imgs = torch.randn(4, 3, 32, 32)
|
|
labels = torch.randint(0, 2, (4, 10)).float()
|
|
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|