From 8fcbceb609083f11abf6646fdf567308dedf184e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 May 2025 14:59:37 -0700 Subject: [PATCH] Add a WIP NaFlex compatible mixup/cutmix for testing --- timm/data/naflex_mixup.py | 233 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 timm/data/naflex_mixup.py diff --git a/timm/data/naflex_mixup.py b/timm/data/naflex_mixup.py new file mode 100644 index 00000000..28bf7a60 --- /dev/null +++ b/timm/data/naflex_mixup.py @@ -0,0 +1,233 @@ +"""Variable‑size Mixup / CutMix utilities for NaFlex data loaders. + +This module provides: + +* `mix_batch_variable_size` – pixel‑level Mixup/CutMix that operates on a + list of images whose spatial sizes differ, mixing only their central overlap + so no resizing is required. +* `pairwise_mixup_target` – builds soft‑label targets that exactly match the + per‑sample pixel provenance produced by the mixer. +* `NaFlexMixup` – a callable functor that wraps the two helpers and stores + all augmentation hyper‑parameters in one place, making it easy to plug into + different dataset wrappers. + +""" +import math +import random +from typing import Dict, List, Tuple + +import torch + + +def mix_batch_variable_size( + imgs: List[torch.Tensor], + *, + mixup_alpha: float = 0.8, + cutmix_alpha: float = 1.0, + switch_prob: float = 0.5, + local_shuffle: int = 4, +) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]: + """Apply Mixup or CutMix on a batch of variable‑sized images. + + The function first sorts images by aspect ratio and pairs neighbouring + samples (optionally shuffling within small windows so pairs vary between + epochs). Only the mutual central‑overlap region of each pair is mixed + + Args: + imgs: List of transformed images shaped (C, H, W). Heights and + widths may differ between samples. + mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup. + cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix. + switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled. + local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. + A value of 0 turns shuffling off. + + Returns: + mixed_imgs: List of mixed images. + lam_list: Per‑sample lambda values representing the degree of mixing. + pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample). + use_cutmix: True if CutMix was used for this call, False if Mixup was used. + """ + if len(imgs) < 2: + raise ValueError("Need at least two images to perform Mixup/CutMix.") + + # Decide augmentation mode and raw λ + if mixup_alpha > 0.0 and cutmix_alpha > 0.0: + use_cutmix = torch.rand(()).item() < switch_prob + alpha = cutmix_alpha if use_cutmix else mixup_alpha + elif mixup_alpha > 0.0: + use_cutmix = False + alpha = mixup_alpha + elif cutmix_alpha > 0.0: + use_cutmix = True + alpha = cutmix_alpha + else: + raise ValueError("Both mixup_alpha and cutmix_alpha are zero – nothing to do.") + + lam_raw = torch.distributions.Beta(alpha, alpha).sample().item() + lam_raw = max(0.0, min(1.0, lam_raw)) # numerical safety + + # Pair images by nearest aspect ratio + order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1]) + if local_shuffle > 1: + for start in range(0, len(order), local_shuffle): + random.shuffle(order[start: start + local_shuffle]) + + pair_to: Dict[int, int] = {} + for a, b in zip(order[::2], order[1::2]): + pair_to[a] = b + pair_to[b] = a + + odd_one = order[-1] if len(imgs) % 2 else None + + mixed_imgs: List[torch.Tensor] = [None] * len(imgs) + lam_list: List[float] = [1.0] * len(imgs) + + for i in range(len(imgs)): + if i == odd_one: + mixed_imgs[i] = imgs[i] + continue + + j = pair_to[i] + xi, xj = imgs[i], imgs[j] + _, hi, wi = xi.shape + _, hj, wj = xj.shape + dest_area = hi * wi + + # Central overlap common to both images + oh, ow = min(hi, hj), min(wi, wj) + overlap_area = oh * ow + top_i, left_i = (hi - oh) // 2, (wi - ow) // 2 + top_j, left_j = (hj - oh) // 2, (wj - ow) // 2 + + xi = xi.clone() + if use_cutmix: + # CutMix: random rectangle inside the overlap + cut_ratio = math.sqrt(1.0 - lam_raw) + ch, cw = int(oh * cut_ratio), int(ow * cut_ratio) + cut_area = ch * cw + y_off = random.randint(0, oh - ch) + x_off = random.randint(0, ow - cw) + + yl_i, xl_i = top_i + y_off, left_i + x_off + yl_j, xl_j = top_j + y_off, left_j + x_off + xi[:, yl_i: yl_i + ch, xl_i: xl_i + cw] = xj[:, yl_j: yl_j + ch, xl_j: xl_j + cw] + mixed_imgs[i] = xi + + corrected_lam = 1.0 - cut_area / float(dest_area) + lam_list[i] = corrected_lam + #print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam) + else: + # Mixup: blend the entire overlap region + patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow] + patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow] + + blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw) + xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended + mixed_imgs[i] = xi + + corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area + lam_list[i] = corrected_lam + #print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam) + + return mixed_imgs, lam_list, pair_to, use_cutmix + + +def pairwise_mixup_target( + labels: torch.Tensor, + pair_to: Dict[int, int], + lam_list: List[float], + *, + num_classes: int, + smoothing: float = 0.0, +) -> torch.Tensor: + """Create soft targets that match the pixel‑level mixing performed. + + Args: + labels: (B,) tensor of integer class indices. + pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size(). + lam_list: Per‑sample fractions of self pixels, also from the mixer. + num_classes: Total number of classes in the dataset. + smoothing: Label‑smoothing value in the range [0, 1). + + Returns: + Tensor of shape (B, num_classes) whose rows sum to 1. + """ + off_val = smoothing / num_classes + on_val = 1.0 - smoothing + off_val + + y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device) + y_onehot.scatter_(1, labels.unsqueeze(1), on_val) + + targets = y_onehot.clone() + for i, j in pair_to.items(): + lam = lam_list[i] + targets[i].mul_(lam).add_(y_onehot[j], alpha=1.0 - lam) + + return targets + + +class NaFlexMixup: + """Callable wrapper that combines mixing and target generation.""" + + def __init__( + self, + *, + num_classes: int, + mixup_alpha: float = 0.8, + cutmix_alpha: float = 1.0, + switch_prob: float = 0.5, + local_shuffle: int = 4, + smoothing: float = 0.0, + ) -> None: + """Configure the augmentation. + + Args: + num_classes: Total number of classes. + mixup_alpha: Beta α for Mixup. 0 disables Mixup. + cutmix_alpha: Beta α for CutMix. 0 disables CutMix. + switch_prob: Probability of selecting CutMix when both modes are enabled. + local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs. + smoothing: Label‑smoothing value. 0 disables smoothing. + """ + self.num_classes = num_classes + self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.switch_prob = switch_prob + self.local_shuffle = local_shuffle + self.smoothing = smoothing + + def __call__( + self, + imgs: List[torch.Tensor], + labels: torch.Tensor, + ) -> Tuple[List[torch.Tensor], torch.Tensor]: + """Apply the augmentation and generate matching targets. + + Args: + imgs: List of already‑transformed images shaped (C, H, W). + labels: Hard labels with shape (B,). + + Returns: + mixed_imgs: List of mixed images in the same order and shapes as the input. + targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets. + """ + if isinstance(labels, (list, tuple)): + labels = torch.tensor(labels) + + mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size( + imgs, + mixup_alpha=self.mixup_alpha, + cutmix_alpha=self.cutmix_alpha, + switch_prob=self.switch_prob, + local_shuffle=self.local_shuffle, + ) + + targets = pairwise_mixup_target( + labels, + pair_to, + lam_list, + num_classes=self.num_classes, + smoothing=self.smoothing, + ) + return mixed_imgs, targets.unbind(0)