Mixup cleanup, add prob support and train script integration. Add working loader based patch compatible RandomErasing for NaFlex mode.

This commit is contained in:
Ross Wightman 2025-05-20 14:38:03 -07:00
parent 8fcbceb609
commit 7624389fc9
7 changed files with 591 additions and 74 deletions

View File

@ -10,6 +10,7 @@ from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .naflex_dataset import VariableSeqMapWrapper
from .naflex_loader import create_naflex_loader
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
from .naflex_transforms import (
ResizeToSequence,
CenterCropToSequence,

View File

@ -83,8 +83,11 @@ class NaFlexCollator:
batch_size = len(batch)
# Extract targets
# FIXME need to handle dense (float) targets or always done downstream of this?
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
targets = [item[1] for item in batch]
if isinstance(targets[0], torch.Tensor):
targets = torch.stack(targets)
else:
targets = torch.tensor(targets, dtype=torch.int64)
# Get patch dictionaries
patch_dicts = [item[0] for item in batch]
@ -139,6 +142,7 @@ class VariableSeqMapWrapper(IterableDataset):
seq_lens: List[int] = (128, 256, 576, 784, 1024),
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
transform_factory: Optional[Callable] = None,
mixup_fn: Optional[Callable] = None,
seed: int = 42,
shuffle: bool = True,
distributed: bool = False,
@ -172,6 +176,7 @@ class VariableSeqMapWrapper(IterableDataset):
else:
self.transforms[seq_len] = None # No transform
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
self.mixup_fn = mixup_fn
self.patchifier = Patchify(self.patch_size)
# --- Canonical Schedule Calculation (Done Once) ---
@ -393,6 +398,8 @@ class VariableSeqMapWrapper(IterableDataset):
transform = self.transforms.get(seq_len)
batch_samples = []
batch_imgs = []
batch_targets = []
for idx in indices:
try:
# Get original image and label from map-style dataset
@ -405,9 +412,8 @@ class VariableSeqMapWrapper(IterableDataset):
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
continue
# Apply patching
patch_data = self.patchifier(processed_img)
batch_samples.append((patch_data, label))
batch_imgs.append(processed_img)
batch_targets.append(label)
except IndexError:
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
@ -417,8 +423,13 @@ class VariableSeqMapWrapper(IterableDataset):
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
continue # Skip problematic sample
# Collate the processed samples into a batch
if self.mixup_fn is not None:
batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
batch_imgs = [self.patchifier(img) for img in batch_imgs]
batch_samples = list(zip(batch_imgs, batch_targets))
if batch_samples: # Only yield if we successfully processed samples
# Collate the processed samples into a batch
yield self.collate_fns[seq_len](batch_samples)
# If batch_samples is empty after processing 'indices', an empty batch is skipped.

View File

@ -3,11 +3,13 @@ from contextlib import suppress
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .loader import _worker_init
from .loader import _worker_init, adapt_to_chs
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
from .naflex_random_erasing import PatchRandomErasing
from .transforms_factory import create_transform
@ -16,19 +18,41 @@ class NaFlexPrefetchLoader:
def __init__(
self,
loader,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
img_dtype=torch.float32,
device=torch.device('cuda')
loader: torch.utils.data.DataLoader,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
channels: int = 3,
device: torch.device = torch.device('cuda'),
img_dtype: Optional[torch.dtype] = None,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
):
self.loader = loader
self.device = device
self.img_dtype = img_dtype or torch.float32
# Create mean/std tensors for normalization (will be applied to patches)
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, 1, channels)
self.channels = channels
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape)
self.std = torch.tensor(
[x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape)
if re_prob > 0.:
self.random_erasing = PatchRandomErasing(
erase_prob=re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device=device,
)
else:
self.random_erasing = None
# Check for CUDA/NPU availability
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
@ -62,9 +86,18 @@ class NaFlexPrefetchLoader:
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
# To [B*N, P*P, C] for normalization and erasing
patches = next_input_dict['patches'].view(batch_size, num_patches, -1, self.channels)
patches = patches.sub(self.mean).div(self.std)
if self.random_erasing is not None:
patches = self.random_erasing(
patches,
patch_coord=next_input_dict['patch_coord'],
patch_valid=next_input_dict.get('patch_valid', None),
)
# Reshape back
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
@ -103,6 +136,7 @@ def create_naflex_loader(
max_seq_len: int = 576, # Fixed sequence length for validation
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
is_training: bool = False,
mixup_fn: Optional[Callable] = None,
no_aug: bool = False,
re_prob: float = 0.,
@ -141,7 +175,8 @@ def create_naflex_loader(
persistent_workers: bool = True,
worker_seeding: str = 'all',
):
"""Create a data loader with dynamic sequence length sampling for training."""
"""Create a data loader with dynamic sequence length sampling for training.
"""
if is_training:
# For training, use the dynamic sequence length mechanism
@ -186,6 +221,7 @@ def create_naflex_loader(
patch_size=patch_size,
seq_lens=train_seq_lens,
max_tokens_per_batch=max_tokens_per_batch,
mixup_fn=mixup_fn,
seed=seed,
distributed=distributed,
rank=rank,
@ -219,6 +255,9 @@ def create_naflex_loader(
std=std,
img_dtype=img_dtype,
device=device,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
)
else:

View File

@ -26,7 +26,7 @@ def mix_batch_variable_size(
cutmix_alpha: float = 1.0,
switch_prob: float = 0.5,
local_shuffle: int = 4,
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]:
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]:
"""Apply Mixup or CutMix on a batch of variablesized images.
The function first sorts images by aspect ratio and pairs neighbouring
@ -34,19 +34,16 @@ def mix_batch_variable_size(
epochs). Only the mutual centraloverlap region of each pair is mixed
Args:
imgs: List of transformed images shaped (C, H, W). Heights and
widths may differ between samples.
mixup_alpha: Betadistribution *α* for Mixup. Set to 0 to disable Mixup.
cutmix_alpha: Betadistribution *α* for CutMix. Set to 0 to disable CutMix.
imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples.
mixup_alpha: Betadistribution alpha for Mixup. Set to 0 to disable Mixup.
cutmix_alpha: Betadistribution alpha for CutMix. Set to 0 to disable CutMix.
switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting.
A value of 0 turns shuffling off.
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. Off if <= 1.
Returns:
mixed_imgs: List of mixed images.
lam_list: Persample lambda values representing the degree of mixing.
pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
use_cutmix: True if CutMix was used for this call, False if Mixup was used.
"""
if len(imgs) < 2:
raise ValueError("Need at least two images to perform Mixup/CutMix.")
@ -71,7 +68,7 @@ def mix_batch_variable_size(
order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
if local_shuffle > 1:
for start in range(0, len(order), local_shuffle):
random.shuffle(order[start: start + local_shuffle])
random.shuffle(order[start:start + local_shuffle])
pair_to: Dict[int, int] = {}
for a, b in zip(order[::2], order[1::2]):
@ -119,22 +116,41 @@ def mix_batch_variable_size(
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
else:
# Mixup: blend the entire overlap region
patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow]
patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow]
patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow]
patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow]
blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended
xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended
mixed_imgs[i] = xi
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
lam_list[i] = corrected_lam
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
return mixed_imgs, lam_list, pair_to, use_cutmix
return mixed_imgs, lam_list, pair_to
def smoothed_sparse_target(
targets: torch.Tensor,
*,
num_classes: int,
smoothing: float = 0.0,
) -> torch.Tensor:
off_val = smoothing / num_classes
on_val = 1.0 - smoothing + off_val
y_onehot = torch.full(
(targets.size(0), num_classes),
off_val,
dtype=torch.float32,
device=targets.device
)
y_onehot.scatter_(1, targets.unsqueeze(1), on_val)
return y_onehot
def pairwise_mixup_target(
labels: torch.Tensor,
targets: torch.Tensor,
pair_to: Dict[int, int],
lam_list: List[float],
*,
@ -144,21 +160,16 @@ def pairwise_mixup_target(
"""Create soft targets that match the pixellevel mixing performed.
Args:
labels: (B,) tensor of integer class indices.
targets: (B,) tensor of integer class indices.
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
lam_list: Persample fractions of self pixels, also from the mixer.
lam_list: Persample fractions of own pixels, also from the mixer.
num_classes: Total number of classes in the dataset.
smoothing: Labelsmoothing value in the range [0, 1).
Returns:
Tensor of shape (B, num_classes) whose rows sum to 1.
"""
off_val = smoothing / num_classes
on_val = 1.0 - smoothing + off_val
y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device)
y_onehot.scatter_(1, labels.unsqueeze(1), on_val)
y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing)
targets = y_onehot.clone()
for i, j in pair_to.items():
lam = lam_list[i]
@ -177,8 +188,9 @@ class NaFlexMixup:
mixup_alpha: float = 0.8,
cutmix_alpha: float = 1.0,
switch_prob: float = 0.5,
prob: float = 1.0,
local_shuffle: int = 4,
smoothing: float = 0.0,
label_smoothing: float = 0.0,
) -> None:
"""Configure the augmentation.
@ -187,6 +199,7 @@ class NaFlexMixup:
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
switch_prob: Probability of selecting CutMix when both modes are enabled.
prob: Probability of applying any mixing per batch.
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
smoothing: Labelsmoothing value. 0 disables smoothing.
"""
@ -194,28 +207,33 @@ class NaFlexMixup:
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.switch_prob = switch_prob
self.prob = prob
self.local_shuffle = local_shuffle
self.smoothing = smoothing
self.smoothing = label_smoothing
def __call__(
self,
imgs: List[torch.Tensor],
labels: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
targets: torch.Tensor,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Apply the augmentation and generate matching targets.
Args:
imgs: List of alreadytransformed images shaped (C, H, W).
labels: Hard labels with shape (B,).
imgs: List of already transformed images shaped (C, H, W).
targets: Hard labels with shape (B,).
Returns:
mixed_imgs: List of mixed images in the same order and shapes as the input.
targets: Softlabel tensor shaped (B, num_classes) suitable for crossentropy with soft targets.
"""
if isinstance(labels, (list, tuple)):
labels = torch.tensor(labels)
if not isinstance(targets, torch.Tensor):
targets = torch.tensor(targets)
mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size(
if random.random() > self.prob:
targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing)
return imgs, targets.unbind(0)
mixed_imgs, lam_list, pair_to = mix_batch_variable_size(
imgs,
mixup_alpha=self.mixup_alpha,
cutmix_alpha=self.cutmix_alpha,
@ -224,7 +242,7 @@ class NaFlexMixup:
)
targets = pairwise_mixup_target(
labels,
targets,
pair_to,
lam_list,
num_classes=self.num_classes,

View File

@ -0,0 +1,433 @@
import random
import math
from typing import Optional, Union, Tuple
import torch
class PatchRandomErasing:
"""
Random erasing for patchified images in NaFlex format.
Supports three modes:
1. 'patch': Simple mode that erases randomly selected valid patches
2. 'region': Erases spatial regions at patch granularity
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
partially erasing patches that are on the boundary of the erased region
Args:
erase_prob: Probability that the Random Erasing operation will be performed.
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
min_area: Minimum percentage of valid patches/area to erase.
max_area: Maximum percentage of valid patches/area to erase.
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
'const' - erase patch is constant color of 0 for all channels
'rand' - erase patch has same random (normal) value across all elements
'pixel' - erase patch has per-element random (normal) values
spatial_mode: Erasing strategy, one of 'patch', 'region', or 'subregion'
patch_size: Size of each patch (required for 'subregion' mode)
num_splits: Number of splits to apply erasing to (0 for all)
device: Computation device
"""
def __init__(
self,
erase_prob: float = 0.5,
patch_drop_prob: float = 0.0,
min_count: int = 1,
max_count: Optional[int] = None,
min_area: float = 0.02,
max_area: float = 1 / 3,
min_aspect: float = 0.3,
max_aspect: Optional[float] = None,
mode: str = 'const',
value: float = 0.,
spatial_mode: str = 'region',
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
num_splits: int = 0,
device: Union[str, torch.device] = 'cuda',
):
self.erase_prob = erase_prob
self.patch_drop_prob = patch_drop_prob
self.min_count = min_count
self.max_count = max_count or min_count
self.min_area = min_area
self.max_area = max_area
# Aspect ratio params (for region mode)
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
# Number of splits
self.num_splits = num_splits
self.device = device
# Strategy mode
self.spatial_mode = spatial_mode
# Patch size (needed for subregion mode)
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
# Value generation mode flags
self.erase_mode = mode.lower()
assert self.erase_mode in ('rand', 'pixel', 'const')
self.const_value = value
def _get_values(
self,
shape: Union[Tuple[int,...], torch.Size],
value: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
device: Optional[Union[str, torch.device]] = None
):
"""Generate values for erased patches based on the specified mode.
Args:
shape: Shape of patches to erase.
value: Value to use in const (or rand) mode.
dtype: Data type to use.
device: Device to use.
"""
device = device or self.device
if self.erase_mode == 'pixel':
# only mode with erase shape that includes pixels
return torch.empty(shape, dtype=dtype, device=device).normal_()
else:
shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1])
if self.erase_mode == 'const' or value is not None:
erase_value = value or self.const_value
if isinstance(erase_value, (int, float)):
values = torch.full(shape, erase_value, dtype=dtype, device=device)
else:
erase_value = torch.tensor(erase_value, dtype=dtype, device=device)
values = torch.expand_copy(erase_value, shape)
else:
values = torch.empty(shape, dtype=dtype, device=device).normal_()
return values
def _drop_patches(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
):
""" Patch Dropout
Fully drops patches from datastream. Only mode that saves compute BUT requires support
for non-contiguous patches and associated patch coordinate and valid handling.
"""
# FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches
if random.random() > self.erase_prob:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
# Skip if no valid patches
if not valid_indices:
return patches, patch_coord, patch_valid
num_valid = len(valid_indices)
if self.patch_drop_prob:
# patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model)
num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob)))
keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep]
# maintain patch order, possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:]))
return patches, patch_coord, patch_valid
def _erase_patches(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting individual patches randomly.
The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles'
noise augmentation at patch size.
"""
if random.random() > self.erase_prob:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
if not valid_indices:
# Skip if no valid patches
return
num_valid = len(valid_indices)
count = random.randint(self.min_count, self.max_count)
# Determine how many valid patches to erase from RE min/max count and area args
max_erase = max(1, int(num_valid * count * self.max_area))
min_erase = max(1, int(num_valid * count * self.min_area))
num_erase = random.randint(min_erase, max_erase)
# Randomly select valid patches to erase
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid))
random_value = None
if self.erase_mode == 'rand':
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
for idx in indices_to_erase:
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
def _erase_region(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions of patches randomly
Closer to the original RandomErasing implementation. Erases
spatially contiguous rectangular regions of patches (aligned with patches).
"""
if random.random() > self.erase_prob:
return
# Determine grid dimensions from coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area
total_area = grid_h * grid_w
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within grid
if w <= grid_w and h <= grid_h:
# Select random top-left corner
top = random.randint(0, grid_h - h)
left = random.randint(0, grid_w - w)
# Define region bounds
bottom = top + h
right = left + w
# Create a single random value for all affected patches if using 'rand' mode
if self.erase_mode == 'rand':
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
else:
random_value = None
# Find and erase all patches that fall within the region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
y, x = patch_coord[i]
if top <= y < bottom and left <= x < right:
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
# Successfully applied erasing, exit the loop
break
def _erase_subregion(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
patch_size: Tuple[int, int],
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
regions that are not aligned to patches (erase regions boundaries cut within patches).
FIXME complexity probably not worth it, may remove.
"""
if random.random() > self.erase_prob:
return
# Get patch dimensions
patch_h, patch_w = patch_size
channels = patch_shape[-1]
# Determine grid dimensions in patch coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area in pixel space
total_area = (grid_h * patch_h) * (grid_w * patch_w)
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width in pixel space
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within total pixel grid
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
# Select random top-left corner in pixel space
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
# Define region bounds in pixel space
pixel_bottom = pixel_top + pixel_h
pixel_right = pixel_left + pixel_w
# Create a single random value for the entire region if using 'rand' mode
rand_value = None
if self.erase_mode == 'rand':
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
# For each valid patch, determine if and how it overlaps with the erase region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
# Convert patch coordinates to pixel space (top-left corner)
y, x = patch_coord[i]
patch_pixel_top = y * patch_h
patch_pixel_left = x * patch_w
patch_pixel_bottom = patch_pixel_top + patch_h
patch_pixel_right = patch_pixel_left + patch_w
# Check if this patch overlaps with the erase region
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
# Calculate the overlap region in patch-local coordinates
local_top = max(0, pixel_top - patch_pixel_top)
local_left = max(0, pixel_left - patch_pixel_left)
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
local_right = min(patch_w, pixel_right - patch_pixel_left)
# Reshape the patch to [patch_h, patch_w, chans]
patch_data = patches[i].reshape(patch_h, patch_w, channels)
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
# Flatten the patch back to [patch_h*patch_w, chans]
if len(patch_shape) == 2:
patch_data = patch_data.reshape(-1, channels)
patches[i] = patch_data
# Successfully applied erasing, exit the loop
break
def __call__(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: Optional[torch.Tensor] = None,
):
"""
Apply random patch erasing.
Args:
patches: Tensor of shape [B, N, P*P, C]
patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates
patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid
If None, all patches are considered valid
Returns:
Erased patches tensor of same shape
"""
if patches.ndim == 4:
batch_size, num_patches, patch_dim, channels = patches.shape
if self.patch_size is not None:
patch_size = self.patch_size
else:
patch_size = None
elif patches.ndim == 5:
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
patch_size = (patch_h, patch_w)
else:
assert False
patch_shape = patches.shape[2:]
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
# Create default valid mask if not provided
if patch_valid is None:
patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device)
# Skip the first part of the batch if num_splits is set
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
# Apply erasing to each batch element
for i in range(batch_start, batch_size):
if self.patch_drop_prob:
assert False, "WIP, not completed"
self._drop_patches(
patches[i],
patch_coord[i],
patch_valid[i],
)
elif self.spatial_mode == 'patch':
self._erase_patches(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patches.dtype
)
elif self.spatial_mode == 'region':
self._erase_region(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patches.dtype
)
elif self.spatial_mode == 'subregion':
self._erase_subregion(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patch_size,
patches.dtype
)
return patches
def __repr__(self):
fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}'
fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))'
fs += f', count=({self.min_count}, {self.max_count}))'
return fs

