mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Mixup cleanup, add prob support and train script integration. Add working loader based patch compatible RandomErasing for NaFlex mode.
This commit is contained in:
parent
8fcbceb609
commit
7624389fc9
@ -10,6 +10,7 @@ from .loader import create_loader
|
|||||||
from .mixup import Mixup, FastCollateMixup
|
from .mixup import Mixup, FastCollateMixup
|
||||||
from .naflex_dataset import VariableSeqMapWrapper
|
from .naflex_dataset import VariableSeqMapWrapper
|
||||||
from .naflex_loader import create_naflex_loader
|
from .naflex_loader import create_naflex_loader
|
||||||
|
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
|
||||||
from .naflex_transforms import (
|
from .naflex_transforms import (
|
||||||
ResizeToSequence,
|
ResizeToSequence,
|
||||||
CenterCropToSequence,
|
CenterCropToSequence,
|
||||||
|
@ -83,8 +83,11 @@ class NaFlexCollator:
|
|||||||
batch_size = len(batch)
|
batch_size = len(batch)
|
||||||
|
|
||||||
# Extract targets
|
# Extract targets
|
||||||
# FIXME need to handle dense (float) targets or always done downstream of this?
|
targets = [item[1] for item in batch]
|
||||||
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
|
if isinstance(targets[0], torch.Tensor):
|
||||||
|
targets = torch.stack(targets)
|
||||||
|
else:
|
||||||
|
targets = torch.tensor(targets, dtype=torch.int64)
|
||||||
|
|
||||||
# Get patch dictionaries
|
# Get patch dictionaries
|
||||||
patch_dicts = [item[0] for item in batch]
|
patch_dicts = [item[0] for item in batch]
|
||||||
@ -139,6 +142,7 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
seq_lens: List[int] = (128, 256, 576, 784, 1024),
|
seq_lens: List[int] = (128, 256, 576, 784, 1024),
|
||||||
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
|
max_tokens_per_batch: int = 4096 * 4, # Example: 16k tokens
|
||||||
transform_factory: Optional[Callable] = None,
|
transform_factory: Optional[Callable] = None,
|
||||||
|
mixup_fn: Optional[Callable] = None,
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
distributed: bool = False,
|
distributed: bool = False,
|
||||||
@ -172,6 +176,7 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
else:
|
else:
|
||||||
self.transforms[seq_len] = None # No transform
|
self.transforms[seq_len] = None # No transform
|
||||||
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
|
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
|
||||||
|
self.mixup_fn = mixup_fn
|
||||||
self.patchifier = Patchify(self.patch_size)
|
self.patchifier = Patchify(self.patch_size)
|
||||||
|
|
||||||
# --- Canonical Schedule Calculation (Done Once) ---
|
# --- Canonical Schedule Calculation (Done Once) ---
|
||||||
@ -393,6 +398,8 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
transform = self.transforms.get(seq_len)
|
transform = self.transforms.get(seq_len)
|
||||||
|
|
||||||
batch_samples = []
|
batch_samples = []
|
||||||
|
batch_imgs = []
|
||||||
|
batch_targets = []
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
try:
|
try:
|
||||||
# Get original image and label from map-style dataset
|
# Get original image and label from map-style dataset
|
||||||
@ -405,9 +412,8 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
|
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Apply patching
|
batch_imgs.append(processed_img)
|
||||||
patch_data = self.patchifier(processed_img)
|
batch_targets.append(label)
|
||||||
batch_samples.append((patch_data, label))
|
|
||||||
|
|
||||||
except IndexError:
|
except IndexError:
|
||||||
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
|
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
|
||||||
@ -417,8 +423,13 @@ class VariableSeqMapWrapper(IterableDataset):
|
|||||||
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
|
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
|
||||||
continue # Skip problematic sample
|
continue # Skip problematic sample
|
||||||
|
|
||||||
# Collate the processed samples into a batch
|
if self.mixup_fn is not None:
|
||||||
|
batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
|
||||||
|
|
||||||
|
batch_imgs = [self.patchifier(img) for img in batch_imgs]
|
||||||
|
batch_samples = list(zip(batch_imgs, batch_targets))
|
||||||
if batch_samples: # Only yield if we successfully processed samples
|
if batch_samples: # Only yield if we successfully processed samples
|
||||||
|
# Collate the processed samples into a batch
|
||||||
yield self.collate_fns[seq_len](batch_samples)
|
yield self.collate_fns[seq_len](batch_samples)
|
||||||
|
|
||||||
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
|
||||||
|
@ -3,11 +3,13 @@ from contextlib import suppress
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .loader import _worker_init
|
from .loader import _worker_init, adapt_to_chs
|
||||||
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
|
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
|
||||||
|
from .naflex_random_erasing import PatchRandomErasing
|
||||||
from .transforms_factory import create_transform
|
from .transforms_factory import create_transform
|
||||||
|
|
||||||
|
|
||||||
@ -16,19 +18,41 @@ class NaFlexPrefetchLoader:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
loader,
|
loader: torch.utils.data.DataLoader,
|
||||||
mean=(0.485, 0.456, 0.406),
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
std=(0.229, 0.224, 0.225),
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
img_dtype=torch.float32,
|
channels: int = 3,
|
||||||
device=torch.device('cuda')
|
device: torch.device = torch.device('cuda'),
|
||||||
|
img_dtype: Optional[torch.dtype] = None,
|
||||||
|
re_prob: float = 0.,
|
||||||
|
re_mode: str = 'const',
|
||||||
|
re_count: int = 1,
|
||||||
|
re_num_splits: int = 0,
|
||||||
):
|
):
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self.device = device
|
self.device = device
|
||||||
self.img_dtype = img_dtype or torch.float32
|
self.img_dtype = img_dtype or torch.float32
|
||||||
|
|
||||||
# Create mean/std tensors for normalization (will be applied to patches)
|
# Create mean/std tensors for normalization (will be applied to patches)
|
||||||
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
|
mean = adapt_to_chs(mean, channels)
|
||||||
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
|
std = adapt_to_chs(std, channels)
|
||||||
|
normalization_shape = (1, 1, channels)
|
||||||
|
self.channels = channels
|
||||||
|
self.mean = torch.tensor(
|
||||||
|
[x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape)
|
||||||
|
self.std = torch.tensor(
|
||||||
|
[x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape)
|
||||||
|
|
||||||
|
if re_prob > 0.:
|
||||||
|
self.random_erasing = PatchRandomErasing(
|
||||||
|
erase_prob=re_prob,
|
||||||
|
mode=re_mode,
|
||||||
|
max_count=re_count,
|
||||||
|
num_splits=re_num_splits,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.random_erasing = None
|
||||||
|
|
||||||
# Check for CUDA/NPU availability
|
# Check for CUDA/NPU availability
|
||||||
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||||
@ -62,9 +86,18 @@ class NaFlexPrefetchLoader:
|
|||||||
|
|
||||||
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
|
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
|
||||||
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
|
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
|
||||||
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
|
|
||||||
|
# To [B*N, P*P, C] for normalization and erasing
|
||||||
|
patches = next_input_dict['patches'].view(batch_size, num_patches, -1, self.channels)
|
||||||
patches = patches.sub(self.mean).div(self.std)
|
patches = patches.sub(self.mean).div(self.std)
|
||||||
|
|
||||||
|
if self.random_erasing is not None:
|
||||||
|
patches = self.random_erasing(
|
||||||
|
patches,
|
||||||
|
patch_coord=next_input_dict['patch_coord'],
|
||||||
|
patch_valid=next_input_dict.get('patch_valid', None),
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape back
|
# Reshape back
|
||||||
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
|
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
|
||||||
|
|
||||||
@ -103,6 +136,7 @@ def create_naflex_loader(
|
|||||||
max_seq_len: int = 576, # Fixed sequence length for validation
|
max_seq_len: int = 576, # Fixed sequence length for validation
|
||||||
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
|
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
|
||||||
is_training: bool = False,
|
is_training: bool = False,
|
||||||
|
mixup_fn: Optional[Callable] = None,
|
||||||
|
|
||||||
no_aug: bool = False,
|
no_aug: bool = False,
|
||||||
re_prob: float = 0.,
|
re_prob: float = 0.,
|
||||||
@ -141,7 +175,8 @@ def create_naflex_loader(
|
|||||||
persistent_workers: bool = True,
|
persistent_workers: bool = True,
|
||||||
worker_seeding: str = 'all',
|
worker_seeding: str = 'all',
|
||||||
):
|
):
|
||||||
"""Create a data loader with dynamic sequence length sampling for training."""
|
"""Create a data loader with dynamic sequence length sampling for training.
|
||||||
|
"""
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
# For training, use the dynamic sequence length mechanism
|
# For training, use the dynamic sequence length mechanism
|
||||||
@ -186,6 +221,7 @@ def create_naflex_loader(
|
|||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
seq_lens=train_seq_lens,
|
seq_lens=train_seq_lens,
|
||||||
max_tokens_per_batch=max_tokens_per_batch,
|
max_tokens_per_batch=max_tokens_per_batch,
|
||||||
|
mixup_fn=mixup_fn,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
distributed=distributed,
|
distributed=distributed,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -219,6 +255,9 @@ def create_naflex_loader(
|
|||||||
std=std,
|
std=std,
|
||||||
img_dtype=img_dtype,
|
img_dtype=img_dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
re_prob=re_prob,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -26,7 +26,7 @@ def mix_batch_variable_size(
|
|||||||
cutmix_alpha: float = 1.0,
|
cutmix_alpha: float = 1.0,
|
||||||
switch_prob: float = 0.5,
|
switch_prob: float = 0.5,
|
||||||
local_shuffle: int = 4,
|
local_shuffle: int = 4,
|
||||||
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int], bool]:
|
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]:
|
||||||
"""Apply Mixup or CutMix on a batch of variable‑sized images.
|
"""Apply Mixup or CutMix on a batch of variable‑sized images.
|
||||||
|
|
||||||
The function first sorts images by aspect ratio and pairs neighbouring
|
The function first sorts images by aspect ratio and pairs neighbouring
|
||||||
@ -34,19 +34,16 @@ def mix_batch_variable_size(
|
|||||||
epochs). Only the mutual central‑overlap region of each pair is mixed
|
epochs). Only the mutual central‑overlap region of each pair is mixed
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
imgs: List of transformed images shaped (C, H, W). Heights and
|
imgs: List of transformed images shaped (C, H, W). Heights and widths may differ between samples.
|
||||||
widths may differ between samples.
|
mixup_alpha: Beta‑distribution alpha for Mixup. Set to 0 to disable Mixup.
|
||||||
mixup_alpha: Beta‑distribution *α* for Mixup. Set to 0 to disable Mixup.
|
cutmix_alpha: Beta‑distribution alpha for CutMix. Set to 0 to disable CutMix.
|
||||||
cutmix_alpha: Beta‑distribution *α* for CutMix. Set to 0 to disable CutMix.
|
|
||||||
switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
|
switch_prob: Probability of using CutMix when both Mixup and CutMix are enabled.
|
||||||
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting.
|
local_shuffle: Size of local windows that are randomly shuffled after aspect sorting. Off if <= 1.
|
||||||
A value of 0 turns shuffling off.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mixed_imgs: List of mixed images.
|
mixed_imgs: List of mixed images.
|
||||||
lam_list: Per‑sample lambda values representing the degree of mixing.
|
lam_list: Per‑sample lambda values representing the degree of mixing.
|
||||||
pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
|
pair_to: Mapping i -> j describing which sample was mixed with which (absent for unmatched odd sample).
|
||||||
use_cutmix: True if CutMix was used for this call, False if Mixup was used.
|
|
||||||
"""
|
"""
|
||||||
if len(imgs) < 2:
|
if len(imgs) < 2:
|
||||||
raise ValueError("Need at least two images to perform Mixup/CutMix.")
|
raise ValueError("Need at least two images to perform Mixup/CutMix.")
|
||||||
@ -71,7 +68,7 @@ def mix_batch_variable_size(
|
|||||||
order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
|
order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
|
||||||
if local_shuffle > 1:
|
if local_shuffle > 1:
|
||||||
for start in range(0, len(order), local_shuffle):
|
for start in range(0, len(order), local_shuffle):
|
||||||
random.shuffle(order[start: start + local_shuffle])
|
random.shuffle(order[start:start + local_shuffle])
|
||||||
|
|
||||||
pair_to: Dict[int, int] = {}
|
pair_to: Dict[int, int] = {}
|
||||||
for a, b in zip(order[::2], order[1::2]):
|
for a, b in zip(order[::2], order[1::2]):
|
||||||
@ -119,22 +116,41 @@ def mix_batch_variable_size(
|
|||||||
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
|
#print(i, 'Doing cutmix', yl_i, xl_i, yl_j, xl_j, ch, cw, lam_raw, corrected_lam)
|
||||||
else:
|
else:
|
||||||
# Mixup: blend the entire overlap region
|
# Mixup: blend the entire overlap region
|
||||||
patch_i = xi[:, top_i: top_i + oh, left_i: left_i + ow]
|
patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow]
|
||||||
patch_j = xj[:, top_j: top_j + oh, left_j: left_j + ow]
|
patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow]
|
||||||
|
|
||||||
blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
|
blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
|
||||||
xi[:, top_i: top_i + oh, left_i: left_i + ow] = blended
|
xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended
|
||||||
mixed_imgs[i] = xi
|
mixed_imgs[i] = xi
|
||||||
|
|
||||||
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
|
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
|
||||||
lam_list[i] = corrected_lam
|
lam_list[i] = corrected_lam
|
||||||
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
|
#print(i, 'Doing mixup', top_i, left_i, top_j, left_j, (oh, ow), (hi, wi), (hj, wj), lam_raw, corrected_lam)
|
||||||
|
|
||||||
return mixed_imgs, lam_list, pair_to, use_cutmix
|
return mixed_imgs, lam_list, pair_to
|
||||||
|
|
||||||
|
|
||||||
|
def smoothed_sparse_target(
|
||||||
|
targets: torch.Tensor,
|
||||||
|
*,
|
||||||
|
num_classes: int,
|
||||||
|
smoothing: float = 0.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
off_val = smoothing / num_classes
|
||||||
|
on_val = 1.0 - smoothing + off_val
|
||||||
|
|
||||||
|
y_onehot = torch.full(
|
||||||
|
(targets.size(0), num_classes),
|
||||||
|
off_val,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=targets.device
|
||||||
|
)
|
||||||
|
y_onehot.scatter_(1, targets.unsqueeze(1), on_val)
|
||||||
|
return y_onehot
|
||||||
|
|
||||||
|
|
||||||
def pairwise_mixup_target(
|
def pairwise_mixup_target(
|
||||||
labels: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
pair_to: Dict[int, int],
|
pair_to: Dict[int, int],
|
||||||
lam_list: List[float],
|
lam_list: List[float],
|
||||||
*,
|
*,
|
||||||
@ -144,21 +160,16 @@ def pairwise_mixup_target(
|
|||||||
"""Create soft targets that match the pixel‑level mixing performed.
|
"""Create soft targets that match the pixel‑level mixing performed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
labels: (B,) tensor of integer class indices.
|
targets: (B,) tensor of integer class indices.
|
||||||
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
|
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
|
||||||
lam_list: Per‑sample fractions of self pixels, also from the mixer.
|
lam_list: Per‑sample fractions of own pixels, also from the mixer.
|
||||||
num_classes: Total number of classes in the dataset.
|
num_classes: Total number of classes in the dataset.
|
||||||
smoothing: Label‑smoothing value in the range [0, 1).
|
smoothing: Label‑smoothing value in the range [0, 1).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of shape (B, num_classes) whose rows sum to 1.
|
Tensor of shape (B, num_classes) whose rows sum to 1.
|
||||||
"""
|
"""
|
||||||
off_val = smoothing / num_classes
|
y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing)
|
||||||
on_val = 1.0 - smoothing + off_val
|
|
||||||
|
|
||||||
y_onehot = torch.full((labels.size(0), num_classes), off_val, dtype=torch.float32, device=labels.device)
|
|
||||||
y_onehot.scatter_(1, labels.unsqueeze(1), on_val)
|
|
||||||
|
|
||||||
targets = y_onehot.clone()
|
targets = y_onehot.clone()
|
||||||
for i, j in pair_to.items():
|
for i, j in pair_to.items():
|
||||||
lam = lam_list[i]
|
lam = lam_list[i]
|
||||||
@ -177,8 +188,9 @@ class NaFlexMixup:
|
|||||||
mixup_alpha: float = 0.8,
|
mixup_alpha: float = 0.8,
|
||||||
cutmix_alpha: float = 1.0,
|
cutmix_alpha: float = 1.0,
|
||||||
switch_prob: float = 0.5,
|
switch_prob: float = 0.5,
|
||||||
|
prob: float = 1.0,
|
||||||
local_shuffle: int = 4,
|
local_shuffle: int = 4,
|
||||||
smoothing: float = 0.0,
|
label_smoothing: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Configure the augmentation.
|
"""Configure the augmentation.
|
||||||
|
|
||||||
@ -187,6 +199,7 @@ class NaFlexMixup:
|
|||||||
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
|
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
|
||||||
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
|
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
|
||||||
switch_prob: Probability of selecting CutMix when both modes are enabled.
|
switch_prob: Probability of selecting CutMix when both modes are enabled.
|
||||||
|
prob: Probability of applying any mixing per batch.
|
||||||
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
|
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
|
||||||
smoothing: Label‑smoothing value. 0 disables smoothing.
|
smoothing: Label‑smoothing value. 0 disables smoothing.
|
||||||
"""
|
"""
|
||||||
@ -194,28 +207,33 @@ class NaFlexMixup:
|
|||||||
self.mixup_alpha = mixup_alpha
|
self.mixup_alpha = mixup_alpha
|
||||||
self.cutmix_alpha = cutmix_alpha
|
self.cutmix_alpha = cutmix_alpha
|
||||||
self.switch_prob = switch_prob
|
self.switch_prob = switch_prob
|
||||||
|
self.prob = prob
|
||||||
self.local_shuffle = local_shuffle
|
self.local_shuffle = local_shuffle
|
||||||
self.smoothing = smoothing
|
self.smoothing = label_smoothing
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
imgs: List[torch.Tensor],
|
imgs: List[torch.Tensor],
|
||||||
labels: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
"""Apply the augmentation and generate matching targets.
|
"""Apply the augmentation and generate matching targets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
imgs: List of already‑transformed images shaped (C, H, W).
|
imgs: List of already transformed images shaped (C, H, W).
|
||||||
labels: Hard labels with shape (B,).
|
targets: Hard labels with shape (B,).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mixed_imgs: List of mixed images in the same order and shapes as the input.
|
mixed_imgs: List of mixed images in the same order and shapes as the input.
|
||||||
targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
|
targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
|
||||||
"""
|
"""
|
||||||
if isinstance(labels, (list, tuple)):
|
if not isinstance(targets, torch.Tensor):
|
||||||
labels = torch.tensor(labels)
|
targets = torch.tensor(targets)
|
||||||
|
|
||||||
mixed_imgs, lam_list, pair_to, _ = mix_batch_variable_size(
|
if random.random() > self.prob:
|
||||||
|
targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing)
|
||||||
|
return imgs, targets.unbind(0)
|
||||||
|
|
||||||
|
mixed_imgs, lam_list, pair_to = mix_batch_variable_size(
|
||||||
imgs,
|
imgs,
|
||||||
mixup_alpha=self.mixup_alpha,
|
mixup_alpha=self.mixup_alpha,
|
||||||
cutmix_alpha=self.cutmix_alpha,
|
cutmix_alpha=self.cutmix_alpha,
|
||||||
@ -224,7 +242,7 @@ class NaFlexMixup:
|
|||||||
)
|
)
|
||||||
|
|
||||||
targets = pairwise_mixup_target(
|
targets = pairwise_mixup_target(
|
||||||
labels,
|
targets,
|
||||||
pair_to,
|
pair_to,
|
||||||
lam_list,
|
lam_list,
|
||||||
num_classes=self.num_classes,
|
num_classes=self.num_classes,
|
||||||
|
433
timm/data/naflex_random_erasing.py
Normal file
433
timm/data/naflex_random_erasing.py
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
import random
|
||||||
|
import math
|
||||||
|
from typing import Optional, Union, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PatchRandomErasing:
|
||||||
|
"""
|
||||||
|
Random erasing for patchified images in NaFlex format.
|
||||||
|
|
||||||
|
Supports three modes:
|
||||||
|
1. 'patch': Simple mode that erases randomly selected valid patches
|
||||||
|
2. 'region': Erases spatial regions at patch granularity
|
||||||
|
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
|
||||||
|
partially erasing patches that are on the boundary of the erased region
|
||||||
|
|
||||||
|
Args:
|
||||||
|
erase_prob: Probability that the Random Erasing operation will be performed.
|
||||||
|
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
|
||||||
|
min_area: Minimum percentage of valid patches/area to erase.
|
||||||
|
max_area: Maximum percentage of valid patches/area to erase.
|
||||||
|
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
|
||||||
|
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
|
||||||
|
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
|
||||||
|
'const' - erase patch is constant color of 0 for all channels
|
||||||
|
'rand' - erase patch has same random (normal) value across all elements
|
||||||
|
'pixel' - erase patch has per-element random (normal) values
|
||||||
|
spatial_mode: Erasing strategy, one of 'patch', 'region', or 'subregion'
|
||||||
|
patch_size: Size of each patch (required for 'subregion' mode)
|
||||||
|
num_splits: Number of splits to apply erasing to (0 for all)
|
||||||
|
device: Computation device
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
erase_prob: float = 0.5,
|
||||||
|
patch_drop_prob: float = 0.0,
|
||||||
|
min_count: int = 1,
|
||||||
|
max_count: Optional[int] = None,
|
||||||
|
min_area: float = 0.02,
|
||||||
|
max_area: float = 1 / 3,
|
||||||
|
min_aspect: float = 0.3,
|
||||||
|
max_aspect: Optional[float] = None,
|
||||||
|
mode: str = 'const',
|
||||||
|
value: float = 0.,
|
||||||
|
spatial_mode: str = 'region',
|
||||||
|
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
|
||||||
|
num_splits: int = 0,
|
||||||
|
device: Union[str, torch.device] = 'cuda',
|
||||||
|
):
|
||||||
|
self.erase_prob = erase_prob
|
||||||
|
self.patch_drop_prob = patch_drop_prob
|
||||||
|
self.min_count = min_count
|
||||||
|
self.max_count = max_count or min_count
|
||||||
|
self.min_area = min_area
|
||||||
|
self.max_area = max_area
|
||||||
|
|
||||||
|
# Aspect ratio params (for region mode)
|
||||||
|
max_aspect = max_aspect or 1 / min_aspect
|
||||||
|
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||||
|
|
||||||
|
# Number of splits
|
||||||
|
self.num_splits = num_splits
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Strategy mode
|
||||||
|
self.spatial_mode = spatial_mode
|
||||||
|
|
||||||
|
# Patch size (needed for subregion mode)
|
||||||
|
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||||
|
|
||||||
|
# Value generation mode flags
|
||||||
|
self.erase_mode = mode.lower()
|
||||||
|
assert self.erase_mode in ('rand', 'pixel', 'const')
|
||||||
|
self.const_value = value
|
||||||
|
|
||||||
|
def _get_values(
|
||||||
|
self,
|
||||||
|
shape: Union[Tuple[int,...], torch.Size],
|
||||||
|
value: Optional[torch.Tensor] = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[Union[str, torch.device]] = None
|
||||||
|
):
|
||||||
|
"""Generate values for erased patches based on the specified mode.
|
||||||
|
Args:
|
||||||
|
shape: Shape of patches to erase.
|
||||||
|
value: Value to use in const (or rand) mode.
|
||||||
|
dtype: Data type to use.
|
||||||
|
device: Device to use.
|
||||||
|
"""
|
||||||
|
device = device or self.device
|
||||||
|
if self.erase_mode == 'pixel':
|
||||||
|
# only mode with erase shape that includes pixels
|
||||||
|
return torch.empty(shape, dtype=dtype, device=device).normal_()
|
||||||
|
else:
|
||||||
|
shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1])
|
||||||
|
if self.erase_mode == 'const' or value is not None:
|
||||||
|
erase_value = value or self.const_value
|
||||||
|
if isinstance(erase_value, (int, float)):
|
||||||
|
values = torch.full(shape, erase_value, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
erase_value = torch.tensor(erase_value, dtype=dtype, device=device)
|
||||||
|
values = torch.expand_copy(erase_value, shape)
|
||||||
|
else:
|
||||||
|
values = torch.empty(shape, dtype=dtype, device=device).normal_()
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _drop_patches(
|
||||||
|
self,
|
||||||
|
patches: torch.Tensor,
|
||||||
|
patch_coord: torch.Tensor,
|
||||||
|
patch_valid: torch.Tensor,
|
||||||
|
):
|
||||||
|
""" Patch Dropout
|
||||||
|
|
||||||
|
Fully drops patches from datastream. Only mode that saves compute BUT requires support
|
||||||
|
for non-contiguous patches and associated patch coordinate and valid handling.
|
||||||
|
"""
|
||||||
|
# FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches
|
||||||
|
if random.random() > self.erase_prob:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get indices of valid patches
|
||||||
|
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
|
||||||
|
|
||||||
|
# Skip if no valid patches
|
||||||
|
if not valid_indices:
|
||||||
|
return patches, patch_coord, patch_valid
|
||||||
|
|
||||||
|
num_valid = len(valid_indices)
|
||||||
|
if self.patch_drop_prob:
|
||||||
|
# patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model)
|
||||||
|
num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob)))
|
||||||
|
keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep]
|
||||||
|
# maintain patch order, possibly useful for debug / visualization
|
||||||
|
keep_indices = keep_indices.sort(dim=-1)[0]
|
||||||
|
patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:]))
|
||||||
|
|
||||||
|
return patches, patch_coord, patch_valid
|
||||||
|
|
||||||
|
def _erase_patches(
|
||||||
|
self,
|
||||||
|
patches: torch.Tensor,
|
||||||
|
patch_coord: torch.Tensor,
|
||||||
|
patch_valid: torch.Tensor,
|
||||||
|
patch_shape: torch.Size,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
"""Apply erasing by selecting individual patches randomly.
|
||||||
|
|
||||||
|
The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles'
|
||||||
|
noise augmentation at patch size.
|
||||||
|
"""
|
||||||
|
if random.random() > self.erase_prob:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get indices of valid patches
|
||||||
|
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
|
||||||
|
if not valid_indices:
|
||||||
|
# Skip if no valid patches
|
||||||
|
return
|
||||||
|
|
||||||
|
num_valid = len(valid_indices)
|
||||||
|
count = random.randint(self.min_count, self.max_count)
|
||||||
|
# Determine how many valid patches to erase from RE min/max count and area args
|
||||||
|
max_erase = max(1, int(num_valid * count * self.max_area))
|
||||||
|
min_erase = max(1, int(num_valid * count * self.min_area))
|
||||||
|
num_erase = random.randint(min_erase, max_erase)
|
||||||
|
|
||||||
|
# Randomly select valid patches to erase
|
||||||
|
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid))
|
||||||
|
|
||||||
|
random_value = None
|
||||||
|
if self.erase_mode == 'rand':
|
||||||
|
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
|
||||||
|
|
||||||
|
for idx in indices_to_erase:
|
||||||
|
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
|
||||||
|
|
||||||
|
def _erase_region(
|
||||||
|
self,
|
||||||
|
patches: torch.Tensor,
|
||||||
|
patch_coord: torch.Tensor,
|
||||||
|
patch_valid: torch.Tensor,
|
||||||
|
patch_shape: torch.Size,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
"""Apply erasing by selecting rectangular regions of patches randomly
|
||||||
|
|
||||||
|
Closer to the original RandomErasing implementation. Erases
|
||||||
|
spatially contiguous rectangular regions of patches (aligned with patches).
|
||||||
|
"""
|
||||||
|
if random.random() > self.erase_prob:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine grid dimensions from coordinates
|
||||||
|
if patch_valid is not None:
|
||||||
|
valid_coord = patch_coord[patch_valid]
|
||||||
|
if len(valid_coord) == 0:
|
||||||
|
return # No valid patches
|
||||||
|
max_y = valid_coord[:, 0].max().item() + 1
|
||||||
|
max_x = valid_coord[:, 1].max().item() + 1
|
||||||
|
else:
|
||||||
|
max_y = patch_coord[:, 0].max().item() + 1
|
||||||
|
max_x = patch_coord[:, 1].max().item() + 1
|
||||||
|
|
||||||
|
grid_h, grid_w = max_y, max_x
|
||||||
|
|
||||||
|
# Calculate total area
|
||||||
|
total_area = grid_h * grid_w
|
||||||
|
|
||||||
|
count = random.randint(self.min_count, self.max_count)
|
||||||
|
for _ in range(count):
|
||||||
|
# Try to select a valid region to erase (multiple attempts)
|
||||||
|
for attempt in range(10):
|
||||||
|
# Sample random area and aspect ratio
|
||||||
|
target_area = random.uniform(self.min_area, self.max_area) * total_area
|
||||||
|
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||||
|
|
||||||
|
# Calculate region height and width
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
|
||||||
|
# Ensure region fits within grid
|
||||||
|
if w <= grid_w and h <= grid_h:
|
||||||
|
# Select random top-left corner
|
||||||
|
top = random.randint(0, grid_h - h)
|
||||||
|
left = random.randint(0, grid_w - w)
|
||||||
|
|
||||||
|
# Define region bounds
|
||||||
|
bottom = top + h
|
||||||
|
right = left + w
|
||||||
|
|
||||||
|
# Create a single random value for all affected patches if using 'rand' mode
|
||||||
|
if self.erase_mode == 'rand':
|
||||||
|
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
|
||||||
|
else:
|
||||||
|
random_value = None
|
||||||
|
|
||||||
|
# Find and erase all patches that fall within the region
|
||||||
|
for i in range(len(patches)):
|
||||||
|
if patch_valid is None or patch_valid[i]:
|
||||||
|
y, x = patch_coord[i]
|
||||||
|
if top <= y < bottom and left <= x < right:
|
||||||
|
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
|
||||||
|
|
||||||
|
# Successfully applied erasing, exit the loop
|
||||||
|
break
|
||||||
|
|
||||||
|
def _erase_subregion(
|
||||||
|
self,
|
||||||
|
patches: torch.Tensor,
|
||||||
|
patch_coord: torch.Tensor,
|
||||||
|
patch_valid: torch.Tensor,
|
||||||
|
patch_shape: torch.Size,
|
||||||
|
patch_size: Tuple[int, int],
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
|
||||||
|
|
||||||
|
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
|
||||||
|
regions that are not aligned to patches (erase regions boundaries cut within patches).
|
||||||
|
|
||||||
|
FIXME complexity probably not worth it, may remove.
|
||||||
|
"""
|
||||||
|
if random.random() > self.erase_prob:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get patch dimensions
|
||||||
|
patch_h, patch_w = patch_size
|
||||||
|
channels = patch_shape[-1]
|
||||||
|
|
||||||
|
# Determine grid dimensions in patch coordinates
|
||||||
|
if patch_valid is not None:
|
||||||
|
valid_coord = patch_coord[patch_valid]
|
||||||
|
if len(valid_coord) == 0:
|
||||||
|
return # No valid patches
|
||||||
|
max_y = valid_coord[:, 0].max().item() + 1
|
||||||
|
max_x = valid_coord[:, 1].max().item() + 1
|
||||||
|
else:
|
||||||
|
max_y = patch_coord[:, 0].max().item() + 1
|
||||||
|
max_x = patch_coord[:, 1].max().item() + 1
|
||||||
|
|
||||||
|
grid_h, grid_w = max_y, max_x
|
||||||
|
|
||||||
|
# Calculate total area in pixel space
|
||||||
|
total_area = (grid_h * patch_h) * (grid_w * patch_w)
|
||||||
|
|
||||||
|
count = random.randint(self.min_count, self.max_count)
|
||||||
|
for _ in range(count):
|
||||||
|
# Try to select a valid region to erase (multiple attempts)
|
||||||
|
for attempt in range(10):
|
||||||
|
# Sample random area and aspect ratio
|
||||||
|
target_area = random.uniform(self.min_area, self.max_area) * total_area
|
||||||
|
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||||
|
|
||||||
|
# Calculate region height and width in pixel space
|
||||||
|
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
|
||||||
|
# Ensure region fits within total pixel grid
|
||||||
|
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
|
||||||
|
# Select random top-left corner in pixel space
|
||||||
|
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
|
||||||
|
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
|
||||||
|
|
||||||
|
# Define region bounds in pixel space
|
||||||
|
pixel_bottom = pixel_top + pixel_h
|
||||||
|
pixel_right = pixel_left + pixel_w
|
||||||
|
|
||||||
|
# Create a single random value for the entire region if using 'rand' mode
|
||||||
|
rand_value = None
|
||||||
|
if self.erase_mode == 'rand':
|
||||||
|
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
|
||||||
|
|
||||||
|
# For each valid patch, determine if and how it overlaps with the erase region
|
||||||
|
for i in range(len(patches)):
|
||||||
|
if patch_valid is None or patch_valid[i]:
|
||||||
|
# Convert patch coordinates to pixel space (top-left corner)
|
||||||
|
y, x = patch_coord[i]
|
||||||
|
patch_pixel_top = y * patch_h
|
||||||
|
patch_pixel_left = x * patch_w
|
||||||
|
patch_pixel_bottom = patch_pixel_top + patch_h
|
||||||
|
patch_pixel_right = patch_pixel_left + patch_w
|
||||||
|
|
||||||
|
# Check if this patch overlaps with the erase region
|
||||||
|
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
|
||||||
|
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
|
||||||
|
|
||||||
|
# Calculate the overlap region in patch-local coordinates
|
||||||
|
local_top = max(0, pixel_top - patch_pixel_top)
|
||||||
|
local_left = max(0, pixel_left - patch_pixel_left)
|
||||||
|
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
|
||||||
|
local_right = min(patch_w, pixel_right - patch_pixel_left)
|
||||||
|
|
||||||
|
# Reshape the patch to [patch_h, patch_w, chans]
|
||||||
|
patch_data = patches[i].reshape(patch_h, patch_w, channels)
|
||||||
|
|
||||||
|
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
|
||||||
|
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
|
||||||
|
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
|
||||||
|
|
||||||
|
# Flatten the patch back to [patch_h*patch_w, chans]
|
||||||
|
if len(patch_shape) == 2:
|
||||||
|
patch_data = patch_data.reshape(-1, channels)
|
||||||
|
patches[i] = patch_data
|
||||||
|
|
||||||
|
# Successfully applied erasing, exit the loop
|
||||||
|
break
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
patches: torch.Tensor,
|
||||||
|
patch_coord: torch.Tensor,
|
||||||
|
patch_valid: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply random patch erasing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patches: Tensor of shape [B, N, P*P, C]
|
||||||
|
patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates
|
||||||
|
patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid
|
||||||
|
If None, all patches are considered valid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Erased patches tensor of same shape
|
||||||
|
"""
|
||||||
|
if patches.ndim == 4:
|
||||||
|
batch_size, num_patches, patch_dim, channels = patches.shape
|
||||||
|
if self.patch_size is not None:
|
||||||
|
patch_size = self.patch_size
|
||||||
|
else:
|
||||||
|
patch_size = None
|
||||||
|
elif patches.ndim == 5:
|
||||||
|
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
|
||||||
|
patch_size = (patch_h, patch_w)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
patch_shape = patches.shape[2:]
|
||||||
|
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
|
||||||
|
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
|
||||||
|
|
||||||
|
# Create default valid mask if not provided
|
||||||
|
if patch_valid is None:
|
||||||
|
patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device)
|
||||||
|
|
||||||
|
# Skip the first part of the batch if num_splits is set
|
||||||
|
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
|
||||||
|
|
||||||
|
# Apply erasing to each batch element
|
||||||
|
for i in range(batch_start, batch_size):
|
||||||
|
if self.patch_drop_prob:
|
||||||
|
assert False, "WIP, not completed"
|
||||||
|
self._drop_patches(
|
||||||
|
patches[i],
|
||||||
|
patch_coord[i],
|
||||||
|
patch_valid[i],
|
||||||
|
)
|
||||||
|
elif self.spatial_mode == 'patch':
|
||||||
|
self._erase_patches(
|
||||||
|
patches[i],
|
||||||
|
patch_coord[i],
|
||||||
|
patch_valid[i],
|
||||||
|
patch_shape,
|
||||||
|
patches.dtype
|
||||||
|
)
|
||||||
|
elif self.spatial_mode == 'region':
|
||||||
|
self._erase_region(
|
||||||
|
patches[i],
|
||||||
|
patch_coord[i],
|
||||||
|
patch_valid[i],
|
||||||
|
patch_shape,
|
||||||
|
patches.dtype
|
||||||
|
)
|
||||||
|
elif self.spatial_mode == 'subregion':
|
||||||
|
self._erase_subregion(
|
||||||
|
patches[i],
|
||||||
|
patch_coord[i],
|
||||||
|
patch_valid[i],
|
||||||
|
patch_shape,
|
||||||
|
patch_size,
|
||||||
|
patches.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return patches
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}'
|
||||||
|
fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))'
|
||||||
|
fs += f', count=({self.min_count}, {self.max_count}))'
|
||||||
|
return fs
|
@ -132,9 +132,13 @@ def transforms_imagenet_train(
|
|||||||
|
|
||||||
primary_tfl = []
|
primary_tfl = []
|
||||||
if naflex:
|
if naflex:
|
||||||
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||||
|
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||||
primary_tfl += [RandomResizedCropToSequence(
|
primary_tfl += [RandomResizedCropToSequence(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
|
scale=scale,
|
||||||
|
ratio=ratio,
|
||||||
interpolation=interpolation
|
interpolation=interpolation
|
||||||
)]
|
)]
|
||||||
else:
|
else:
|
||||||
|
63
train.py
63
train.py
@ -697,32 +697,6 @@ def main():
|
|||||||
trust_remote_code=args.dataset_trust_remote_code,
|
trust_remote_code=args.dataset_trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup mixup / cutmix
|
|
||||||
collate_fn = None
|
|
||||||
mixup_fn = None
|
|
||||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
|
||||||
if mixup_active:
|
|
||||||
assert not args.naflex_loader, "Mixup/Cutmix not currently supported for NaFlex loading."
|
|
||||||
mixup_args = dict(
|
|
||||||
mixup_alpha=args.mixup,
|
|
||||||
cutmix_alpha=args.cutmix,
|
|
||||||
cutmix_minmax=args.cutmix_minmax,
|
|
||||||
prob=args.mixup_prob,
|
|
||||||
switch_prob=args.mixup_switch_prob,
|
|
||||||
mode=args.mixup_mode,
|
|
||||||
label_smoothing=args.smoothing,
|
|
||||||
num_classes=args.num_classes
|
|
||||||
)
|
|
||||||
if args.prefetcher:
|
|
||||||
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
|
|
||||||
collate_fn = FastCollateMixup(**mixup_args)
|
|
||||||
else:
|
|
||||||
mixup_fn = Mixup(**mixup_args)
|
|
||||||
|
|
||||||
# wrap dataset in AugMix helper
|
|
||||||
if num_aug_splits > 1:
|
|
||||||
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
|
||||||
|
|
||||||
# create data loaders w/ augmentation pipeline
|
# create data loaders w/ augmentation pipeline
|
||||||
train_interpolation = args.train_interpolation
|
train_interpolation = args.train_interpolation
|
||||||
if args.no_aug or not train_interpolation:
|
if args.no_aug or not train_interpolation:
|
||||||
@ -764,22 +738,59 @@ def main():
|
|||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mixup_fn = None
|
||||||
|
mixup_args = {}
|
||||||
|
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||||
|
if mixup_active:
|
||||||
|
mixup_args = dict(
|
||||||
|
mixup_alpha=args.mixup,
|
||||||
|
cutmix_alpha=args.cutmix,
|
||||||
|
cutmix_minmax=args.cutmix_minmax,
|
||||||
|
prob=args.mixup_prob,
|
||||||
|
switch_prob=args.mixup_switch_prob,
|
||||||
|
mode=args.mixup_mode,
|
||||||
|
label_smoothing=args.smoothing,
|
||||||
|
num_classes=args.num_classes
|
||||||
|
)
|
||||||
|
|
||||||
naflex_mode = False
|
naflex_mode = False
|
||||||
if args.naflex_loader:
|
if args.naflex_loader:
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('Using NaFlex loader')
|
_logger.info('Using NaFlex loader')
|
||||||
|
|
||||||
|
assert num_aug_splits <= 1, 'Augmentation splits not supported in NaFlex mode'
|
||||||
|
naflex_mixup_fn = None
|
||||||
|
if mixup_active:
|
||||||
|
from timm.data import NaFlexMixup
|
||||||
|
mixup_args.pop('mode') # not supported
|
||||||
|
mixup_args.pop('cutmix_minmax') # not supported
|
||||||
|
naflex_mixup_fn = NaFlexMixup(**mixup_args)
|
||||||
|
|
||||||
naflex_mode = True
|
naflex_mode = True
|
||||||
loader_train = create_naflex_loader(
|
loader_train = create_naflex_loader(
|
||||||
dataset=dataset_train,
|
dataset=dataset_train,
|
||||||
patch_size=16, # Could be derived from model config
|
patch_size=16, # Could be derived from model config
|
||||||
train_seq_lens=args.naflex_train_seq_lens,
|
train_seq_lens=args.naflex_train_seq_lens,
|
||||||
|
mixup_fn=naflex_mixup_fn,
|
||||||
rank=args.rank,
|
rank=args.rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
**common_loader_kwargs,
|
**common_loader_kwargs,
|
||||||
**train_loader_kwargs,
|
**train_loader_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# setup mixup / cutmix
|
||||||
|
collate_fn = None
|
||||||
|
if mixup_active:
|
||||||
|
if args.prefetcher:
|
||||||
|
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
|
||||||
|
collate_fn = FastCollateMixup(**mixup_args)
|
||||||
|
else:
|
||||||
|
mixup_fn = Mixup(**mixup_args)
|
||||||
|
|
||||||
|
# wrap dataset in AugMix helper
|
||||||
|
if num_aug_splits > 1:
|
||||||
|
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||||
|
|
||||||
# Use standard loader
|
# Use standard loader
|
||||||
loader_train = create_loader(
|
loader_train = create_loader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user