mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some cutmix/mixup cleanup/fixes
This commit is contained in:
parent
b3cb5f3275
commit
670c61b28f
@ -14,6 +14,7 @@ Hacked together by Ross Wightman
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import numbers
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
@ -49,9 +50,17 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
|
|||||||
return input, target
|
return input, target
|
||||||
|
|
||||||
|
|
||||||
|
def calc_ratio(lam, minmax=None):
|
||||||
|
ratio = math.sqrt(1 - lam)
|
||||||
|
if minmax is not None:
|
||||||
|
if isinstance(minmax, numbers.Number):
|
||||||
|
minmax = (minmax, 1 - minmax)
|
||||||
|
ratio = np.clip(ratio, minmax[0], minmax[1])
|
||||||
|
return ratio
|
||||||
|
|
||||||
|
|
||||||
def rand_bbox(size, ratio):
|
def rand_bbox(size, ratio):
|
||||||
H, W = size[-2:]
|
H, W = size[-2:]
|
||||||
ratio = max(min(ratio, 0.8), 0.2)
|
|
||||||
cut_h, cut_w = int(H * ratio), int(W * ratio)
|
cut_h, cut_w = int(H * ratio), int(W * ratio)
|
||||||
cy, cx = np.random.randint(H), np.random.randint(W)
|
cy, cx = np.random.randint(H), np.random.randint(W)
|
||||||
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
|
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
|
||||||
@ -59,14 +68,15 @@ def rand_bbox(size, ratio):
|
|||||||
return yl, yh, xl, xh
|
return yl, yh, xl, xh
|
||||||
|
|
||||||
|
|
||||||
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
|
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False):
|
||||||
lam = 1.
|
lam = 1.
|
||||||
if not disable:
|
if not disable:
|
||||||
lam = np.random.beta(alpha, alpha)
|
lam = np.random.beta(alpha, alpha)
|
||||||
if lam != 1:
|
if lam != 1:
|
||||||
ratio = math.sqrt(1. - lam)
|
yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam))
|
||||||
yl, yh, xl, xh = rand_bbox(input.size(), ratio)
|
|
||||||
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
|
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
|
||||||
|
if correct_lam:
|
||||||
|
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1])
|
||||||
target = mixup_target(target, num_classes, lam, smoothing)
|
target = mixup_target(target, num_classes, lam, smoothing)
|
||||||
return input, target
|
return input, target
|
||||||
|
|
||||||
@ -82,9 +92,9 @@ def mix_batch(
|
|||||||
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP):
|
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP):
|
||||||
mode = _resolve_mode(mode)
|
mode = _resolve_mode(mode)
|
||||||
if mode == MixupMode.CUTMIX:
|
if mode == MixupMode.CUTMIX:
|
||||||
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
|
|
||||||
else:
|
|
||||||
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable)
|
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable)
|
||||||
|
else:
|
||||||
|
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
|
||||||
|
|
||||||
|
|
||||||
class FastCollateMixup:
|
class FastCollateMixup:
|
||||||
@ -99,6 +109,7 @@ class FastCollateMixup:
|
|||||||
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
|
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
|
||||||
self.mixup_enabled = True
|
self.mixup_enabled = True
|
||||||
self.correct_lam = False # correct lambda based on clipped area for cutmix
|
self.correct_lam = False # correct lambda based on clipped area for cutmix
|
||||||
|
self.ratio_minmax = None # (0.2, 0.8)
|
||||||
|
|
||||||
def _do_mix(self, tensor, batch):
|
def _do_mix(self, tensor, batch):
|
||||||
batch_size = len(batch)
|
batch_size = len(batch)
|
||||||
@ -111,7 +122,7 @@ class FastCollateMixup:
|
|||||||
|
|
||||||
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
|
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
|
||||||
mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32)
|
mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32)
|
||||||
ratio = math.sqrt(1. - lam)
|
ratio = calc_ratio(lam, self.ratio_minmax)
|
||||||
if lam != 1:
|
if lam != 1:
|
||||||
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
|
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
|
||||||
mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
|
mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
|
||||||
@ -132,7 +143,7 @@ class FastCollateMixup:
|
|||||||
np.round(mixed_j, out=mixed_j)
|
np.round(mixed_j, out=mixed_j)
|
||||||
tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
||||||
tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
||||||
return lam_out
|
return lam_out.unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
batch_size = len(batch)
|
batch_size = len(batch)
|
||||||
@ -140,7 +151,7 @@ class FastCollateMixup:
|
|||||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
lam = self._do_mix(tensor, batch)
|
lam = self._do_mix(tensor, batch)
|
||||||
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu')
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
||||||
|
|
||||||
return tensor, target
|
return tensor, target
|
||||||
|
|
||||||
@ -157,27 +168,27 @@ class FastCollateMixupElementwise(FastCollateMixup):
|
|||||||
batch_size = len(batch)
|
batch_size = len(batch)
|
||||||
lam_out = torch.ones(batch_size)
|
lam_out = torch.ones(batch_size)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
j = batch_size - i - 1
|
||||||
lam = 1.
|
lam = 1.
|
||||||
if self.mixup_enabled:
|
if self.mixup_enabled:
|
||||||
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||||
|
|
||||||
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
|
if _resolve_mode(self.mode) == MixupMode.CUTMIX:
|
||||||
mixed = batch[i][0].astype(np.float32)
|
mixed = batch[i][0].astype(np.float32)
|
||||||
ratio = math.sqrt(1. - lam)
|
|
||||||
if lam != 1:
|
if lam != 1:
|
||||||
|
ratio = calc_ratio(lam)
|
||||||
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
|
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
|
||||||
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32)
|
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
|
||||||
if self.correct_lam:
|
if self.correct_lam:
|
||||||
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
|
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
|
||||||
else:
|
else:
|
||||||
lam_out[i] = lam
|
lam_out[i] = lam
|
||||||
else:
|
else:
|
||||||
mixed = batch[i][0].astype(np.float32) * lam + \
|
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
||||||
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
|
|
||||||
lam_out[i] = lam
|
lam_out[i] = lam
|
||||||
np.round(mixed, out=mixed)
|
np.round(mixed, out=mixed)
|
||||||
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
|
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
|
||||||
return lam_out
|
return lam_out.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class FastCollateMixupBatchwise(FastCollateMixup):
|
class FastCollateMixupBatchwise(FastCollateMixup):
|
||||||
@ -191,25 +202,23 @@ class FastCollateMixupBatchwise(FastCollateMixup):
|
|||||||
|
|
||||||
def _do_mix(self, tensor, batch):
|
def _do_mix(self, tensor, batch):
|
||||||
batch_size = len(batch)
|
batch_size = len(batch)
|
||||||
lam_out = torch.ones(batch_size)
|
|
||||||
lam = 1.
|
lam = 1.
|
||||||
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
|
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
|
||||||
if self.mixup_enabled:
|
if self.mixup_enabled:
|
||||||
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||||
if cutmix and self.correct_lam:
|
if cutmix:
|
||||||
ratio = math.sqrt(1. - lam)
|
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam))
|
||||||
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio)
|
if self.correct_lam:
|
||||||
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
|
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
j = batch_size - i - 1
|
||||||
if cutmix:
|
if cutmix:
|
||||||
mixed = batch[i][0].astype(np.float32)
|
mixed = batch[i][0].astype(np.float32)
|
||||||
if lam != 1:
|
if lam != 1:
|
||||||
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32)
|
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
|
||||||
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
|
|
||||||
else:
|
else:
|
||||||
mixed = batch[i][0].astype(np.float32) * lam + \
|
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
||||||
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
|
|
||||||
np.round(mixed, out=mixed)
|
np.round(mixed, out=mixed)
|
||||||
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
|
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
|
||||||
return lam
|
return lam
|
||||||
|
Loading…
x
Reference in New Issue
Block a user