mmclassification/tests/test_models/test_utils/test_batch_augments.py

247 lines
9.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from mmcls.core import ClsDataSample
from mmcls.models import Mixup, RandomBatchAugment
from mmcls.registry import BATCH_AUGMENTS
augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
]
class TestRandomBatchAugment(TestCase):
def test_initialize(self):
# test single augmentation
augments = dict(type='Mixup', alpha=1.)
batch_augments = RandomBatchAugment(augments)
self.assertIsInstance(batch_augments.augments, list)
self.assertEqual(len(batch_augments.augments), 1)
# test specify augments with object
augments = Mixup(alpha=1.)
batch_augments = RandomBatchAugment(augments)
self.assertIsInstance(batch_augments.augments, list)
self.assertEqual(len(batch_augments.augments), 1)
# test multiple augmentation
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments)
# mixup, cutmix
self.assertEqual(len(batch_augments.augments), 2)
self.assertIsNone(batch_augments.probs)
# test specify probs
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3])
# mixup, cutmix and None
self.assertEqual(len(batch_augments.augments), 3)
self.assertAlmostEqual(batch_augments.probs[-1], 0.2)
# test assertion
with self.assertRaisesRegex(AssertionError, 'Got 2 vs 1'):
RandomBatchAugment(augments, probs=0.5)
with self.assertRaisesRegex(AssertionError, 'exceeds 1.'):
RandomBatchAugment(augments, probs=[0.5, 0.6])
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.5, 0.3])
with patch('numpy.random', np.random.RandomState(0)):
batch_augments.augments[1] = MagicMock()
batch_augments(inputs, data_samples)
batch_augments.augments[1].assert_called_once_with(
inputs, data_samples)
augments = [
dict(type='Mixup', alpha=1.),
dict(type='CutMix', alpha=0.8),
]
batch_augments = RandomBatchAugment(augments, probs=[0.0, 0.0])
mixed_inputs, mixed_samples = batch_augments(inputs, data_samples)
self.assertIs(mixed_inputs, inputs)
self.assertIs(mixed_samples, data_samples)
class TestMixup(TestCase):
DEFAULT_ARGS = dict(type='Mixup', alpha=1.)
def test_initialize(self):
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
# test get num_classes from data_samples
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
mixup(inputs, data_samples)
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
class TestCutMix(TestCase):
DEFAULT_ARGS = dict(type='CutMix', alpha=1.)
def test_initialize(self):
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
# test with cutmix_minmax
cfg = {**self.DEFAULT_ARGS, 'cutmix_minmax': (0.1, 0.2)}
cutmix = BATCH_AUGMENTS.build(cfg)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
# test without correct_lam
cfg = {**self.DEFAULT_ARGS, 'correct_lam': False}
cutmix = BATCH_AUGMENTS.build(cfg)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
# test get num_classes from data_samples
cutmix = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
cutmix(inputs, data_samples)
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
cutmix = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
cutmix = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = cutmix(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))
class TestResizeMix(TestCase):
DEFAULT_ARGS = dict(type='ResizeMix', alpha=1.)
def test_initialize(self):
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'alpha': 'unknown'}
BATCH_AUGMENTS.build(cfg)
with self.assertRaises(AssertionError):
cfg = {**self.DEFAULT_ARGS, 'num_classes': 'unknown'}
BATCH_AUGMENTS.build(cfg)
def test_call(self):
inputs = torch.rand(2, 3, 224, 224)
data_samples = [
ClsDataSample(metainfo={
'num_classes': 10
}).set_gt_label(1) for _ in range(2)
]
# test get num_classes from data_samples
mixup = BATCH_AUGMENTS.build(self.DEFAULT_ARGS)
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (10, ))
with self.assertRaisesRegex(RuntimeError, 'Not specify'):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(2)]
mixup(inputs, data_samples)
# test binary classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 1}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (1, ))
# test multi-label classification
cfg = {**self.DEFAULT_ARGS, 'num_classes': 5}
mixup = BATCH_AUGMENTS.build(cfg)
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(2)]
mixed_inputs, mixed_samples = mixup(inputs, data_samples)
self.assertEqual(mixed_inputs.shape, (2, 3, 224, 224))
self.assertEqual(mixed_samples[0].gt_label.score.shape, (5, ))