75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Callable, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmpretrain.registry import BATCH_AUGMENTS
|
|
|
|
|
|
class RandomBatchAugment:
|
|
"""Randomly choose one batch augmentation to apply.
|
|
|
|
Args:
|
|
augments (Callable | dict | list): configs of batch
|
|
augmentations.
|
|
probs (float | List[float] | None): The probabilities of each batch
|
|
augmentations. If None, choose evenly. Defaults to None.
|
|
|
|
Example:
|
|
>>> import torch
|
|
>>> import torch.nn.functional as F
|
|
>>> from mmpretrain.models import RandomBatchAugment
|
|
>>> augments_cfg = [
|
|
... dict(type='CutMix', alpha=1.),
|
|
... dict(type='Mixup', alpha=1.)
|
|
... ]
|
|
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
|
|
>>> imgs = torch.rand(16, 3, 32, 32)
|
|
>>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10)
|
|
>>> imgs, label = batch_augment(imgs, label)
|
|
|
|
.. note ::
|
|
|
|
To decide which batch augmentation will be used, it picks one of
|
|
``augments`` based on the probabilities. In the example above, the
|
|
probability to use CutMix is 0.5, to use Mixup is 0.3, and to do
|
|
nothing is 0.2.
|
|
"""
|
|
|
|
def __init__(self, augments: Union[Callable, dict, list], probs=None):
|
|
if not isinstance(augments, (tuple, list)):
|
|
augments = [augments]
|
|
|
|
self.augments = []
|
|
for aug in augments:
|
|
if isinstance(aug, dict):
|
|
self.augments.append(BATCH_AUGMENTS.build(aug))
|
|
else:
|
|
self.augments.append(aug)
|
|
|
|
if isinstance(probs, float):
|
|
probs = [probs]
|
|
|
|
if probs is not None:
|
|
assert len(augments) == len(probs), \
|
|
'``augments`` and ``probs`` must have same lengths. ' \
|
|
f'Got {len(augments)} vs {len(probs)}.'
|
|
assert sum(probs) <= 1, \
|
|
'The total probability of batch augments exceeds 1.'
|
|
self.augments.append(None)
|
|
probs.append(1 - sum(probs))
|
|
|
|
self.probs = probs
|
|
|
|
def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor):
|
|
"""Randomly apply batch augmentations to the batch inputs and batch
|
|
data samples."""
|
|
aug_index = np.random.choice(len(self.augments), p=self.probs)
|
|
aug = self.augments[aug_index]
|
|
|
|
if aug is not None:
|
|
return aug(batch_input, batch_score)
|
|
else:
|
|
return batch_input, batch_score.float()
|