pytorch-image-models/timm/data/naflex_mixup.py

234 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Variablesize Mixup / CutMix utilities for NaFlex data loaders.
This module provides:
* `mix_batch_variable_size` pixellevel Mixup/CutMix that operates on a
list of images whose spatial sizes differ, mixing only their central overlap
so no resizing is required.
* `pairwise_mixup_target` builds softlabel targets that exactly match the
persample pixel provenance produced by the mixer.
* `NaFlexMixup` a callable functor that wraps the two helpers and stores
all augmentation hyperparameters in one place, making it easy to plug into
different dataset wrappers.
"""
import math
import random
from typing import Dict, List, Tuple
import torch
def mix_batch_variable_size(
imgs: List[torch.Tensor],
*,
mixup_alpha: float = 0.8,
cutmix_alpha: float = 1.0,
switch_prob: float = 0.5,
local_shuffle: int = 4,
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]:
"""Apply Mixup or CutMix on a batch of variablesized images.
The function first sorts images by aspect ratio and pairs neighbouring
samples (optionally shuffling within small windows so pairs vary between
epochs). Only the mutual centraloverlap region of each pair is mixed
Args:
imgs: List of transformed images shaped (C, H, W). Heights and
widths may differ between samples.
mixup_alpha: Betadistribution *α* for Mixup. Set to 0 to disable Mixup.
cutmix_alpha: Betadistribution *α* for CutMix. Set to 0 to disable CutMix.
switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting.
A value of 0 turns shuffling off.
Returns:
mixed_imgs: List of mixed images.
lam_list: Persample lambda values representing the degree of mixing.
pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
use_cutmix: True if CutMix was used for this call, False if Mixup was used.
"""
if len(imgs) < 2:
raise ValueError("Need at least two images to perform Mixup/CutMix.")
# Decide augmentation mode and raw λ
if mixup_alpha > 0.0 and cutmix_alpha > 0.0:
use_cutmix = torch.rand(()).item() < switch_prob
alpha = cutmix_alpha if use_cutmix else mixup_alpha
elif mixup_alpha > 0.0:
use_cutmix = False
alpha = mixup_alpha
elif cutmix_alpha > 0.0:
use_cutmix = True
alpha = cutmix_alpha
else:
raise ValueError("Both mixup_alpha and cutmix_alpha are zero nothing to do.")
lam_raw = torch.distributions.Beta(alpha, alpha).sample().item()
lam_raw = max(0.0, min(1.0, lam_raw)) # numerical safety
# Pair images by nearest aspect ratio
order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
if local_shuffle > 1:
for start in range(0, len(order), local_shuffle):
random.shuffle(order[start: start + local_shuffle])
pair_to: Dict[int, int] = {}
for a, b in zip(order[::2], order[1::2]):
pair_to[a] = b
pair_to[b] = a
odd_one = order[-1] if len(imgs) % 2 else None
mixed_imgs: List[torch.Tensor] = [None] * len(imgs)
lam_list: List[float] = [1.0] * len(imgs)
for i in range(len(imgs)):
if i == odd_one:
mixed_imgs[i] = imgs[i]
continue
j = pair_to[i]
xi, xj = imgs[i], imgs[j]
_, hi, wi = xi.shape
_, hj, wj = xj.shape
dest_area = hi * wi
# Central overlap common to both images
oh, ow = min(hi, hj), min(wi, wj)
overlap_area = oh * ow
top_i, left_i = (hi - oh) // 2, (wi - ow) // 2
top_j, left_j = (hj - oh) // 2, (wj - ow) // 2
xi = xi.clone()
if use_cutmix:
# CutMix: random rectangle inside the overlap
cut_ratio = math.sqrt(1.0 - lam_raw)
ch, cw = int(oh * cut_ratio), int(ow * cut_ratio)
cut_area = ch * cw
y_off = random.randint(0, oh - ch)
x_off = random.randint(0, ow - cw)
yl_i, xl_i = top_i + y_off, left_i + x_off
yl_j, xl_j = top_j + y_off, left_j + x_off
xi[:, yl_i: yl_i + ch, xl_i: xl_i + cw] = xj[:, yl_j: yl_j + ch, xl_j: xl_j + cw]
mixed_imgs[i] = xi
corrected_lam = 1.0 - cut_area / float(dest_area)
lam_list[i] = corrected_lam
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
else:
# Mixup: blend the entire overlap region
patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow]
patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow]
blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended
mixed_imgs[i] = xi
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
lam_list[i] = corrected_lam
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
return mixed_imgs, lam_list, pair_to, use_cutmix
def pairwise_mixup_target(
labels: torch.Tensor,
pair_to: Dict[int, int],
lam_list: List[float],
*,
num_classes: int,
smoothing: float = 0.0,
) -> torch.Tensor:
"""Create soft targets that match the pixellevel mixing performed.
Args:
labels: (B,) tensor of integer class indices.
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
lam_list: Persample fractions of self pixels, also from the mixer.
num_classes: Total number of classes in the dataset.
smoothing: Labelsmoothing value in the range [0, 1).
Returns:
Tensor of shape (B, num_classes) whose rows sum to 1.
"""
off_val = smoothing / num_classes
on_val = 1.0 - smoothing + off_val
y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device)
y_onehot.scatter_(1, labels.unsqueeze(1), on_val)
targets = y_onehot.clone()
for i, j in pair_to.items():
lam = lam_list[i]
targets[i].mul_(lam).add_(y_onehot[j], alpha=1.0 - lam)
return targets
class NaFlexMixup:
"""Callable wrapper that combines mixing and target generation."""
def __init__(
self,
*,
num_classes: int,
mixup_alpha: float = 0.8,
cutmix_alpha: float = 1.0,
switch_prob: float = 0.5,
local_shuffle: int = 4,
smoothing: float = 0.0,
) -> None:
"""Configure the augmentation.
Args:
num_classes: Total number of classes.
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
switch_prob: Probability of selecting CutMix when both modes are enabled.
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
smoothing: Labelsmoothing value. 0 disables smoothing.
"""
self.num_classes = num_classes
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.switch_prob = switch_prob
self.local_shuffle = local_shuffle
self.smoothing = smoothing
def __call__(
self,
imgs: List[torch.Tensor],
labels: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""Apply the augmentation and generate matching targets.
Args:
imgs: List of alreadytransformed images shaped (C, H, W).
labels: Hard labels with shape (B,).
Returns:
mixed_imgs: List of mixed images in the same order and shapes as the input.
targets: Softlabel tensor shaped (B, num_classes) suitable for crossentropy with soft targets.
"""
if isinstance(labels, (list, tuple)):
labels = torch.tensor(labels)
mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size(
imgs,
mixup_alpha=self.mixup_alpha,
cutmix_alpha=self.cutmix_alpha,
switch_prob=self.switch_prob,
local_shuffle=self.local_shuffle,
)
targets = pairwise_mixup_target(
labels,
pair_to,
lam_list,
num_classes=self.num_classes,
smoothing=self.smoothing,
)
return mixed_imgs, targets.unbind(0)