pytorch-image-models/timm/data/naflex_random_erasing.py

312 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import math
from typing import Optional, Union, Tuple
import torch
class PatchRandomErasing:
"""
Random erasing for patchified images in NaFlex format.
Supports three modes:
1. 'patch': Simple mode that erases randomly selected valid patches
2. 'region': Erases rectangular regions at patch granularity
Args:
erase_prob: Probability that the Random Erasing operation will be performed.
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
min_area: Minimum percentage of valid patches/area to erase.
max_area: Maximum percentage of valid patches/area to erase.
min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
'const' - erase patch is constant color of 0 for all channels
'rand' - erase patch has same random (normal) value across all elements
'pixel' - erase patch has per-element random (normal) values
spatial_mode: Erasing strategy, one of 'patch', 'region', or 'subregion'
patch_size: Size of each patch (required for 'subregion' mode)
num_splits: Number of splits to apply erasing to (0 for all)
device: Computation device
"""
def __init__(
self,
erase_prob: float = 0.5,
patch_drop_prob: float = 0.0,
min_count: int = 1,
max_count: Optional[int] = None,
min_area: float = 0.02,
max_area: float = 1 / 3,
min_aspect: float = 0.3,
max_aspect: Optional[float] = None,
mode: str = 'const',
value: float = 0.,
spatial_mode: str = 'region',
num_splits: int = 0,
device: Union[str, torch.device] = 'cuda',
):
self.erase_prob = erase_prob
self.patch_drop_prob = patch_drop_prob
self.min_count = min_count
self.max_count = max_count or min_count
self.min_area = min_area
self.max_area = max_area
# Aspect ratio params (for region mode)
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
# Number of splits
self.num_splits = num_splits
self.device = device
# Strategy mode
self.spatial_mode = spatial_mode
assert self.spatial_mode in ('patch', 'region')
# Value generation mode flags
self.erase_mode = mode.lower()
assert self.erase_mode in ('rand', 'pixel', 'const')
self.const_value = value
self.unique_noise_per_patch = True
def _get_values(
self,
shape: Union[Tuple[int,...], torch.Size],
value: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
device: Optional[Union[str, torch.device]] = None
):
"""Generate values for erased patches based on the specified mode.
Args:
shape: Shape of patches to erase.
value: Value to use in const (or rand) mode.
dtype: Data type to use.
device: Device to use.
"""
device = device or self.device
if self.erase_mode == 'pixel':
# only mode with erase shape that includes pixels
return torch.empty(shape, dtype=dtype, device=device).normal_()
else:
shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1])
if self.erase_mode == 'const' or value is not None:
erase_value = value or self.const_value
if isinstance(erase_value, (int, float)):
values = torch.full(shape, erase_value, dtype=dtype, device=device)
else:
erase_value = torch.tensor(erase_value, dtype=dtype, device=device)
values = torch.expand_copy(erase_value, shape)
else:
values = torch.empty(shape, dtype=dtype, device=device).normal_()
return values
def _drop_patches(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
):
""" Patch Dropout
Fully drops patches from datastream. Only mode that saves compute BUT requires support
for non-contiguous patches and associated patch coordinate and valid handling.
"""
# FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches
if random.random() > self.erase_prob:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
# Skip if no valid patches
if not valid_indices:
return patches, patch_coord, patch_valid
num_valid = len(valid_indices)
if self.patch_drop_prob:
# patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model)
num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob)))
keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep]
# maintain patch order, possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:]))
return patches, patch_coord, patch_valid
def _erase_patches(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting individual patches randomly.
The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles'
noise augmentation at patch size.
"""
if random.random() > self.erase_prob:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
num_valid = len(valid_indices)
if num_valid == 0:
return
count = random.randint(self.min_count, self.max_count)
# Determine how many valid patches to erase from RE min/max count and area args
max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
min_erase = max(1, int(num_valid * count * self.min_area))
num_erase = random.randint(min_erase, max_erase)
# Randomly select valid patches to erase
erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
# generate unique noise for the whole selection of patches
fill_shape = (num_erase,) + patch_shape
else:
fill_shape = patch_shape
patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
def _erase_region(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions of patches randomly
Closer to the original RandomErasing implementation. Erases
spatially contiguous rectangular regions of patches (aligned with patches).
"""
if random.random() > self.erase_prob:
return
# Determine grid dimensions from coordinates
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
total_area = grid_h * grid_w
ys, xs = patch_coord[:, 0], patch_coord[:, 1]
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if h > grid_h or w > grid_w:
continue # try again
# Calculate region patch bounds
top = random.randint(0, grid_h - h)
left = random.randint(0, grid_w - w)
bottom, right = top + h, left + w
# Region test
region_mask = (
(ys >= top) & (ys < bottom) &
(xs >= left) & (xs < right) &
patch_valid
)
num_selected = int(region_mask.sum().item())
if not num_selected:
continue # no patch actually falls inside try again
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
# generate unique noise for the whole region
fill_shape = (num_selected,) + patch_shape
else:
fill_shape = patch_shape
patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
# Successfully applied erasing, exit the loop
break
def __call__(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: Optional[torch.Tensor] = None,
):
"""
Apply random patch erasing.
Args:
patches: Tensor of shape [B, N, P*P, C]
patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates
patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid
If None, all patches are considered valid
Returns:
Erased patches tensor of same shape
"""
if patches.ndim == 4:
batch_size, num_patches, patch_dim, channels = patches.shape
elif patches.ndim == 5:
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
else:
assert False
patch_shape = patches.shape[2:]
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
# Create default valid mask if not provided
if patch_valid is None:
patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device)
# Skip the first part of the batch if num_splits is set
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
# Apply erasing to each batch element
for i in range(batch_start, batch_size):
if self.patch_drop_prob:
assert False, "WIP, not completed"
self._drop_patches(
patches[i],
patch_coord[i],
patch_valid[i],
)
elif self.spatial_mode == 'patch':
# FIXME we could vectorize patch mode across batch, worth the effort?
self._erase_patches(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patches.dtype
)
elif self.spatial_mode == 'region':
self._erase_region(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patches.dtype
)
else:
assert False
return patches
def __repr__(self):
fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}'
fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))'
fs += f', count=({self.min_count}, {self.max_count}))'
return fs