diff --git a/mmcls/models/utils/batch_augments/cutmix.py b/mmcls/models/utils/batch_augments/cutmix.py index e4b04bdd7..5d5b1ac4f 100644 --- a/mmcls/models/utils/batch_augments/cutmix.py +++ b/mmcls/models/utils/batch_augments/cutmix.py @@ -21,7 +21,7 @@ class CutMix(Mixup): Args: alpha (float): Parameters for Beta distribution to generate the mixing ratio. It should be a positive number. More details - can be found in :class:`BatchMixupLayer`. + can be found in :class:`Mixup`. num_classes (int, optional): The number of classes. If not specified, will try to get it from data samples during training. Defaults to None. @@ -124,8 +124,18 @@ class CutMix(Mixup): lam = 1. - bbox_area / float(img_shape[0] * img_shape[1]) return (yl, yu, xl, xu), lam - def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): - """Mix the batch inputs and batch one-hot format ground truth.""" + def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor): + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ lam = np.random.beta(self.alpha, self.alpha) batch_size = batch_inputs.size(0) img_shape = batch_inputs.shape[-2:] @@ -133,6 +143,6 @@ class CutMix(Mixup): (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2] - mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] - return batch_inputs, mixed_score + return batch_inputs, mixed_scores diff --git a/mmcls/models/utils/batch_augments/mixup.py b/mmcls/models/utils/batch_augments/mixup.py index 4884849bd..a759adba9 100644 --- a/mmcls/models/utils/batch_augments/mixup.py +++ b/mmcls/models/utils/batch_augments/mixup.py @@ -40,16 +40,26 @@ class Mixup: self.alpha = alpha self.num_classes = num_classes - def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): - """Mix the batch inputs and batch one-hot format ground truth.""" + def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor): + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ lam = np.random.beta(self.alpha, self.alpha) batch_size = batch_inputs.size(0) index = torch.randperm(batch_size) mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :] - mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] - return mixed_inputs, mixed_score + return mixed_inputs, mixed_scores def __call__(self, batch_inputs: torch.Tensor, data_samples: List[ClsDataSample]): diff --git a/mmcls/models/utils/batch_augments/resizemix.py b/mmcls/models/utils/batch_augments/resizemix.py index ae7596872..d130c80dd 100644 --- a/mmcls/models/utils/batch_augments/resizemix.py +++ b/mmcls/models/utils/batch_augments/resizemix.py @@ -70,8 +70,18 @@ class ResizeMix(CutMix): self.lam_max = lam_max self.interpolation = interpolation - def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor): - """Mix the batch inputs and batch one-hot format ground truth.""" + def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor): + """Mix the batch inputs and batch one-hot format ground truth. + + Args: + batch_inputs (Tensor): A batch of images tensor in the shape of + ``(N, C, H, W)``. + batch_scores (Tensor): A batch of one-hot format labels in the + shape of ``(N, num_classes)``. + + Returns: + Tuple[Tensor, Tensor): The mixed inputs and labels. + """ lam = np.random.beta(self.alpha, self.alpha) lam = lam * (self.lam_max - self.lam_min) + self.lam_min img_shape = batch_inputs.shape[-2:] @@ -84,6 +94,6 @@ class ResizeMix(CutMix): size=(y2 - y1, x2 - x1), mode=self.interpolation, align_corners=False) - mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :] + mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] - return batch_inputs, mixed_score + return batch_inputs, mixed_scores diff --git a/tests/test_models/test_utils/test_batch_augments.py b/tests/test_models/test_utils/test_batch_augments.py index 2a84b99c9..19f9ea5b3 100644 --- a/tests/test_models/test_utils/test_batch_augments.py +++ b/tests/test_models/test_utils/test_batch_augments.py @@ -9,13 +9,6 @@ from mmcls.core import ClsDataSample from mmcls.models import Mixup, RandomBatchAugment from mmcls.registry import BATCH_AUGMENTS -augment_cfgs = [ - dict(type='BatchCutMix', alpha=1., prob=1.), - dict(type='BatchMixup', alpha=1., prob=1.), - dict(type='Identity', prob=1.), - dict(type='BatchResizeMix', alpha=1., prob=1.) -] - class TestRandomBatchAugment(TestCase):