NaFlex random erasing performance improvements, python loops were slow. Remove subregion mode, not going to be worth it.

This commit is contained in:
Ross Wightman 2025-05-20 17:03:46 -07:00
parent 7624389fc9
commit f001b15ed3

View File

@ -11,17 +11,15 @@ class PatchRandomErasing:
Supports three modes:
1. 'patch': Simple mode that erases randomly selected valid patches
2. 'region': Erases spatial regions at patch granularity
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
partially erasing patches that are on the boundary of the erased region
2. 'region': Erases rectangular regions at patch granularity
Args:
erase_prob: Probability that the Random Erasing operation will be performed.
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
min_area: Minimum percentage of valid patches/area to erase.
max_area: Maximum percentage of valid patches/area to erase.
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
'const' - erase patch is constant color of 0 for all channels
'rand' - erase patch has same random (normal) value across all elements
@ -45,7 +43,6 @@ class PatchRandomErasing:
mode: str = 'const',
value: float = 0.,
spatial_mode: str = 'region',
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
num_splits: int = 0,
device: Union[str, torch.device] = 'cuda',
):
@ -66,14 +63,13 @@ class PatchRandomErasing:
# Strategy mode
self.spatial_mode = spatial_mode
# Patch size (needed for subregion mode)
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
assert self.spatial_mode in ('patch', 'region')
# Value generation mode flags
self.erase_mode = mode.lower()
assert self.erase_mode in ('rand', 'pixel', 'const')
self.const_value = value
self.unique_noise_per_patch = True
def _get_values(
self,
@ -156,27 +152,27 @@ class PatchRandomErasing:
return
# Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
if not valid_indices:
# Skip if no valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
num_valid = len(valid_indices)
if num_valid == 0:
return
num_valid = len(valid_indices)
count = random.randint(self.min_count, self.max_count)
# Determine how many valid patches to erase from RE min/max count and area args
max_erase = max(1, int(num_valid * count * self.max_area))
max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
min_erase = max(1, int(num_valid * count * self.min_area))
num_erase = random.randint(min_erase, max_erase)
# Randomly select valid patches to erase
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid))
erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
random_value = None
if self.erase_mode == 'rand':
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
# generate unique noise for the whole selection of patches
fill_shape = (num_erase,) + patch_shape
else:
fill_shape = patch_shape
for idx in indices_to_erase:
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
def _erase_region(
self,
@ -195,20 +191,14 @@ class PatchRandomErasing:
return
# Determine grid dimensions from coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area
total_area = grid_h * grid_w
ys, xs = patch_coord[:, 0], patch_coord[:, 1]
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
@ -222,132 +212,33 @@ class PatchRandomErasing:
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within grid
if w <= grid_w and h <= grid_h:
# Select random top-left corner
top = random.randint(0, grid_h - h)
left = random.randint(0, grid_w - w)
if h > grid_h or w > grid_w:
continue # try again
# Define region bounds
bottom = top + h
right = left + w
# Calculate region patch bounds
top = random.randint(0, grid_h - h)
left = random.randint(0, grid_w - w)
bottom, right = top + h, left + w
# Create a single random value for all affected patches if using 'rand' mode
if self.erase_mode == 'rand':
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
else:
random_value = None
# Region test
region_mask = (
(ys >= top) & (ys < bottom) &
(xs >= left) & (xs < right) &
patch_valid
)
num_selected = int(region_mask.sum().item())
if not num_selected:
continue # no patch actually falls inside try again
# Find and erase all patches that fall within the region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
y, x = patch_coord[i]
if top <= y < bottom and left <= x < right:
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
# generate unique noise for the whole region
fill_shape = (num_selected,) + patch_shape
else:
fill_shape = patch_shape
# Successfully applied erasing, exit the loop
break
def _erase_subregion(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
patch_size: Tuple[int, int],
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
regions that are not aligned to patches (erase regions boundaries cut within patches).
FIXME complexity probably not worth it, may remove.
"""
if random.random() > self.erase_prob:
return
# Get patch dimensions
patch_h, patch_w = patch_size
channels = patch_shape[-1]
# Determine grid dimensions in patch coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area in pixel space
total_area = (grid_h * patch_h) * (grid_w * patch_w)
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width in pixel space
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within total pixel grid
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
# Select random top-left corner in pixel space
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
# Define region bounds in pixel space
pixel_bottom = pixel_top + pixel_h
pixel_right = pixel_left + pixel_w
# Create a single random value for the entire region if using 'rand' mode
rand_value = None
if self.erase_mode == 'rand':
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
# For each valid patch, determine if and how it overlaps with the erase region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
# Convert patch coordinates to pixel space (top-left corner)
y, x = patch_coord[i]
patch_pixel_top = y * patch_h
patch_pixel_left = x * patch_w
patch_pixel_bottom = patch_pixel_top + patch_h
patch_pixel_right = patch_pixel_left + patch_w
# Check if this patch overlaps with the erase region
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
# Calculate the overlap region in patch-local coordinates
local_top = max(0, pixel_top - patch_pixel_top)
local_left = max(0, pixel_left - patch_pixel_left)
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
local_right = min(patch_w, pixel_right - patch_pixel_left)
# Reshape the patch to [patch_h, patch_w, chans]
patch_data = patches[i].reshape(patch_h, patch_w, channels)
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
# Flatten the patch back to [patch_h*patch_w, chans]
if len(patch_shape) == 2:
patch_data = patch_data.reshape(-1, channels)
patches[i] = patch_data
# Successfully applied erasing, exit the loop
break
patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
# Successfully applied erasing, exit the loop
break
def __call__(
self,
@ -369,18 +260,12 @@ class PatchRandomErasing:
"""
if patches.ndim == 4:
batch_size, num_patches, patch_dim, channels = patches.shape
if self.patch_size is not None:
patch_size = self.patch_size
else:
patch_size = None
elif patches.ndim == 5:
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
patch_size = (patch_h, patch_w)
else:
assert False
patch_shape = patches.shape[2:]
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
# Create default valid mask if not provided
if patch_valid is None:
@ -399,6 +284,7 @@ class PatchRandomErasing:
patch_valid[i],
)
elif self.spatial_mode == 'patch':
# FIXME we could vectorize patch mode across batch, worth the effort?
self._erase_patches(
patches[i],
patch_coord[i],
@ -414,15 +300,8 @@ class PatchRandomErasing:
patch_shape,
patches.dtype
)
elif self.spatial_mode == 'subregion':
self._erase_subregion(
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patch_size,
patches.dtype
)
else:
assert False
return patches