Imporve according to comments
parent
f0cab33e09
commit
375fe68f12
|
@ -21,7 +21,7 @@ class CutMix(Mixup):
|
|||
Args:
|
||||
alpha (float): Parameters for Beta distribution to generate the
|
||||
mixing ratio. It should be a positive number. More details
|
||||
can be found in :class:`BatchMixupLayer`.
|
||||
can be found in :class:`Mixup`.
|
||||
num_classes (int, optional): The number of classes. If not specified,
|
||||
will try to get it from data samples during training.
|
||||
Defaults to None.
|
||||
|
@ -124,8 +124,18 @@ class CutMix(Mixup):
|
|||
lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
|
||||
return (yl, yu, xl, xu), lam
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth."""
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): A batch of images tensor in the shape of
|
||||
``(N, C, H, W)``.
|
||||
batch_scores (Tensor): A batch of one-hot format labels in the
|
||||
shape of ``(N, num_classes)``.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor): The mixed inputs and labels.
|
||||
"""
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = batch_inputs.size(0)
|
||||
img_shape = batch_inputs.shape[-2:]
|
||||
|
@ -133,6 +143,6 @@ class CutMix(Mixup):
|
|||
|
||||
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
|
||||
batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2]
|
||||
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
|
||||
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
|
||||
|
||||
return batch_inputs, mixed_score
|
||||
return batch_inputs, mixed_scores
|
||||
|
|
|
@ -40,16 +40,26 @@ class Mixup:
|
|||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth."""
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): A batch of images tensor in the shape of
|
||||
``(N, C, H, W)``.
|
||||
batch_scores (Tensor): A batch of one-hot format labels in the
|
||||
shape of ``(N, num_classes)``.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor): The mixed inputs and labels.
|
||||
"""
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = batch_inputs.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
|
||||
mixed_inputs = lam * batch_inputs + (1 - lam) * batch_inputs[index, :]
|
||||
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
|
||||
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
|
||||
|
||||
return mixed_inputs, mixed_score
|
||||
return mixed_inputs, mixed_scores
|
||||
|
||||
def __call__(self, batch_inputs: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]):
|
||||
|
|
|
@ -70,8 +70,18 @@ class ResizeMix(CutMix):
|
|||
self.lam_max = lam_max
|
||||
self.interpolation = interpolation
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_score: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth."""
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
|
||||
"""Mix the batch inputs and batch one-hot format ground truth.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): A batch of images tensor in the shape of
|
||||
``(N, C, H, W)``.
|
||||
batch_scores (Tensor): A batch of one-hot format labels in the
|
||||
shape of ``(N, num_classes)``.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor): The mixed inputs and labels.
|
||||
"""
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
|
||||
img_shape = batch_inputs.shape[-2:]
|
||||
|
@ -84,6 +94,6 @@ class ResizeMix(CutMix):
|
|||
size=(y2 - y1, x2 - x1),
|
||||
mode=self.interpolation,
|
||||
align_corners=False)
|
||||
mixed_score = lam * batch_score + (1 - lam) * batch_score[index, :]
|
||||
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
|
||||
|
||||
return batch_inputs, mixed_score
|
||||
return batch_inputs, mixed_scores
|
||||
|
|
|
@ -9,13 +9,6 @@ from mmcls.core import ClsDataSample
|
|||
from mmcls.models import Mixup, RandomBatchAugment
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
|
||||
augment_cfgs = [
|
||||
dict(type='BatchCutMix', alpha=1., prob=1.),
|
||||
dict(type='BatchMixup', alpha=1., prob=1.),
|
||||
dict(type='Identity', prob=1.),
|
||||
dict(type='BatchResizeMix', alpha=1., prob=1.)
|
||||
]
|
||||
|
||||
|
||||
class TestRandomBatchAugment(TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue