From 7624389fc9e4e2b457081461a6ff24cc51284793 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 20 May 2025 14:38:03 -0700 Subject: [PATCH] Mixup cleanup, add prob support and train script integration. Add working loader based patch compatible RandomErasing for NaFlex mode. --- timm/data/__init__.py | 1 + timm/data/naflex_dataset.py | 23 +- timm/data/naflex_loader.py | 59 +++- timm/data/naflex_mixup.py | 82 +++--- timm/data/naflex_random_erasing.py | 433 +++++++++++++++++++++++++++++ timm/data/transforms_factory.py | 4 + train.py | 63 +++-- 7 files changed, 591 insertions(+), 74 deletions(-) create mode 100644 timm/data/naflex_random_erasing.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 3eba2193..51357619 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -10,6 +10,7 @@ from .loader import create_loader from .mixup import Mixup, FastCollateMixup from .naflex_dataset import VariableSeqMapWrapper from .naflex_loader import create_naflex_loader +from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size from .naflex_transforms import ( ResizeToSequence, CenterCropToSequence, diff --git a/timm/data/naflex_dataset.py b/timm/data/naflex_dataset.py index 858a182f..8269bc78 100644 --- a/timm/data/naflex_dataset.py +++ b/timm/data/naflex_dataset.py @@ -83,8 +83,11 @@ class NaFlexCollator: batch_size = len(batch) # Extract targets - # FIXME need to handle dense (float) targets or always done downstream of this? - targets = torch.tensor([item[1] for item in batch], dtype=torch.int64) + targets = [item[1] for item in batch] + if isinstance(targets[0], torch.Tensor): + targets = torch.stack(targets) + else: + targets = torch.tensor(targets, dtype=torch.int64) # Get patch dictionaries patch_dicts = [item[0] for item in batch] @@ -139,6 +142,7 @@ class VariableSeqMapWrapper(IterableDataset): seq_lens: List[int] = (128, 256, 576, 784, 1024), max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens transform_factory: Optional[Callable] = None, + mixup_fn: Optional[Callable] = None, seed: int = 42, shuffle: bool = True, distributed: bool = False, @@ -172,6 +176,7 @@ class VariableSeqMapWrapper(IterableDataset): else: self.transforms[seq_len] = None # No transform self.collate_fns[seq_len] = NaFlexCollator(seq_len) + self.mixup_fn = mixup_fn self.patchifier = Patchify(self.patch_size) # --- Canonical Schedule Calculation (Done Once) --- @@ -393,6 +398,8 @@ class VariableSeqMapWrapper(IterableDataset): transform = self.transforms.get(seq_len) batch_samples = [] + batch_imgs = [] + batch_targets = [] for idx in indices: try: # Get original image and label from map-style dataset @@ -405,9 +412,8 @@ class VariableSeqMapWrapper(IterableDataset): warnings.warn(f"Transform returned None for index {idx}. Skipping sample.") continue - # Apply patching - patch_data = self.patchifier(processed_img) - batch_samples.append((patch_data, label)) + batch_imgs.append(processed_img) + batch_targets.append(label) except IndexError: warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.") @@ -417,8 +423,13 @@ class VariableSeqMapWrapper(IterableDataset): warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.") continue # Skip problematic sample - # Collate the processed samples into a batch + if self.mixup_fn is not None: + batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets) + + batch_imgs = [self.patchifier(img) for img in batch_imgs] + batch_samples = list(zip(batch_imgs, batch_targets)) if batch_samples: # Only yield if we successfully processed samples + # Collate the processed samples into a batch yield self.collate_fns[seq_len](batch_samples) # If batch_samples is empty after processing 'indices', an empty batch is skipped. diff --git a/timm/data/naflex_loader.py b/timm/data/naflex_loader.py index bb96d07d..ec620828 100644 --- a/timm/data/naflex_loader.py +++ b/timm/data/naflex_loader.py @@ -3,11 +3,13 @@ from contextlib import suppress from functools import partial from typing import Callable, List, Optional, Tuple, Union + import torch from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .loader import _worker_init +from .loader import _worker_init, adapt_to_chs from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator +from .naflex_random_erasing import PatchRandomErasing from .transforms_factory import create_transform @@ -16,19 +18,41 @@ class NaFlexPrefetchLoader: def __init__( self, - loader, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - img_dtype=torch.float32, - device=torch.device('cuda') + loader: torch.utils.data.DataLoader, + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + channels: int = 3, + device: torch.device = torch.device('cuda'), + img_dtype: Optional[torch.dtype] = None, + re_prob: float = 0., + re_mode: str = 'const', + re_count: int = 1, + re_num_splits: int = 0, ): self.loader = loader self.device = device self.img_dtype = img_dtype or torch.float32 # Create mean/std tensors for normalization (will be applied to patches) - self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3) - self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3) + mean = adapt_to_chs(mean, channels) + std = adapt_to_chs(std, channels) + normalization_shape = (1, 1, channels) + self.channels = channels + self.mean = torch.tensor( + [x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape) + self.std = torch.tensor( + [x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape) + + if re_prob > 0.: + self.random_erasing = PatchRandomErasing( + erase_prob=re_prob, + mode=re_mode, + max_count=re_count, + num_splits=re_num_splits, + device=device, + ) + else: + self.random_erasing = None # Check for CUDA/NPU availability self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() @@ -62,9 +86,18 @@ class NaFlexPrefetchLoader: # Normalize patch values (assuming patches are in format [B, N, P*P*C]) batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape - patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization + + # To [B*N, P*P, C] for normalization and erasing + patches = next_input_dict['patches'].view(batch_size, num_patches, -1, self.channels) patches = patches.sub(self.mean).div(self.std) + if self.random_erasing is not None: + patches = self.random_erasing( + patches, + patch_coord=next_input_dict['patch_coord'], + patch_valid=next_input_dict.get('patch_valid', None), + ) + # Reshape back next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels) @@ -103,6 +136,7 @@ def create_naflex_loader( max_seq_len: int = 576, # Fixed sequence length for validation batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens) is_training: bool = False, + mixup_fn: Optional[Callable] = None, no_aug: bool = False, re_prob: float = 0., @@ -141,7 +175,8 @@ def create_naflex_loader( persistent_workers: bool = True, worker_seeding: str = 'all', ): - """Create a data loader with dynamic sequence length sampling for training.""" + """Create a data loader with dynamic sequence length sampling for training. + """ if is_training: # For training, use the dynamic sequence length mechanism @@ -186,6 +221,7 @@ def create_naflex_loader( patch_size=patch_size, seq_lens=train_seq_lens, max_tokens_per_batch=max_tokens_per_batch, + mixup_fn=mixup_fn, seed=seed, distributed=distributed, rank=rank, @@ -219,6 +255,9 @@ def create_naflex_loader( std=std, img_dtype=img_dtype, device=device, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, ) else: diff --git a/timm/data/naflex_mixup.py b/timm/data/naflex_mixup.py index 28bf7a60..427cd2e3 100644 --- a/timm/data/naflex_mixup.py +++ b/timm/data/naflex_mixup.py @@ -26,7 +26,7 @@ def mix_batch_variable_size( cutmix_alpha: float = 1.0, switch_prob: float = 0.5, local_shuffle: int = 4, -) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]: +) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]: """Apply Mixup or CutMix on a batch of variable‑sized images. The function first sorts images by aspect ratio and pairs neighbouring @@ -34,19 +34,16 @@ def mix_batch_variable_size( epochs). Only the mutual central‑overlap region of each pair is mixed Args: - imgs: List of transformed images shaped (C, H, W). Heights and - widths may differ between samples. - mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup. - cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix. + imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples. + mixup_alpha: Beta‑distribution alpha for Mixup. Set to 0 to disable Mixup. + cutmix_alpha: Beta‑distribution alpha for CutMix. Set to 0 to disable CutMix. switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled. - local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. - A value of 0 turns shuffling off. + local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. Off if <= 1. Returns: mixed_imgs: List of mixed images. lam_list: Per‑sample lambda values representing the degree of mixing. pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample). - use_cutmix: True if CutMix was used for this call, False if Mixup was used. """ if len(imgs) < 2: raise ValueError("Need at least two images to perform Mixup/CutMix.") @@ -71,7 +68,7 @@ def mix_batch_variable_size( order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1]) if local_shuffle > 1: for start in range(0, len(order), local_shuffle): - random.shuffle(order[start: start + local_shuffle]) + random.shuffle(order[start:start + local_shuffle]) pair_to: Dict[int, int] = {} for a, b in zip(order[::2], order[1::2]): @@ -119,22 +116,41 @@ def mix_batch_variable_size( #print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam) else: # Mixup: blend the entire overlap region - patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow] - patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow] + patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow] + patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow] blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw) - xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended + xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended mixed_imgs[i] = xi corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area lam_list[i] = corrected_lam #print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam) - return mixed_imgs, lam_list, pair_to, use_cutmix + return mixed_imgs, lam_list, pair_to + + +def smoothed_sparse_target( + targets: torch.Tensor, + *, + num_classes: int, + smoothing: float = 0.0, +) -> torch.Tensor: + off_val = smoothing / num_classes + on_val = 1.0 - smoothing + off_val + + y_onehot = torch.full( + (targets.size(0), num_classes), + off_val, + dtype=torch.float32, + device=targets.device + ) + y_onehot.scatter_(1, targets.unsqueeze(1), on_val) + return y_onehot def pairwise_mixup_target( - labels: torch.Tensor, + targets: torch.Tensor, pair_to: Dict[int, int], lam_list: List[float], *, @@ -144,21 +160,16 @@ def pairwise_mixup_target( """Create soft targets that match the pixel‑level mixing performed. Args: - labels: (B,) tensor of integer class indices. + targets: (B,) tensor of integer class indices. pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size(). - lam_list: Per‑sample fractions of self pixels, also from the mixer. + lam_list: Per‑sample fractions of own pixels, also from the mixer. num_classes: Total number of classes in the dataset. smoothing: Label‑smoothing value in the range [0, 1). Returns: Tensor of shape (B, num_classes) whose rows sum to 1. """ - off_val = smoothing / num_classes - on_val = 1.0 - smoothing + off_val - - y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device) - y_onehot.scatter_(1, labels.unsqueeze(1), on_val) - + y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing) targets = y_onehot.clone() for i, j in pair_to.items(): lam = lam_list[i] @@ -177,8 +188,9 @@ class NaFlexMixup: mixup_alpha: float = 0.8, cutmix_alpha: float = 1.0, switch_prob: float = 0.5, + prob: float = 1.0, local_shuffle: int = 4, - smoothing: float = 0.0, + label_smoothing: float = 0.0, ) -> None: """Configure the augmentation. @@ -187,6 +199,7 @@ class NaFlexMixup: mixup_alpha: Beta α for Mixup. 0 disables Mixup. cutmix_alpha: Beta α for CutMix. 0 disables CutMix. switch_prob: Probability of selecting CutMix when both modes are enabled. + prob: Probability of applying any mixing per batch. local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs. smoothing: Label‑smoothing value. 0 disables smoothing. """ @@ -194,28 +207,33 @@ class NaFlexMixup: self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.switch_prob = switch_prob + self.prob = prob self.local_shuffle = local_shuffle - self.smoothing = smoothing + self.smoothing = label_smoothing def __call__( self, imgs: List[torch.Tensor], - labels: torch.Tensor, - ) -> Tuple[List[torch.Tensor], torch.Tensor]: + targets: torch.Tensor, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Apply the augmentation and generate matching targets. Args: - imgs: List of already‑transformed images shaped (C, H, W). - labels: Hard labels with shape (B,). + imgs: List of already transformed images shaped (C, H, W). + targets: Hard labels with shape (B,). Returns: mixed_imgs: List of mixed images in the same order and shapes as the input. targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets. """ - if isinstance(labels, (list, tuple)): - labels = torch.tensor(labels) + if not isinstance(targets, torch.Tensor): + targets = torch.tensor(targets) - mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size( + if random.random() > self.prob: + targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing) + return imgs, targets.unbind(0) + + mixed_imgs, lam_list, pair_to = mix_batch_variable_size( imgs, mixup_alpha=self.mixup_alpha, cutmix_alpha=self.cutmix_alpha, @@ -224,7 +242,7 @@ class NaFlexMixup: ) targets = pairwise_mixup_target( - labels, + targets, pair_to, lam_list, num_classes=self.num_classes, diff --git a/timm/data/naflex_random_erasing.py b/timm/data/naflex_random_erasing.py new file mode 100644 index 00000000..ac2aaf60 --- /dev/null +++ b/timm/data/naflex_random_erasing.py @@ -0,0 +1,433 @@ +import random +import math +from typing import Optional, Union, Tuple + +import torch + + +class PatchRandomErasing: + """ + Random erasing for patchified images in NaFlex format. + + Supports three modes: + 1. 'patch': Simple mode that erases randomly selected valid patches + 2. 'region': Erases spatial regions at patch granularity + 3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity, + partially erasing patches that are on the boundary of the erased region + + Args: + erase_prob: Probability that the Random Erasing operation will be performed. + patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing. + min_area: Minimum percentage of valid patches/area to erase. + max_area: Maximum percentage of valid patches/area to erase. + min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode). + max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode). + mode: Patch content mode, one of 'const', 'rand', or 'pixel' + 'const' - erase patch is constant color of 0 for all channels + 'rand' - erase patch has same random (normal) value across all elements + 'pixel' - erase patch has per-element random (normal) values + spatial_mode: Erasing strategy, one of 'patch', 'region', or 'subregion' + patch_size: Size of each patch (required for 'subregion' mode) + num_splits: Number of splits to apply erasing to (0 for all) + device: Computation device + """ + + def __init__( + self, + erase_prob: float = 0.5, + patch_drop_prob: float = 0.0, + min_count: int = 1, + max_count: Optional[int] = None, + min_area: float = 0.02, + max_area: float = 1 / 3, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None, + mode: str = 'const', + value: float = 0., + spatial_mode: str = 'region', + patch_size: Optional[Union[int, Tuple[int, int]]] = 16, + num_splits: int = 0, + device: Union[str, torch.device] = 'cuda', + ): + self.erase_prob = erase_prob + self.patch_drop_prob = patch_drop_prob + self.min_count = min_count + self.max_count = max_count or min_count + self.min_area = min_area + self.max_area = max_area + + # Aspect ratio params (for region mode) + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + # Number of splits + self.num_splits = num_splits + self.device = device + + # Strategy mode + self.spatial_mode = spatial_mode + + # Patch size (needed for subregion mode) + self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + + # Value generation mode flags + self.erase_mode = mode.lower() + assert self.erase_mode in ('rand', 'pixel', 'const') + self.const_value = value + + def _get_values( + self, + shape: Union[Tuple[int,...], torch.Size], + value: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + device: Optional[Union[str, torch.device]] = None + ): + """Generate values for erased patches based on the specified mode. + Args: + shape: Shape of patches to erase. + value: Value to use in const (or rand) mode. + dtype: Data type to use. + device: Device to use. + """ + device = device or self.device + if self.erase_mode == 'pixel': + # only mode with erase shape that includes pixels + return torch.empty(shape, dtype=dtype, device=device).normal_() + else: + shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1]) + if self.erase_mode == 'const' or value is not None: + erase_value = value or self.const_value + if isinstance(erase_value, (int, float)): + values = torch.full(shape, erase_value, dtype=dtype, device=device) + else: + erase_value = torch.tensor(erase_value, dtype=dtype, device=device) + values = torch.expand_copy(erase_value, shape) + else: + values = torch.empty(shape, dtype=dtype, device=device).normal_() + return values + + def _drop_patches( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + ): + """ Patch Dropout + + Fully drops patches from datastream. Only mode that saves compute BUT requires support + for non-contiguous patches and associated patch coordinate and valid handling. + """ + # FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches + if random.random() > self.erase_prob: + return + + # Get indices of valid patches + valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist() + + # Skip if no valid patches + if not valid_indices: + return patches, patch_coord, patch_valid + + num_valid = len(valid_indices) + if self.patch_drop_prob: + # patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model) + num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob))) + keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep] + # maintain patch order, possibly useful for debug / visualization + keep_indices = keep_indices.sort(dim=-1)[0] + patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:])) + + return patches, patch_coord, patch_valid + + def _erase_patches( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + patch_shape: torch.Size, + dtype: torch.dtype = torch.float32, + ): + """Apply erasing by selecting individual patches randomly. + + The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles' + noise augmentation at patch size. + """ + if random.random() > self.erase_prob: + return + + # Get indices of valid patches + valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist() + if not valid_indices: + # Skip if no valid patches + return + + num_valid = len(valid_indices) + count = random.randint(self.min_count, self.max_count) + # Determine how many valid patches to erase from RE min/max count and area args + max_erase = max(1, int(num_valid * count * self.max_area)) + min_erase = max(1, int(num_valid * count * self.min_area)) + num_erase = random.randint(min_erase, max_erase) + + # Randomly select valid patches to erase + indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid)) + + random_value = None + if self.erase_mode == 'rand': + random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_() + + for idx in indices_to_erase: + patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value)) + + def _erase_region( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + patch_shape: torch.Size, + dtype: torch.dtype = torch.float32, + ): + """Apply erasing by selecting rectangular regions of patches randomly + + Closer to the original RandomErasing implementation. Erases + spatially contiguous rectangular regions of patches (aligned with patches). + """ + if random.random() > self.erase_prob: + return + + # Determine grid dimensions from coordinates + if patch_valid is not None: + valid_coord = patch_coord[patch_valid] + if len(valid_coord) == 0: + return # No valid patches + max_y = valid_coord[:, 0].max().item() + 1 + max_x = valid_coord[:, 1].max().item() + 1 + else: + max_y = patch_coord[:, 0].max().item() + 1 + max_x = patch_coord[:, 1].max().item() + 1 + + grid_h, grid_w = max_y, max_x + + # Calculate total area + total_area = grid_h * grid_w + + count = random.randint(self.min_count, self.max_count) + for _ in range(count): + # Try to select a valid region to erase (multiple attempts) + for attempt in range(10): + # Sample random area and aspect ratio + target_area = random.uniform(self.min_area, self.max_area) * total_area + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + + # Calculate region height and width + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + # Ensure region fits within grid + if w <= grid_w and h <= grid_h: + # Select random top-left corner + top = random.randint(0, grid_h - h) + left = random.randint(0, grid_w - w) + + # Define region bounds + bottom = top + h + right = left + w + + # Create a single random value for all affected patches if using 'rand' mode + if self.erase_mode == 'rand': + random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_() + else: + random_value = None + + # Find and erase all patches that fall within the region + for i in range(len(patches)): + if patch_valid is None or patch_valid[i]: + y, x = patch_coord[i] + if top <= y < bottom and left <= x < right: + patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value) + + # Successfully applied erasing, exit the loop + break + + def _erase_subregion( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: torch.Tensor, + patch_shape: torch.Size, + patch_size: Tuple[int, int], + dtype: torch.dtype = torch.float32, + ): + """Apply erasing by selecting rectangular regions ignoring patch boundaries. + + Matches or original RandomErasing implementation. Erases spatially contiguous rectangular + regions that are not aligned to patches (erase regions boundaries cut within patches). + + FIXME complexity probably not worth it, may remove. + """ + if random.random() > self.erase_prob: + return + + # Get patch dimensions + patch_h, patch_w = patch_size + channels = patch_shape[-1] + + # Determine grid dimensions in patch coordinates + if patch_valid is not None: + valid_coord = patch_coord[patch_valid] + if len(valid_coord) == 0: + return # No valid patches + max_y = valid_coord[:, 0].max().item() + 1 + max_x = valid_coord[:, 1].max().item() + 1 + else: + max_y = patch_coord[:, 0].max().item() + 1 + max_x = patch_coord[:, 1].max().item() + 1 + + grid_h, grid_w = max_y, max_x + + # Calculate total area in pixel space + total_area = (grid_h * patch_h) * (grid_w * patch_w) + + count = random.randint(self.min_count, self.max_count) + for _ in range(count): + # Try to select a valid region to erase (multiple attempts) + for attempt in range(10): + # Sample random area and aspect ratio + target_area = random.uniform(self.min_area, self.max_area) * total_area + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + + # Calculate region height and width in pixel space + pixel_h = int(round(math.sqrt(target_area * aspect_ratio))) + pixel_w = int(round(math.sqrt(target_area / aspect_ratio))) + + # Ensure region fits within total pixel grid + if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h: + # Select random top-left corner in pixel space + pixel_top = random.randint(0, grid_h * patch_h - pixel_h) + pixel_left = random.randint(0, grid_w * patch_w - pixel_w) + + # Define region bounds in pixel space + pixel_bottom = pixel_top + pixel_h + pixel_right = pixel_left + pixel_w + + # Create a single random value for the entire region if using 'rand' mode + rand_value = None + if self.erase_mode == 'rand': + rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_() + + # For each valid patch, determine if and how it overlaps with the erase region + for i in range(len(patches)): + if patch_valid is None or patch_valid[i]: + # Convert patch coordinates to pixel space (top-left corner) + y, x = patch_coord[i] + patch_pixel_top = y * patch_h + patch_pixel_left = x * patch_w + patch_pixel_bottom = patch_pixel_top + patch_h + patch_pixel_right = patch_pixel_left + patch_w + + # Check if this patch overlaps with the erase region + if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or + patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom): + + # Calculate the overlap region in patch-local coordinates + local_top = max(0, pixel_top - patch_pixel_top) + local_left = max(0, pixel_left - patch_pixel_left) + local_bottom = min(patch_h, pixel_bottom - patch_pixel_top) + local_right = min(patch_w, pixel_right - patch_pixel_left) + + # Reshape the patch to [patch_h, patch_w, chans] + patch_data = patches[i].reshape(patch_h, patch_w, channels) + + erase_shape = (local_bottom - local_top, local_right - local_left, channels) + erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value) + patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value + + # Flatten the patch back to [patch_h*patch_w, chans] + if len(patch_shape) == 2: + patch_data = patch_data.reshape(-1, channels) + patches[i] = patch_data + + # Successfully applied erasing, exit the loop + break + + def __call__( + self, + patches: torch.Tensor, + patch_coord: torch.Tensor, + patch_valid: Optional[torch.Tensor] = None, + ): + """ + Apply random patch erasing. + + Args: + patches: Tensor of shape [B, N, P*P, C] + patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates + patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid + If None, all patches are considered valid + + Returns: + Erased patches tensor of same shape + """ + if patches.ndim == 4: + batch_size, num_patches, patch_dim, channels = patches.shape + if self.patch_size is not None: + patch_size = self.patch_size + else: + patch_size = None + elif patches.ndim == 5: + batch_size, num_patches, patch_h, patch_w, channels = patches.shape + patch_size = (patch_h, patch_w) + else: + assert False + patch_shape = patches.shape[2:] + # patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c) + # patch_size ==> patch h, w (if available, must be avail for subregion mode) + + # Create default valid mask if not provided + if patch_valid is None: + patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device) + + # Skip the first part of the batch if num_splits is set + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + + # Apply erasing to each batch element + for i in range(batch_start, batch_size): + if self.patch_drop_prob: + assert False, "WIP, not completed" + self._drop_patches( + patches[i], + patch_coord[i], + patch_valid[i], + ) + elif self.spatial_mode == 'patch': + self._erase_patches( + patches[i], + patch_coord[i], + patch_valid[i], + patch_shape, + patches.dtype + ) + elif self.spatial_mode == 'region': + self._erase_region( + patches[i], + patch_coord[i], + patch_valid[i], + patch_shape, + patches.dtype + ) + elif self.spatial_mode == 'subregion': + self._erase_subregion( + patches[i], + patch_coord[i], + patch_valid[i], + patch_shape, + patch_size, + patches.dtype + ) + + return patches + + def __repr__(self): + fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}' + fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))' + fs += f', count=({self.min_count}, {self.max_count}))' + return fs \ No newline at end of file diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index ed427456..904017ee 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -132,9 +132,13 @@ def transforms_imagenet_train( primary_tfl = [] if naflex: + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range primary_tfl += [RandomResizedCropToSequence( patch_size=patch_size, max_seq_len=max_seq_len, + scale=scale, + ratio=ratio, interpolation=interpolation )] else: diff --git a/train.py b/train.py index efbe2892..224cc397 100755 --- a/train.py +++ b/train.py @@ -697,32 +697,6 @@ def main(): trust_remote_code=args.dataset_trust_remote_code, ) - # setup mixup / cutmix - collate_fn = None - mixup_fn = None - mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None - if mixup_active: - assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading." - mixup_args = dict( - mixup_alpha=args.mixup, - cutmix_alpha=args.cutmix, - cutmix_minmax=args.cutmix_minmax, - prob=args.mixup_prob, - switch_prob=args.mixup_switch_prob, - mode=args.mixup_mode, - label_smoothing=args.smoothing, - num_classes=args.num_classes - ) - if args.prefetcher: - assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup) - collate_fn = FastCollateMixup(**mixup_args) - else: - mixup_fn = Mixup(**mixup_args) - - # wrap dataset in AugMix helper - if num_aug_splits > 1: - dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) - # create data loaders w/ augmentation pipeline train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: @@ -764,22 +738,59 @@ def main(): worker_seeding=args.worker_seeding, ) + mixup_fn = None + mixup_args = {} + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_args = dict( + mixup_alpha=args.mixup, + cutmix_alpha=args.cutmix, + cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, + switch_prob=args.mixup_switch_prob, + mode=args.mixup_mode, + label_smoothing=args.smoothing, + num_classes=args.num_classes + ) + naflex_mode = False if args.naflex_loader: if utils.is_primary(args): _logger.info('Using NaFlex loader') + assert num_aug_splits <= 1, 'Augmentation splits not supported in NaFlex mode' + naflex_mixup_fn = None + if mixup_active: + from timm.data import NaFlexMixup + mixup_args.pop('mode') # not supported + mixup_args.pop('cutmix_minmax') # not supported + naflex_mixup_fn = NaFlexMixup(**mixup_args) + naflex_mode = True loader_train = create_naflex_loader( dataset=dataset_train, patch_size=16, # Could be derived from model config train_seq_lens=args.naflex_train_seq_lens, + mixup_fn=naflex_mixup_fn, rank=args.rank, world_size=args.world_size, **common_loader_kwargs, **train_loader_kwargs, ) else: + # setup mixup / cutmix + collate_fn = None + if mixup_active: + if args.prefetcher: + assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup) + collate_fn = FastCollateMixup(**mixup_args) + else: + mixup_fn = Mixup(**mixup_args) + + # wrap dataset in AugMix helper + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + # Use standard loader loader_train = create_loader( dataset_train,