"""Variable‑size Mixup / CutMix utilities for NaFlex data loaders. This module provides: * `mix_batch_variable_size` – pixel‑level Mixup/CutMix that operates on a list of images whose spatial sizes differ, mixing only their central overlap so no resizing is required. * `pairwise_mixup_target` – builds soft‑label targets that exactly match the per‑sample pixel provenance produced by the mixer. * `NaFlexMixup` – a callable functor that wraps the two helpers and stores all augmentation hyper‑parameters in one place, making it easy to plug into different dataset wrappers. """ import math import random from typing import Dict, List, Tuple import torch def mix_batch_variable_size( imgs: List[torch.Tensor], *, mixup_alpha: float = 0.8, cutmix_alpha: float = 1.0, switch_prob: float = 0.5, local_shuffle: int = 4, ) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]: """Apply Mixup or CutMix on a batch of variable‑sized images. The function first sorts images by aspect ratio and pairs neighbouring samples (optionally shuffling within small windows so pairs vary between epochs). Only the mutual central‑overlap region of each pair is mixed Args: imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples. mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup. cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix. switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled. local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. A value of 0 turns shuffling off. Returns: mixed_imgs: List of mixed images. lam_list: Per‑sample lambda values representing the degree of mixing. pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample). use_cutmix: True if CutMix was used for this call, False if Mixup was used. """ if len(imgs) < 2: raise ValueError("Need at least two images to perform Mixup/CutMix.") # Decide augmentation mode and raw λ if mixup_alpha > 0.0 and cutmix_alpha > 0.0: use_cutmix = torch.rand(()).item() < switch_prob alpha = cutmix_alpha if use_cutmix else mixup_alpha elif mixup_alpha > 0.0: use_cutmix = False alpha = mixup_alpha elif cutmix_alpha > 0.0: use_cutmix = True alpha = cutmix_alpha else: raise ValueError("Both mixup_alpha and cutmix_alpha are zero – nothing to do.") lam_raw = torch.distributions.Beta(alpha, alpha).sample().item() lam_raw = max(0.0, min(1.0, lam_raw)) # numerical safety # Pair images by nearest aspect ratio order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1]) if local_shuffle > 1: for start in range(0, len(order), local_shuffle): random.shuffle(order[start: start + local_shuffle]) pair_to: Dict[int, int] = {} for a, b in zip(order[::2], order[1::2]): pair_to[a] = b pair_to[b] = a odd_one = order[-1] if len(imgs) % 2 else None mixed_imgs: List[torch.Tensor] = [None] * len(imgs) lam_list: List[float] = [1.0] * len(imgs) for i in range(len(imgs)): if i == odd_one: mixed_imgs[i] = imgs[i] continue j = pair_to[i] xi, xj = imgs[i], imgs[j] _, hi, wi = xi.shape _, hj, wj = xj.shape dest_area = hi * wi # Central overlap common to both images oh, ow = min(hi, hj), min(wi, wj) overlap_area = oh * ow top_i, left_i = (hi - oh) // 2, (wi - ow) // 2 top_j, left_j = (hj - oh) // 2, (wj - ow) // 2 xi = xi.clone() if use_cutmix: # CutMix: random rectangle inside the overlap cut_ratio = math.sqrt(1.0 - lam_raw) ch, cw = int(oh * cut_ratio), int(ow * cut_ratio) cut_area = ch * cw y_off = random.randint(0, oh - ch) x_off = random.randint(0, ow - cw) yl_i, xl_i = top_i + y_off, left_i + x_off yl_j, xl_j = top_j + y_off, left_j + x_off xi[:, yl_i: yl_i + ch, xl_i: xl_i + cw] = xj[:, yl_j: yl_j + ch, xl_j: xl_j + cw] mixed_imgs[i] = xi corrected_lam = 1.0 - cut_area / float(dest_area) lam_list[i] = corrected_lam #print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam) else: # Mixup: blend the entire overlap region patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow] patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow] blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw) xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended mixed_imgs[i] = xi corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area lam_list[i] = corrected_lam #print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam) return mixed_imgs, lam_list, pair_to, use_cutmix def pairwise_mixup_target( labels: torch.Tensor, pair_to: Dict[int, int], lam_list: List[float], *, num_classes: int, smoothing: float = 0.0, ) -> torch.Tensor: """Create soft targets that match the pixel‑level mixing performed. Args: labels: (B,) tensor of integer class indices. pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size(). lam_list: Per‑sample fractions of self pixels, also from the mixer. num_classes: Total number of classes in the dataset. smoothing: Label‑smoothing value in the range [0, 1). Returns: Tensor of shape (B, num_classes) whose rows sum to 1. """ off_val = smoothing / num_classes on_val = 1.0 - smoothing + off_val y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device) y_onehot.scatter_(1, labels.unsqueeze(1), on_val) targets = y_onehot.clone() for i, j in pair_to.items(): lam = lam_list[i] targets[i].mul_(lam).add_(y_onehot[j], alpha=1.0 - lam) return targets class NaFlexMixup: """Callable wrapper that combines mixing and target generation.""" def __init__( self, *, num_classes: int, mixup_alpha: float = 0.8, cutmix_alpha: float = 1.0, switch_prob: float = 0.5, local_shuffle: int = 4, smoothing: float = 0.0, ) -> None: """Configure the augmentation. Args: num_classes: Total number of classes. mixup_alpha: Beta α for Mixup. 0 disables Mixup. cutmix_alpha: Beta α for CutMix. 0 disables CutMix. switch_prob: Probability of selecting CutMix when both modes are enabled. local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs. smoothing: Label‑smoothing value. 0 disables smoothing. """ self.num_classes = num_classes self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.switch_prob = switch_prob self.local_shuffle = local_shuffle self.smoothing = smoothing def __call__( self, imgs: List[torch.Tensor], labels: torch.Tensor, ) -> Tuple[List[torch.Tensor], torch.Tensor]: """Apply the augmentation and generate matching targets. Args: imgs: List of already‑transformed images shaped (C, H, W). labels: Hard labels with shape (B,). Returns: mixed_imgs: List of mixed images in the same order and shapes as the input. targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets. """ if isinstance(labels, (list, tuple)): labels = torch.tensor(labels) mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size( imgs, mixup_alpha=self.mixup_alpha, cutmix_alpha=self.cutmix_alpha, switch_prob=self.switch_prob, local_shuffle=self.local_shuffle, ) targets = pairwise_mixup_target( labels, pair_to, lam_list, num_classes=self.num_classes, smoothing=self.smoothing, ) return mixed_imgs, targets.unbind(0)