pytorch-image-models/timm/data/naflex_random_erasing.py

433 lines
18 KiB
Python
Raw Normal View History

import random
import math
from typing import Optional, Union, Tuple
import torch
class PatchRandomErasing:
"""
Random erasing for patchified images in NaFlex format.
Supports three modes:
1. 'patch': Simple mode that erases randomly selected valid patches
2. 'region': Erases spatial regions at patch granularity
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
partially erasing patches that are on the boundary of the erased region
Args:
erase_prob: Probability that the Random Erasing operation will be performed.
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
min_area: Minimum percentage of valid patches/area to erase.
max_area: Maximum percentage of valid patches/area to erase.
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
'const' - erase patch is constant color of 0 for all channels
'rand' - erase patch has same random (normal) value across all elements
'pixel' - erase patch has per-element random (normal) values
spatial_mode: Erasing strategy, one of 'patch', 'region', or 'subregion'
patch_size: Size of each patch (required for 'subregion' mode)
num_splits: Number of splits to apply erasing to (0 for all)
device: Computation device
"""
def __init__(
self,
erase_prob: float = 0.5,
patch_drop_prob: float = 0.0,
min_count: int = 1,
max_count: Optional[int] = None,
min_area: float = 0.02,
max_area: float = 1 / 3,
min_aspect: float = 0.3,
max_aspect: Optional[float] = None,
mode: str = 'const',
value: float = 0.,
spatial_mode: str = 'region',
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
num_splits: int = 0,
device: Union[str, torch.device] = 'cuda',
):
self.erase_prob = erase_prob
self.patch_drop_prob = patch_drop_prob
self.min_count = min_count
self.max_count = max_count or min_count
self.min_area = min_area
self.max_area = max_area
# Aspect ratio params (for region mode)
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
# Number of splits
self.num_splits = num_splits
self.device = device
# Strategy mode
self.spatial_mode = spatial_mode
# Patch size (needed for subregion mode)
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
# Value generation mode flags
self.erase_mode = mode.lower()
assert self.erase_mode in ('rand', 'pixel', 'const')
self.const_value = value
def _get_values(
self,
shape: Union[Tuple[int,...], torch.Size],
value: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
device: Optional[Union[str, torch.device]] = None
):
"""Generate values for erased patches based on the specified mode.
Args:
shape: Shape of patches to erase.
value: Value to use in const (or rand) mode.
dtype: Data type to use.
device: Device to use.
"""
device = device or self.device
if self.erase_mode == 'pixel':
# only mode with erase shape that includes pixels
return torch.empty(shape, dtype=dtype, device=device).normal_()
else:
shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1])
if self.erase_mode == 'const' or value is not None:
erase_value = value or self.const_value
if isinstance(erase_value, (int, float)):
values = torch.full(shape, erase_value, dtype=dtype, device=device)
else:
erase_value = torch.tensor(erase_value, dtype=dtype, device=device)
values = torch.expand_copy(erase_value, shape)
else:
values = torch.empty(shape, dtype=dtype, device=device).normal_()
return values
def _drop_patches(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
):
""" Patch Dropout
Fully drops patches from datastream. Only mode that saves compute BUT requires support
for non-contiguous patches and associated patch coordinate and valid handling.
"""
# FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches
if random.random() > self.erase_prob:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
# Skip if no valid patches
if not valid_indices:
return patches, patch_coord, patch_valid
num_valid = len(valid_indices)
if self.patch_drop_prob:
# patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model)
num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob)))
keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep]
# maintain patch order, possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:]))
return patches, patch_coord, patch_valid
def _erase_patches(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting individual patches randomly.
The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles'
noise augmentation at patch size.
"""
if random.random() > self.erase_prob:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
if not valid_indices:
# Skip if no valid patches
return
num_valid = len(valid_indices)
count = random.randint(self.min_count, self.max_count)
# Determine how many valid patches to erase from RE min/max count and area args
max_erase = max(1, int(num_valid * count * self.max_area))
min_erase = max(1, int(num_valid * count * self.min_area))
num_erase = random.randint(min_erase, max_erase)
# Randomly select valid patches to erase
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid))
random_value = None
if self.erase_mode == 'rand':
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
for idx in indices_to_erase:
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
def _erase_region(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions of patches randomly
Closer to the original RandomErasing implementation. Erases
spatially contiguous rectangular regions of patches (aligned with patches).
"""
if random.random() > self.erase_prob:
return
# Determine grid dimensions from coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area
total_area = grid_h * grid_w
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within grid
if w <= grid_w and h <= grid_h:
# Select random top-left corner
top = random.randint(0, grid_h - h)
left = random.randint(0, grid_w - w)
# Define region bounds
bottom = top + h
right = left + w
# Create a single random value for all affected patches if using 'rand' mode
if self.erase_mode == 'rand':
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
else:
random_value = None
# Find and erase all patches that fall within the region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
y, x = patch_coord[i]
if top <= y < bottom and left <= x < right:
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
# Successfully applied erasing, exit the loop
break
def _erase_subregion(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
patch_size: Tuple[int, int],
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
regions that are not aligned to patches (erase regions boundaries cut within patches).
FIXME complexity probably not worth it, may remove.
"""
if random.random() > self.erase_prob:
return
# Get patch dimensions
patch_h, patch_w = patch_size
channels = patch_shape[-1]
# Determine grid dimensions in patch coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area in pixel space
total_area = (grid_h * patch_h) * (grid_w * patch_w)
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width in pixel space
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within total pixel grid
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
# Select random top-left corner in pixel space
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
# Define region bounds in pixel space
pixel_bottom = pixel_top + pixel_h
pixel_right = pixel_left + pixel_w
# Create a single random value for the entire region if using 'rand' mode
rand_value = None
if self.erase_mode == 'rand':
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
# For each valid patch, determine if and how it overlaps with the erase region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
# Convert patch coordinates to pixel space (top-left corner)
y, x = patch_coord[i]
patch_pixel_top = y * patch_h
patch_pixel_left = x * patch_w
patch_pixel_bottom = patch_pixel_top + patch_h
patch_pixel_right = patch_pixel_left + patch_w
# Check if this patch overlaps with the erase region
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
# Calculate the overlap region in patch-local coordinates
local_top = max(0, pixel_top - patch_pixel_top)
local_left = max(0, pixel_left - patch_pixel_left)
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
local_right = min(patch_w, pixel_right - patch_pixel_left)
# Reshape the patch to [patch_h, patch_w, chans]
patch_data = patches[i].reshape(patch_h, patch_w, channels)
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
# Flatten the patch back to [patch_h*patch_w, chans]
if len(patch_shape) == 2:
patch_data = patch_data.reshape(-1, channels)
patches[i] = patch_data
# Successfully applied erasing, exit the loop
break
def __call__(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: Optional[torch.Tensor] = None,
):
"""
Apply random patch erasing.
Args:
patches: Tensor of shape [B, N, P*P, C]
patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates
patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid
If None, all patches are considered valid
Returns:
Erased patches tensor of same shape
"""
if patches.ndim == 4:
batch_size, num_patches, patch_dim, channels = patches.shape
if self.patch_size is not None:
patch_size = self.patch_size
else:
patch_size = None
elif patches.ndim == 5:
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
patch_size = (patch_h, patch_w)
else:
assert False
patch_shape = patches.shape[2:]
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
# Create default valid mask if not provided
if patch_valid is None:
patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device)
# Skip the first part of the batch if num_splits is set
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
# Apply erasing to each batch element
for i in range(batch_start, batch_size):
if self.patch_drop_prob:
assert False, "WIP, not completed"
self._drop_patches(
patches[i],
patch_coord[i],
patch_valid[i],
)
elif self.spatial_mode == 'patch':
self._erase_patches(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patches.dtype
)
elif self.spatial_mode == 'region':
self._erase_region(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patches.dtype
)
elif self.spatial_mode == 'subregion':
self._erase_subregion(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patch_size,
patches.dtype
)
return patches
def __repr__(self):
fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}'
fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))'
fs += f', count=({self.min_count}, {self.max_count}))'
return fs