Imporve according to comments

pull/913/head
mzr1996 2022-06-13 11:05:15 +08:00
parent f0cab33e09
commit 375fe68f12
4 changed files with 43 additions and 20 deletions

View File

@ -21,7 +21,7 @@ class CutMix(Mixup):
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`BatchMixupLayer`.
can be found in :class:`Mixup`.
num_classes (int, optional): The number of classes. If not specified,
will try to get it from data samples during training.
Defaults to None.
@ -124,8 +124,18 @@ class CutMix(Mixup):
lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
return (yl, yu, xl, xu), lam
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth.
Args:
batch_inputs (Tensor): A batch of images tensor in the shape of
``(N, C, H, W)``.
batch_scores (Tensor): A batch of one-hot format labels in the
shape of ``(N, num_classes)``.
Returns:
Tuple[Tensor, Tensor): The mixed inputs and labels.
"""
lam = np.random.beta(self.alpha, self.alpha)
batch_size = batch_inputs.size(0)
img_shape = batch_inputs.shape[-2:]
@ -133,6 +143,6 @@ class CutMix(Mixup):
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2]
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
return batch_inputs, mixed_score
return batch_inputs, mixed_scores

View File

@ -40,16 +40,26 @@ class Mixup:
self.alpha = alpha
self.num_classes = num_classes
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth.
Args:
batch_inputs (Tensor): A batch of images tensor in the shape of
``(N, C, H, W)``.
batch_scores (Tensor): A batch of one-hot format labels in the
shape of ``(N, num_classes)``.
Returns:
Tuple[Tensor, Tensor): The mixed inputs and labels.
"""
lam = np.random.beta(self.alpha, self.alpha)
batch_size = batch_inputs.size(0)
index = torch.randperm(batch_size)
mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :]
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
return mixed_inputs, mixed_score
return mixed_inputs, mixed_scores
def __call__(self, batch_inputs: torch.Tensor,
data_samples: List[ClsDataSample]):

View File

@ -70,8 +70,18 @@ class ResizeMix(CutMix):
self.lam_max = lam_max
self.interpolation = interpolation
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth."""
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
"""Mix the batch inputs and batch one-hot format ground truth.
Args:
batch_inputs (Tensor): A batch of images tensor in the shape of
``(N, C, H, W)``.
batch_scores (Tensor): A batch of one-hot format labels in the
shape of ``(N, num_classes)``.
Returns:
Tuple[Tensor, Tensor): The mixed inputs and labels.
"""
lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
img_shape = batch_inputs.shape[-2:]
@ -84,6 +94,6 @@ class ResizeMix(CutMix):
size=(y2 - y1, x2 - x1),
mode=self.interpolation,
align_corners=False)
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
return batch_inputs, mixed_score
return batch_inputs, mixed_scores

View File

@ -9,13 +9,6 @@ from mmcls.core import ClsDataSample
from mmcls.models import Mixup, RandomBatchAugment
from mmcls.registry import BATCH_AUGMENTS
augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
dict(type='BatchResizeMix', alpha=1., prob=1.)
]
class TestRandomBatchAugment(TestCase):