NaFlex random erasing performance improvements, python loops were slow. Remove subregion mode, not going to be worth it.

This commit is contained in:
Ross Wightman 2025-05-20 17:03:46 -07:00
parent 7624389fc9
commit f001b15ed3

View File

@ -11,17 +11,15 @@ class PatchRandomErasing:
Supports three modes: Supports three modes:
1. 'patch': Simple mode that erases randomly selected valid patches 1. 'patch': Simple mode that erases randomly selected valid patches
2. 'region': Erases spatial regions at patch granularity 2. 'region': Erases rectangular regions at patch granularity
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
partially erasing patches that are on the boundary of the erased region
Args: Args:
erase_prob: Probability that the Random Erasing operation will be performed. erase_prob: Probability that the Random Erasing operation will be performed.
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing. patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
min_area: Minimum percentage of valid patches/area to erase. min_area: Minimum percentage of valid patches/area to erase.
max_area: Maximum percentage of valid patches/area to erase. max_area: Maximum percentage of valid patches/area to erase.
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode). min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode). max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
mode: Patch content mode, one of 'const', 'rand', or 'pixel' mode: Patch content mode, one of 'const', 'rand', or 'pixel'
'const' - erase patch is constant color of 0 for all channels 'const' - erase patch is constant color of 0 for all channels
'rand' - erase patch has same random (normal) value across all elements 'rand' - erase patch has same random (normal) value across all elements
@ -45,7 +43,6 @@ class PatchRandomErasing:
mode: str = 'const', mode: str = 'const',
value: float = 0., value: float = 0.,
spatial_mode: str = 'region', spatial_mode: str = 'region',
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
num_splits: int = 0, num_splits: int = 0,
device: Union[str, torch.device] = 'cuda', device: Union[str, torch.device] = 'cuda',
): ):
@ -66,14 +63,13 @@ class PatchRandomErasing:
# Strategy mode # Strategy mode
self.spatial_mode = spatial_mode self.spatial_mode = spatial_mode
assert self.spatial_mode in ('patch', 'region')
# Patch size (needed for subregion mode)
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
# Value generation mode flags # Value generation mode flags
self.erase_mode = mode.lower() self.erase_mode = mode.lower()
assert self.erase_mode in ('rand', 'pixel', 'const') assert self.erase_mode in ('rand', 'pixel', 'const')
self.const_value = value self.const_value = value
self.unique_noise_per_patch = True
def _get_values( def _get_values(
self, self,
@ -156,27 +152,27 @@ class PatchRandomErasing:
return return
# Get indices of valid patches # Get indices of valid patches
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist() valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
if not valid_indices: num_valid = len(valid_indices)
# Skip if no valid patches if num_valid == 0:
return return
num_valid = len(valid_indices)
count = random.randint(self.min_count, self.max_count) count = random.randint(self.min_count, self.max_count)
# Determine how many valid patches to erase from RE min/max count and area args # Determine how many valid patches to erase from RE min/max count and area args
max_erase = max(1, int(num_valid * count * self.max_area)) max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
min_erase = max(1, int(num_valid * count * self.min_area)) min_erase = max(1, int(num_valid * count * self.min_area))
num_erase = random.randint(min_erase, max_erase) num_erase = random.randint(min_erase, max_erase)
# Randomly select valid patches to erase # Randomly select valid patches to erase
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid)) erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
random_value = None if self.unique_noise_per_patch and self.erase_mode == 'pixel':
if self.erase_mode == 'rand': # generate unique noise for the whole selection of patches
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_() fill_shape = (num_erase,) + patch_shape
else:
fill_shape = patch_shape
for idx in indices_to_erase: patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
def _erase_region( def _erase_region(
self, self,
@ -195,20 +191,14 @@ class PatchRandomErasing:
return return
# Determine grid dimensions from coordinates # Determine grid dimensions from coordinates
if patch_valid is not None: valid_coord = patch_coord[patch_valid]
valid_coord = patch_coord[patch_valid] if len(valid_coord) == 0:
if len(valid_coord) == 0: return # No valid patches
return # No valid patches max_y = valid_coord[:, 0].max().item() + 1
max_y = valid_coord[:, 0].max().item() + 1 max_x = valid_coord[:, 1].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x grid_h, grid_w = max_y, max_x
# Calculate total area
total_area = grid_h * grid_w total_area = grid_h * grid_w
ys, xs = patch_coord[:, 0], patch_coord[:, 1]
count = random.randint(self.min_count, self.max_count) count = random.randint(self.min_count, self.max_count)
for _ in range(count): for _ in range(count):
@ -222,132 +212,33 @@ class PatchRandomErasing:
h = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within grid if h > grid_h or w > grid_w:
if w <= grid_w and h <= grid_h: continue # try again
# Select random top-left corner
top = random.randint(0, grid_h - h)
left = random.randint(0, grid_w - w)
# Define region bounds # Calculate region patch bounds
bottom = top + h top = random.randint(0, grid_h - h)
right = left + w left = random.randint(0, grid_w - w)
bottom, right = top + h, left + w
# Create a single random value for all affected patches if using 'rand' mode # Region test
if self.erase_mode == 'rand': region_mask = (
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_() (ys >= top) & (ys < bottom) &
else: (xs >= left) & (xs < right) &
random_value = None patch_valid
)
num_selected = int(region_mask.sum().item())
if not num_selected:
continue # no patch actually falls inside try again
# Find and erase all patches that fall within the region if self.unique_noise_per_patch and self.erase_mode == 'pixel':
for i in range(len(patches)): # generate unique noise for the whole region
if patch_valid is None or patch_valid[i]: fill_shape = (num_selected,) + patch_shape
y, x = patch_coord[i] else:
if top <= y < bottom and left <= x < right: fill_shape = patch_shape
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
# Successfully applied erasing, exit the loop patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
break # Successfully applied erasing, exit the loop
break
def _erase_subregion(
self,
patches: torch.Tensor,
patch_coord: torch.Tensor,
patch_valid: torch.Tensor,
patch_shape: torch.Size,
patch_size: Tuple[int, int],
dtype: torch.dtype = torch.float32,
):
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
regions that are not aligned to patches (erase regions boundaries cut within patches).
FIXME complexity probably not worth it, may remove.
"""
if random.random() > self.erase_prob:
return
# Get patch dimensions
patch_h, patch_w = patch_size
channels = patch_shape[-1]
# Determine grid dimensions in patch coordinates
if patch_valid is not None:
valid_coord = patch_coord[patch_valid]
if len(valid_coord) == 0:
return # No valid patches
max_y = valid_coord[:, 0].max().item() + 1
max_x = valid_coord[:, 1].max().item() + 1
else:
max_y = patch_coord[:, 0].max().item() + 1
max_x = patch_coord[:, 1].max().item() + 1
grid_h, grid_w = max_y, max_x
# Calculate total area in pixel space
total_area = (grid_h * patch_h) * (grid_w * patch_w)
count = random.randint(self.min_count, self.max_count)
for _ in range(count):
# Try to select a valid region to erase (multiple attempts)
for attempt in range(10):
# Sample random area and aspect ratio
target_area = random.uniform(self.min_area, self.max_area) * total_area
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
# Calculate region height and width in pixel space
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
# Ensure region fits within total pixel grid
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
# Select random top-left corner in pixel space
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
# Define region bounds in pixel space
pixel_bottom = pixel_top + pixel_h
pixel_right = pixel_left + pixel_w
# Create a single random value for the entire region if using 'rand' mode
rand_value = None
if self.erase_mode == 'rand':
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
# For each valid patch, determine if and how it overlaps with the erase region
for i in range(len(patches)):
if patch_valid is None or patch_valid[i]:
# Convert patch coordinates to pixel space (top-left corner)
y, x = patch_coord[i]
patch_pixel_top = y * patch_h
patch_pixel_left = x * patch_w
patch_pixel_bottom = patch_pixel_top + patch_h
patch_pixel_right = patch_pixel_left + patch_w
# Check if this patch overlaps with the erase region
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
# Calculate the overlap region in patch-local coordinates
local_top = max(0, pixel_top - patch_pixel_top)
local_left = max(0, pixel_left - patch_pixel_left)
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
local_right = min(patch_w, pixel_right - patch_pixel_left)
# Reshape the patch to [patch_h, patch_w, chans]
patch_data = patches[i].reshape(patch_h, patch_w, channels)
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
# Flatten the patch back to [patch_h*patch_w, chans]
if len(patch_shape) == 2:
patch_data = patch_data.reshape(-1, channels)
patches[i] = patch_data
# Successfully applied erasing, exit the loop
break
def __call__( def __call__(
self, self,
@ -369,18 +260,12 @@ class PatchRandomErasing:
""" """
if patches.ndim == 4: if patches.ndim == 4:
batch_size, num_patches, patch_dim, channels = patches.shape batch_size, num_patches, patch_dim, channels = patches.shape
if self.patch_size is not None:
patch_size = self.patch_size
else:
patch_size = None
elif patches.ndim == 5: elif patches.ndim == 5:
batch_size, num_patches, patch_h, patch_w, channels = patches.shape batch_size, num_patches, patch_h, patch_w, channels = patches.shape
patch_size = (patch_h, patch_w)
else: else:
assert False assert False
patch_shape = patches.shape[2:] patch_shape = patches.shape[2:]
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c) # patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
# Create default valid mask if not provided # Create default valid mask if not provided
if patch_valid is None: if patch_valid is None:
@ -399,6 +284,7 @@ class PatchRandomErasing:
patch_valid[i], patch_valid[i],
) )
elif self.spatial_mode == 'patch': elif self.spatial_mode == 'patch':
# FIXME we could vectorize patch mode across batch, worth the effort?
self._erase_patches( self._erase_patches(
patches[i], patches[i],
patch_coord[i], patch_coord[i],
@ -414,15 +300,8 @@ class PatchRandomErasing:
patch_shape, patch_shape,
patches.dtype patches.dtype
) )
elif self.spatial_mode == 'subregion': else:
self._erase_subregion( assert False
patches[i],
patch_coord[i],
patch_valid[i],
patch_shape,
patch_size,
patches.dtype
)
return patches return patches