mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
NaFlex random erasing performance improvements, python loops were slow. Remove subregion mode, not going to be worth it.
This commit is contained in:
parent
7624389fc9
commit
f001b15ed3
@ -11,17 +11,15 @@ class PatchRandomErasing:
|
|||||||
|
|
||||||
Supports three modes:
|
Supports three modes:
|
||||||
1. 'patch': Simple mode that erases randomly selected valid patches
|
1. 'patch': Simple mode that erases randomly selected valid patches
|
||||||
2. 'region': Erases spatial regions at patch granularity
|
2. 'region': Erases rectangular regions at patch granularity
|
||||||
3. 'subregion': Most sophisticated mode that erases spatial regions at sub-patch granularity,
|
|
||||||
partially erasing patches that are on the boundary of the erased region
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
erase_prob: Probability that the Random Erasing operation will be performed.
|
erase_prob: Probability that the Random Erasing operation will be performed.
|
||||||
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
|
patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
|
||||||
min_area: Minimum percentage of valid patches/area to erase.
|
min_area: Minimum percentage of valid patches/area to erase.
|
||||||
max_area: Maximum percentage of valid patches/area to erase.
|
max_area: Maximum percentage of valid patches/area to erase.
|
||||||
min_aspect: Minimum aspect ratio of erased area (only used in 'region'/'subregion' mode).
|
min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
|
||||||
max_aspect: Maximum aspect ratio of erased area (only used in 'region'/'subregion' mode).
|
max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
|
||||||
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
|
mode: Patch content mode, one of 'const', 'rand', or 'pixel'
|
||||||
'const' - erase patch is constant color of 0 for all channels
|
'const' - erase patch is constant color of 0 for all channels
|
||||||
'rand' - erase patch has same random (normal) value across all elements
|
'rand' - erase patch has same random (normal) value across all elements
|
||||||
@ -45,7 +43,6 @@ class PatchRandomErasing:
|
|||||||
mode: str = 'const',
|
mode: str = 'const',
|
||||||
value: float = 0.,
|
value: float = 0.,
|
||||||
spatial_mode: str = 'region',
|
spatial_mode: str = 'region',
|
||||||
patch_size: Optional[Union[int, Tuple[int, int]]] = 16,
|
|
||||||
num_splits: int = 0,
|
num_splits: int = 0,
|
||||||
device: Union[str, torch.device] = 'cuda',
|
device: Union[str, torch.device] = 'cuda',
|
||||||
):
|
):
|
||||||
@ -66,14 +63,13 @@ class PatchRandomErasing:
|
|||||||
|
|
||||||
# Strategy mode
|
# Strategy mode
|
||||||
self.spatial_mode = spatial_mode
|
self.spatial_mode = spatial_mode
|
||||||
|
assert self.spatial_mode in ('patch', 'region')
|
||||||
# Patch size (needed for subregion mode)
|
|
||||||
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
|
||||||
|
|
||||||
# Value generation mode flags
|
# Value generation mode flags
|
||||||
self.erase_mode = mode.lower()
|
self.erase_mode = mode.lower()
|
||||||
assert self.erase_mode in ('rand', 'pixel', 'const')
|
assert self.erase_mode in ('rand', 'pixel', 'const')
|
||||||
self.const_value = value
|
self.const_value = value
|
||||||
|
self.unique_noise_per_patch = True
|
||||||
|
|
||||||
def _get_values(
|
def _get_values(
|
||||||
self,
|
self,
|
||||||
@ -156,27 +152,27 @@ class PatchRandomErasing:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get indices of valid patches
|
# Get indices of valid patches
|
||||||
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
|
valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
|
||||||
if not valid_indices:
|
num_valid = len(valid_indices)
|
||||||
# Skip if no valid patches
|
if num_valid == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
num_valid = len(valid_indices)
|
|
||||||
count = random.randint(self.min_count, self.max_count)
|
count = random.randint(self.min_count, self.max_count)
|
||||||
# Determine how many valid patches to erase from RE min/max count and area args
|
# Determine how many valid patches to erase from RE min/max count and area args
|
||||||
max_erase = max(1, int(num_valid * count * self.max_area))
|
max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
|
||||||
min_erase = max(1, int(num_valid * count * self.min_area))
|
min_erase = max(1, int(num_valid * count * self.min_area))
|
||||||
num_erase = random.randint(min_erase, max_erase)
|
num_erase = random.randint(min_erase, max_erase)
|
||||||
|
|
||||||
# Randomly select valid patches to erase
|
# Randomly select valid patches to erase
|
||||||
indices_to_erase = random.sample(valid_indices, min(num_erase, num_valid))
|
erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
|
||||||
|
|
||||||
random_value = None
|
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
|
||||||
if self.erase_mode == 'rand':
|
# generate unique noise for the whole selection of patches
|
||||||
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
|
fill_shape = (num_erase,) + patch_shape
|
||||||
|
else:
|
||||||
|
fill_shape = patch_shape
|
||||||
|
|
||||||
for idx in indices_to_erase:
|
patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
|
||||||
patches[idx].copy_(self._get_values(patch_shape, dtype=dtype, value=random_value))
|
|
||||||
|
|
||||||
def _erase_region(
|
def _erase_region(
|
||||||
self,
|
self,
|
||||||
@ -195,20 +191,14 @@ class PatchRandomErasing:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Determine grid dimensions from coordinates
|
# Determine grid dimensions from coordinates
|
||||||
if patch_valid is not None:
|
valid_coord = patch_coord[patch_valid]
|
||||||
valid_coord = patch_coord[patch_valid]
|
if len(valid_coord) == 0:
|
||||||
if len(valid_coord) == 0:
|
return # No valid patches
|
||||||
return # No valid patches
|
max_y = valid_coord[:, 0].max().item() + 1
|
||||||
max_y = valid_coord[:, 0].max().item() + 1
|
max_x = valid_coord[:, 1].max().item() + 1
|
||||||
max_x = valid_coord[:, 1].max().item() + 1
|
|
||||||
else:
|
|
||||||
max_y = patch_coord[:, 0].max().item() + 1
|
|
||||||
max_x = patch_coord[:, 1].max().item() + 1
|
|
||||||
|
|
||||||
grid_h, grid_w = max_y, max_x
|
grid_h, grid_w = max_y, max_x
|
||||||
|
|
||||||
# Calculate total area
|
|
||||||
total_area = grid_h * grid_w
|
total_area = grid_h * grid_w
|
||||||
|
ys, xs = patch_coord[:, 0], patch_coord[:, 1]
|
||||||
|
|
||||||
count = random.randint(self.min_count, self.max_count)
|
count = random.randint(self.min_count, self.max_count)
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
@ -222,132 +212,33 @@ class PatchRandomErasing:
|
|||||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
|
||||||
# Ensure region fits within grid
|
if h > grid_h or w > grid_w:
|
||||||
if w <= grid_w and h <= grid_h:
|
continue # try again
|
||||||
# Select random top-left corner
|
|
||||||
top = random.randint(0, grid_h - h)
|
|
||||||
left = random.randint(0, grid_w - w)
|
|
||||||
|
|
||||||
# Define region bounds
|
# Calculate region patch bounds
|
||||||
bottom = top + h
|
top = random.randint(0, grid_h - h)
|
||||||
right = left + w
|
left = random.randint(0, grid_w - w)
|
||||||
|
bottom, right = top + h, left + w
|
||||||
|
|
||||||
# Create a single random value for all affected patches if using 'rand' mode
|
# Region test
|
||||||
if self.erase_mode == 'rand':
|
region_mask = (
|
||||||
random_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
|
(ys >= top) & (ys < bottom) &
|
||||||
else:
|
(xs >= left) & (xs < right) &
|
||||||
random_value = None
|
patch_valid
|
||||||
|
)
|
||||||
|
num_selected = int(region_mask.sum().item())
|
||||||
|
if not num_selected:
|
||||||
|
continue # no patch actually falls inside – try again
|
||||||
|
|
||||||
# Find and erase all patches that fall within the region
|
if self.unique_noise_per_patch and self.erase_mode == 'pixel':
|
||||||
for i in range(len(patches)):
|
# generate unique noise for the whole region
|
||||||
if patch_valid is None or patch_valid[i]:
|
fill_shape = (num_selected,) + patch_shape
|
||||||
y, x = patch_coord[i]
|
else:
|
||||||
if top <= y < bottom and left <= x < right:
|
fill_shape = patch_shape
|
||||||
patches[i] = self._get_values(patch_shape, dtype=dtype, value=random_value)
|
|
||||||
|
|
||||||
# Successfully applied erasing, exit the loop
|
patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
|
||||||
break
|
# Successfully applied erasing, exit the loop
|
||||||
|
break
|
||||||
def _erase_subregion(
|
|
||||||
self,
|
|
||||||
patches: torch.Tensor,
|
|
||||||
patch_coord: torch.Tensor,
|
|
||||||
patch_valid: torch.Tensor,
|
|
||||||
patch_shape: torch.Size,
|
|
||||||
patch_size: Tuple[int, int],
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
):
|
|
||||||
"""Apply erasing by selecting rectangular regions ignoring patch boundaries.
|
|
||||||
|
|
||||||
Matches or original RandomErasing implementation. Erases spatially contiguous rectangular
|
|
||||||
regions that are not aligned to patches (erase regions boundaries cut within patches).
|
|
||||||
|
|
||||||
FIXME complexity probably not worth it, may remove.
|
|
||||||
"""
|
|
||||||
if random.random() > self.erase_prob:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get patch dimensions
|
|
||||||
patch_h, patch_w = patch_size
|
|
||||||
channels = patch_shape[-1]
|
|
||||||
|
|
||||||
# Determine grid dimensions in patch coordinates
|
|
||||||
if patch_valid is not None:
|
|
||||||
valid_coord = patch_coord[patch_valid]
|
|
||||||
if len(valid_coord) == 0:
|
|
||||||
return # No valid patches
|
|
||||||
max_y = valid_coord[:, 0].max().item() + 1
|
|
||||||
max_x = valid_coord[:, 1].max().item() + 1
|
|
||||||
else:
|
|
||||||
max_y = patch_coord[:, 0].max().item() + 1
|
|
||||||
max_x = patch_coord[:, 1].max().item() + 1
|
|
||||||
|
|
||||||
grid_h, grid_w = max_y, max_x
|
|
||||||
|
|
||||||
# Calculate total area in pixel space
|
|
||||||
total_area = (grid_h * patch_h) * (grid_w * patch_w)
|
|
||||||
|
|
||||||
count = random.randint(self.min_count, self.max_count)
|
|
||||||
for _ in range(count):
|
|
||||||
# Try to select a valid region to erase (multiple attempts)
|
|
||||||
for attempt in range(10):
|
|
||||||
# Sample random area and aspect ratio
|
|
||||||
target_area = random.uniform(self.min_area, self.max_area) * total_area
|
|
||||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
|
||||||
|
|
||||||
# Calculate region height and width in pixel space
|
|
||||||
pixel_h = int(round(math.sqrt(target_area * aspect_ratio)))
|
|
||||||
pixel_w = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
||||||
|
|
||||||
# Ensure region fits within total pixel grid
|
|
||||||
if pixel_w <= grid_w * patch_w and pixel_h <= grid_h * patch_h:
|
|
||||||
# Select random top-left corner in pixel space
|
|
||||||
pixel_top = random.randint(0, grid_h * patch_h - pixel_h)
|
|
||||||
pixel_left = random.randint(0, grid_w * patch_w - pixel_w)
|
|
||||||
|
|
||||||
# Define region bounds in pixel space
|
|
||||||
pixel_bottom = pixel_top + pixel_h
|
|
||||||
pixel_right = pixel_left + pixel_w
|
|
||||||
|
|
||||||
# Create a single random value for the entire region if using 'rand' mode
|
|
||||||
rand_value = None
|
|
||||||
if self.erase_mode == 'rand':
|
|
||||||
rand_value = torch.empty(patch_shape[-1], dtype=dtype, device=self.device).normal_()
|
|
||||||
|
|
||||||
# For each valid patch, determine if and how it overlaps with the erase region
|
|
||||||
for i in range(len(patches)):
|
|
||||||
if patch_valid is None or patch_valid[i]:
|
|
||||||
# Convert patch coordinates to pixel space (top-left corner)
|
|
||||||
y, x = patch_coord[i]
|
|
||||||
patch_pixel_top = y * patch_h
|
|
||||||
patch_pixel_left = x * patch_w
|
|
||||||
patch_pixel_bottom = patch_pixel_top + patch_h
|
|
||||||
patch_pixel_right = patch_pixel_left + patch_w
|
|
||||||
|
|
||||||
# Check if this patch overlaps with the erase region
|
|
||||||
if not (patch_pixel_right <= pixel_left or patch_pixel_left >= pixel_right or
|
|
||||||
patch_pixel_bottom <= pixel_top or patch_pixel_top >= pixel_bottom):
|
|
||||||
|
|
||||||
# Calculate the overlap region in patch-local coordinates
|
|
||||||
local_top = max(0, pixel_top - patch_pixel_top)
|
|
||||||
local_left = max(0, pixel_left - patch_pixel_left)
|
|
||||||
local_bottom = min(patch_h, pixel_bottom - patch_pixel_top)
|
|
||||||
local_right = min(patch_w, pixel_right - patch_pixel_left)
|
|
||||||
|
|
||||||
# Reshape the patch to [patch_h, patch_w, chans]
|
|
||||||
patch_data = patches[i].reshape(patch_h, patch_w, channels)
|
|
||||||
|
|
||||||
erase_shape = (local_bottom - local_top, local_right - local_left, channels)
|
|
||||||
erase_value = self._get_values(erase_shape, dtype=dtype, value=rand_value)
|
|
||||||
patch_data[local_top:local_bottom, local_left:local_right, :] = erase_value
|
|
||||||
|
|
||||||
# Flatten the patch back to [patch_h*patch_w, chans]
|
|
||||||
if len(patch_shape) == 2:
|
|
||||||
patch_data = patch_data.reshape(-1, channels)
|
|
||||||
patches[i] = patch_data
|
|
||||||
|
|
||||||
# Successfully applied erasing, exit the loop
|
|
||||||
break
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -369,18 +260,12 @@ class PatchRandomErasing:
|
|||||||
"""
|
"""
|
||||||
if patches.ndim == 4:
|
if patches.ndim == 4:
|
||||||
batch_size, num_patches, patch_dim, channels = patches.shape
|
batch_size, num_patches, patch_dim, channels = patches.shape
|
||||||
if self.patch_size is not None:
|
|
||||||
patch_size = self.patch_size
|
|
||||||
else:
|
|
||||||
patch_size = None
|
|
||||||
elif patches.ndim == 5:
|
elif patches.ndim == 5:
|
||||||
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
|
batch_size, num_patches, patch_h, patch_w, channels = patches.shape
|
||||||
patch_size = (patch_h, patch_w)
|
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
patch_shape = patches.shape[2:]
|
patch_shape = patches.shape[2:]
|
||||||
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
|
# patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
|
||||||
# patch_size ==> patch h, w (if available, must be avail for subregion mode)
|
|
||||||
|
|
||||||
# Create default valid mask if not provided
|
# Create default valid mask if not provided
|
||||||
if patch_valid is None:
|
if patch_valid is None:
|
||||||
@ -399,6 +284,7 @@ class PatchRandomErasing:
|
|||||||
patch_valid[i],
|
patch_valid[i],
|
||||||
)
|
)
|
||||||
elif self.spatial_mode == 'patch':
|
elif self.spatial_mode == 'patch':
|
||||||
|
# FIXME we could vectorize patch mode across batch, worth the effort?
|
||||||
self._erase_patches(
|
self._erase_patches(
|
||||||
patches[i],
|
patches[i],
|
||||||
patch_coord[i],
|
patch_coord[i],
|
||||||
@ -414,15 +300,8 @@ class PatchRandomErasing:
|
|||||||
patch_shape,
|
patch_shape,
|
||||||
patches.dtype
|
patches.dtype
|
||||||
)
|
)
|
||||||
elif self.spatial_mode == 'subregion':
|
else:
|
||||||
self._erase_subregion(
|
assert False
|
||||||
patches[i],
|
|
||||||
patch_coord[i],
|
|
||||||
patch_valid[i],
|
|
||||||
patch_shape,
|
|
||||||
patch_size,
|
|
||||||
patches.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return patches
|
return patches
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user