View File

@ -132,9 +132,13 @@ def transforms_imagenet_train(
primary_tfl = []
if naflex:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl += [RandomResizedCropToSequence(
patch_size=patch_size,
max_seq_len=max_seq_len,
scale=scale,
ratio=ratio,
interpolation=interpolation
)]
else:

View File

@ -697,32 +697,6 @@ def main():
trust_remote_code=args.dataset_trust_remote_code,
)
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading."
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
@ -764,22 +738,59 @@ def main():
worker_seeding=args.worker_seeding,
)
mixup_fn = None
mixup_args = {}
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
naflex_mode = False
if args.naflex_loader:
if utils.is_primary(args):
_logger.info('Using NaFlex loader')
assert num_aug_splits <= 1, 'Augmentation splits not supported in NaFlex mode'
naflex_mixup_fn = None
if mixup_active:
from timm.data import NaFlexMixup
mixup_args.pop('mode') # not supported
mixup_args.pop('cutmix_minmax') # not supported
naflex_mixup_fn = NaFlexMixup(**mixup_args)
naflex_mode = True
loader_train = create_naflex_loader(
dataset=dataset_train,
patch_size=16, # Could be derived from model config
train_seq_lens=args.naflex_train_seq_lens,
mixup_fn=naflex_mixup_fn,
rank=args.rank,
world_size=args.world_size,
**common_loader_kwargs,
**train_loader_kwargs,
)
else:
# setup mixup / cutmix
collate_fn = None
if mixup_active:
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# Use standard loader
loader_train = create_loader(
dataset_train,