72 lines
2.4 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Union
import numpy as np
import torch
from mmcls.registry import BATCH_AUGMENTS
class RandomBatchAugment:
"""Randomly choose one batch augmentation to apply.
Args:
augments (Callable | dict | list): configs of batch
augmentations.
probs (float | List[float] | None): The probabilities of each batch
augmentations. If None, choose evenly. Defaults to None.
Example:
>>> augments_cfg = [
... dict(type='CutMix', alpha=1., num_classes=10),
... dict(type='Mixup', alpha=1., num_classes=10)
... ]
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
>>> imgs = torch.randn(16, 3, 32, 32)
>>> label = torch.randint(0, 10, (16, ))
>>> imgs, label = batch_augment(imgs, label)
.. note ::
To decide which batch augmentation will be used, it picks one of
``augments`` based on the probabilities. In the example above, the
probability to use CutMix is 0.5, to use Mixup is 0.3, and to do
nothing is 0.2.
"""
def __init__(self, augments: Union[Callable, dict, list], probs=None):
if not isinstance(augments, (tuple, list)):
augments = [augments]
self.augments = []
for aug in augments:
if isinstance(aug, dict):
self.augments.append(BATCH_AUGMENTS.build(aug))
else:
self.augments.append(aug)
if isinstance(probs, float):
probs = [probs]
if probs is not None:
assert len(augments) == len(probs), \
'``augments`` and ``probs`` must have same lengths. ' \
f'Got {len(augments)} vs {len(probs)}.'
assert sum(probs) <= 1, \
'The total probability of batch augments exceeds 1.'
self.augments.append(None)
probs.append(1 - sum(probs))
self.probs = probs
def __call__(self, inputs: torch.Tensor, data_samples: Union[list, None]):
"""Randomly apply batch augmentations to the batch inputs and batch
data samples."""
aug_index = np.random.choice(len(self.augments), p=self.probs)
aug = self.augments[aug_index]
if aug is not None:
return aug(inputs, data_samples)
else:
return inputs, data_samples