mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* Refactor unit tests folder structure. * Remove label smooth and Vit test in `test_classifiers.py` * Rename test_utils in dataset to test_dataset_utils * Split test_models/test_utils/test_utils.py to multiple sub files. * Add unit tests of classifiers and heads * Use patch context manager. * Add unit test of `is_tracing`, and add warning in `is_tracing` if torch verison is smaller than 1.6.0
52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
import torch
|
|
|
|
from mmcls.models.utils import Augments
|
|
|
|
|
|
def test_augments():
|
|
imgs = torch.randn(4, 3, 32, 32)
|
|
labels = torch.randint(0, 10, (4, ))
|
|
|
|
# Test cutmix
|
|
augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test mixup
|
|
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test cutmixup
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
|
|
dict(type='Identity', num_classes=10, prob=0.2)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|