More flexible mixup mode, add 'half' mode.
parent
532e3b417d
commit
47a7b3b5b1
|
@ -96,13 +96,13 @@ class Mixup:
|
|||
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
||||
prob (float): probability of applying mixup or cutmix per batch or element
|
||||
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
||||
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
|
||||
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
||||
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
||||
label_smoothing (float): apply label smoothing to the mixed target tensor
|
||||
num_classes (int): number of classes for target
|
||||
"""
|
||||
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
|
||||
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
||||
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
||||
self.mixup_alpha = mixup_alpha
|
||||
self.cutmix_alpha = cutmix_alpha
|
||||
self.cutmix_minmax = cutmix_minmax
|
||||
|
@ -114,7 +114,7 @@ class Mixup:
|
|||
self.switch_prob = switch_prob
|
||||
self.label_smoothing = label_smoothing
|
||||
self.num_classes = num_classes
|
||||
self.elementwise = elementwise
|
||||
self.mode = mode
|
||||
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
||||
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
||||
|
||||
|
@ -173,6 +173,26 @@ class Mixup:
|
|||
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
||||
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
|
||||
def _mix_pair(self, x):
|
||||
batch_size = len(x)
|
||||
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
||||
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
||||
for i in range(batch_size // 2):
|
||||
j = batch_size - i - 1
|
||||
lam = lam_batch[i]
|
||||
if lam != 1.:
|
||||
if use_cutmix[i]:
|
||||
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
||||
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
||||
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
||||
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
|
||||
lam_batch[i] = lam
|
||||
else:
|
||||
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
||||
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
|
||||
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
||||
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
|
||||
def _mix_batch(self, x):
|
||||
lam, use_cutmix = self._params_per_batch()
|
||||
if lam == 1.:
|
||||
|
@ -188,7 +208,12 @@ class Mixup:
|
|||
|
||||
def __call__(self, x, target):
|
||||
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
||||
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
|
||||
if self.mode == 'elem':
|
||||
lam = self._mix_elem(x)
|
||||
elif self.mode == 'pair':
|
||||
lam = self._mix_pair(x)
|
||||
else:
|
||||
lam = self._mix_batch(x)
|
||||
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
|
||||
return x, target
|
||||
|
||||
|
@ -199,25 +224,57 @@ class FastCollateMixup(Mixup):
|
|||
A Mixup impl that's performed while collating the batches.
|
||||
"""
|
||||
|
||||
def _mix_elem_collate(self, output, batch):
|
||||
def _mix_elem_collate(self, output, batch, half=False):
|
||||
batch_size = len(batch)
|
||||
lam_batch, use_cutmix = self._params_per_elem(batch_size)
|
||||
for i in range(batch_size):
|
||||
num_elem = batch_size // 2 if half else batch_size
|
||||
assert len(output) == num_elem
|
||||
lam_batch, use_cutmix = self._params_per_elem(num_elem)
|
||||
for i in range(num_elem):
|
||||
j = batch_size - i - 1
|
||||
lam = lam_batch[i]
|
||||
mixed = batch[i][0]
|
||||
if lam != 1.:
|
||||
if use_cutmix[i]:
|
||||
mixed = mixed.copy()
|
||||
if not half:
|
||||
mixed = mixed.copy()
|
||||
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
||||
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
||||
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
||||
lam_batch[i] = lam
|
||||
else:
|
||||
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
||||
lam_batch[i] = lam
|
||||
np.round(mixed, out=mixed)
|
||||
np.rint(mixed, out=mixed)
|
||||
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
||||
if half:
|
||||
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
|
||||
return torch.tensor(lam_batch).unsqueeze(1)
|
||||
|
||||
def _mix_pair_collate(self, output, batch):
|
||||
batch_size = len(batch)
|
||||
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
||||
for i in range(batch_size // 2):
|
||||
j = batch_size - i - 1
|
||||
lam = lam_batch[i]
|
||||
mixed_i = batch[i][0]
|
||||
mixed_j = batch[j][0]
|
||||
assert 0 <= lam <= 1.0
|
||||
if lam < 1.:
|
||||
if use_cutmix[i]:
|
||||
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
||||
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
||||
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
|
||||
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
|
||||
mixed_j[:, yl:yh, xl:xh] = patch_i
|
||||
lam_batch[i] = lam
|
||||
else:
|
||||
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
|
||||
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
|
||||
mixed_i = mixed_temp
|
||||
np.rint(mixed_j, out=mixed_j)
|
||||
np.rint(mixed_i, out=mixed_i)
|
||||
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
||||
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
||||
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
||||
return torch.tensor(lam_batch).unsqueeze(1)
|
||||
|
||||
def _mix_batch_collate(self, output, batch):
|
||||
|
@ -235,19 +292,25 @@ class FastCollateMixup(Mixup):
|
|||
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
||||
else:
|
||||
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
||||
np.round(mixed, out=mixed)
|
||||
np.rint(mixed, out=mixed)
|
||||
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
||||
return lam
|
||||
|
||||
def __call__(self, batch, _=None):
|
||||
batch_size = len(batch)
|
||||
assert batch_size % 2 == 0, 'Batch size should be even when using this'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
if self.elementwise:
|
||||
lam = self._mix_elem_collate(output, batch)
|
||||
if self.mode == 'elem' or self.mode == 'half':
|
||||
lam = self._mix_elem_collate(output, batch, half=half)
|
||||
elif self.mode == 'pair':
|
||||
lam = self._mix_pair_collate(output, batch)
|
||||
else:
|
||||
lam = self._mix_batch_collate(output, batch)
|
||||
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
||||
target = target[:batch_size]
|
||||
return output, target
|
||||
|
||||
|
|
6
train.py
6
train.py
|
@ -176,8 +176,8 @@ parser.add_argument('--mixup-prob', type=float, default=1.0,
|
|||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
||||
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
||||
parser.add_argument('--mixup-elem', action='store_true', default=False,
|
||||
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
|
||||
parser.add_argument('--mixup-mode', type=str, default='batch',
|
||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
|
@ -444,7 +444,7 @@ def main():
|
|||
if mixup_active:
|
||||
mixup_args = dict(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.num_classes)
|
||||
if args.prefetcher:
|
||||
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
|
||||
|
|
Loading…
Reference in New Issue