pytorch-image-models/timm/data/naflex_transforms.py

805 lines
31 KiB
Python

""" NaFlex (NaViT + FlexiViT) Transforms and Collation
Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers:
- NaViT: https://arxiv.org/abs/2307.14995
- FlexiViT: https://arxiv.org/abs/2212.08013
Enables variable resolution/aspect ratio image handling with efficient patching.
"""
import math
import random
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode
from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad
def get_image_size_for_seq(
image_hw,
patch_size=16,
max_seq_len=1024,
divisible_by_patch=True,
max_ratio=None,
eps = 1e-5,
):
"""
Determine scaling ratio and image size so that when `image_hw` is scaled
by 'ratio', the total number of resulting patches does not exceed
'max_seq_len'.
- Patch size can be an integer (square patch) or a tuple (patch_h, patch_w).
- Optionally cap the ratio at `max_ratio` to prevent upsampling beyond
a certain multiple of the original size.
Args:
image_hw (tuple or list of int): (height, width) of the original image.
patch_size (int or tuple[int, int]): If int, patch is square. If tuple,
patch is rectangular (patch_h, patch_w).
max_seq_len (int): Maximum allowed sequence length for the resulting image.
divisible_by_patch (bool): If True, the resulting image height and width
must be multiples of patch_size.
eps (float): Small number for binary search convergence.
max_ratio (float or None): If provided, the scaling ratio found by the
binary search will be clamped to min(found_ratio, max_ratio). Set
max_ratio=1.0 to ensure no upsampling beyond original size.
Returns:
ratio (float): Found scaling ratio (capped by `max_ratio` if provided).
target_hw (tuple of int): Target (height, width) after scaling.
"""
# Handle patch size input, extract patch_h, patch_w
if isinstance(patch_size, int):
patch_h, patch_w = patch_size, patch_size
else:
# Assume it's a tuple/list: (patch_h, patch_w)
if len(patch_size) != 2:
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
patch_h, patch_w = patch_size
# Safety checks
if patch_h <= 0 or patch_w <= 0:
raise ValueError("patch_size dimensions must be positive.")
def prepare_target_hw(ratio):
"""Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w."""
scaled_h = image_hw[0] * ratio
scaled_w = image_hw[1] * ratio
# If we need the result to be divisible by patch_size
if divisible_by_patch:
scaled_h = patch_h * math.ceil(scaled_h / patch_h)
scaled_w = patch_w * math.ceil(scaled_w / patch_w)
# Ensure at least one patch in each dimension
scaled_h = int(max(scaled_h, patch_h))
scaled_w = int(max(scaled_w, patch_w))
return scaled_h, scaled_w
def is_feasible(ratio):
"""Check if scaling by 'ratio' keeps patch count within max_seq_len."""
t_h, t_w = prepare_target_hw(ratio)
# Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True.
# Use integer division to count patches.
num_patches_h = t_h // patch_h
num_patches_w = t_w // patch_w
seq_len = num_patches_h * num_patches_w
return seq_len <= max_seq_len
# Binary search boundaries
lb = eps / 10.0
rb = 100.0
# Standard binary search loop
while (rb - lb) >= eps:
mid = (lb + rb) / 2.0
if is_feasible(mid):
lb = mid
else:
rb = mid
# The final ratio from the binary search
ratio = lb
# If max_ratio is provided, clamp it to prevent upsampling beyond that threshold
if max_ratio is not None:
ratio = min(ratio, max_ratio)
# Final checks
if ratio <= eps:
raise ValueError("Binary search failed - image might be too large?")
if ratio >= 100.0:
raise ValueError("Binary search failed - image might be too small?")
# Prepare the final target dimensions with the possibly clamped ratio
target_hw = prepare_target_hw(ratio)
return ratio, target_hw
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
class ResizeToSequence(torch.nn.Module):
"""Resize image to fit within a maximum sequence length constraint when patchified.
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
will not exceed the specified maximum sequence length.
"""
def __init__(
self,
patch_size: int,
max_seq_len: int = 1024,
divisible_by_patch: bool = True,
max_ratio: Optional[float] = None,
interpolation='bicubic',
):
super().__init__()
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.divisible_by_patch = divisible_by_patch
self.max_ratio = max_ratio
if isinstance(interpolation, str):
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
else:
self.interpolation = interpolation
def forward(self, img):
"""Resize image to maintain aspect ratio and fit sequence constraint."""
_, h, w = transforms.functional.get_dimensions(img)
_, target_hw = get_image_size_for_seq(
(h, w),
self.patch_size,
self.max_seq_len,
divisible_by_patch=self.divisible_by_patch,
max_ratio=self.max_ratio,
)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True)
return resized_img
class ResizeKeepRatioToSequence(torch.nn.Module):
"""
Resize and Keep Aspect Ratio, adapted to fit sequence length constraints.
"""
def __init__(
self,
patch_size=16,
max_sequence_len=1024,
divisible_by_patch=True,
longest=0.,
interpolation='bilinear',
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_scale_area=False,
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11),
max_ratio=None,
):
"""
Args:
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
max_sequence_len: Maximum allowed sequence length for the resulting image
divisible_by_patch: If True, ensure dimensions are divisible by patch_size
longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale
interpolation: Interpolation method for resizing
random_scale_prob: Probability of applying random scaling
random_scale_range: Range for random scaling factor (min, max)
random_scale_area: If True, scale factors affect area (√ factor)
random_aspect_prob: Probability of applying random aspect ratio jittering
random_aspect_range: Range for random aspect ratio (min, max)
max_ratio: Maximum allowed scaling ratio
"""
super().__init__()
self.patch_size = patch_size
self.max_sequence_len = max_sequence_len
self.divisible_by_patch = divisible_by_patch
self.longest = float(longest)
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
self.random_scale_prob = random_scale_prob
self.random_scale_range = random_scale_range
self.random_scale_area = random_scale_area
self.random_aspect_prob = random_aspect_prob
self.random_aspect_range = random_aspect_range
self.max_ratio = max_ratio
@staticmethod
def get_params(
img,
patch_size,
max_sequence_len,
divisible_by_patch,
longest,
random_scale_prob=0.,
random_scale_range=(1.0, 1.33),
random_scale_area=False,
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11),
max_ratio=None,
):
"""Get parameters for resizing."""
# Get image dimensions
img_h, img_w = F.get_dimensions(img)[1:]
# Step 1: Get the maximum allowed dimensions from sequence length constraint
_, target_hw = get_image_size_for_seq(
(img_h, img_w),
patch_size,
max_sequence_len,
divisible_by_patch,
max_ratio,
)
target_h, target_w = target_hw
# Calculate ratio based on sequence constraint
ratio_h = target_h / img_h
ratio_w = target_w / img_w
# Apply longest blending
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
# Apply random scaling
if random_scale_prob > 0 and random.random() < random_scale_prob:
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
if random_scale_area:
# Make ratio factor equivalent to area change
ratio_factor = 1. / math.sqrt(ratio_factor)
ratio_factor = (ratio_factor, ratio_factor)
else:
ratio_factor = (1., 1.)
# Apply random aspect
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
aspect_factor = math.exp(random.uniform(*log_aspect))
aspect_factor = math.sqrt(aspect_factor)
# Apply aspect ratio jittering
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
# Calculate final dimensions
size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)]
# Ensure dimensions satisfy sequence constraint and are divisible by patch size
if isinstance(patch_size, int):
ph, pw = patch_size, patch_size
else:
ph, pw = patch_size
# Ensure dimensions are at least one patch
size[0] = max(size[0], ph)
size[1] = max(size[1], pw)
# Make divisible by patch size if needed
if divisible_by_patch:
size[0] = ph * math.ceil(size[0] / ph)
size[1] = pw * math.ceil(size[1] / pw)
# Verify we haven't exceeded sequence length
num_patches_h = size[0] // ph
num_patches_w = size[1] // pw
seq_len = num_patches_h * num_patches_w
if seq_len > max_sequence_len:
# Scale back down to fit sequence constraint
scale_back = math.sqrt(max_sequence_len / seq_len)
size[0] = int(size[0] * scale_back)
size[1] = int(size[1] * scale_back)
# Ensure divisible by patch size after scaling back
if divisible_by_patch:
size[0] = ph * math.ceil(size[0] / ph)
size[1] = pw * math.ceil(size[1] / pw)
return size
def forward(self, img):
"""
Resize the image with aspect ratio preservation and sequence length constraints.
"""
size = self.get_params(
img,
self.patch_size,
self.max_sequence_len,
self.divisible_by_patch,
self.longest,
self.random_scale_prob,
self.random_scale_range,
self.random_scale_area,
self.random_aspect_prob,
self.random_aspect_range,
self.max_ratio,
)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
return F.resize(img, size, interpolation)
def __repr__(self):
interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation)
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
f"max_sequence_len={self.max_sequence_len}, "
f"longest={self.longest:.3f}, "
f"random_scale_prob={self.random_scale_prob:.3f}, "
f"random_aspect_prob={self.random_aspect_prob:.3f})")
class CenterCropToSequence(torch.nn.Module):
"""Center crop the image such that the resulting patch sequence length meets constraints."""
def __init__(
self,
patch_size: int,
max_seq_len: int,
divisible_by_patch: bool = True,
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant'
):
super().__init__()
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.divisible_by_patch = divisible_by_patch
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""Center crop the image to maintain aspect ratio and fit sequence constraint."""
_, h, w = transforms.functional.get_dimensions(img)
_, target_hw = get_image_size_for_seq(
(h, w),
self.patch_size,
self.max_seq_len,
self.divisible_by_patch
)
# Use center crop
return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode)
class RandomCropToSequence(torch.nn.Module):
"""Randomly crop and/or pad the image to fit sequence length constraints.
This maintains aspect ratio while ensuring the resulting image, when divided into patches,
will not exceed the specified maximum sequence length. Similar to CentralCropToSequence
but with randomized positioning.
"""
def __init__(
self,
patch_size: int,
max_sequence_len: int,
divisible_by_patch: bool = True,
fill: Union[int, Tuple[int, int, int]] = 0,
padding_mode: str = 'constant'
):
"""
Args:
patch_size: Size of patches (int or tuple of (patch_h, patch_w))
max_sequence_len: Maximum allowed sequence length for the resulting image
divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size
fill: Fill value for padding
padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric')
"""
super().__init__()
self.patch_size = patch_size
self.max_sequence_len = max_sequence_len
self.divisible_by_patch = divisible_by_patch
self.fill = fill
self.padding_mode = padding_mode
@staticmethod
def get_params(img, target_size):
"""Get random position for crop/pad."""
_, image_height, image_width = transforms.functional.get_dimensions(img)
delta_height = image_height - target_size[0]
delta_width = image_width - target_size[1]
# Handle both positive (crop) and negative (pad) deltas
if delta_height == 0:
top = 0
else:
top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
if delta_width == 0:
left = 0
else:
left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
return top, left
def forward(self, img):
"""Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint."""
# Get current dimensions
_, img_h, img_w = transforms.functional.get_dimensions(img)
# Calculate target dimensions that satisfy sequence length
# We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size
_, target_hw = get_image_size_for_seq(
(img_h, img_w),
self.patch_size,
self.max_sequence_len,
self.divisible_by_patch,
max_ratio=1.0 # Prevent upscaling
)
# Get random position for crop/pad
top, left = self.get_params(img, target_hw)
# Apply crop or pad
return crop_or_pad(
img,
top=top,
left=left,
height=target_hw[0],
width=target_hw[1],
fill=self.fill,
padding_mode=self.padding_mode,
)
def __repr__(self) -> str:
return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
f"max_sequence_len={self.max_sequence_len}, "
f"divisible_by_patch={self.divisible_by_patch})")
def _validate_range(value, name, length=2):
# Validate type and length
if not isinstance(value, Sequence) or len(value) != length:
raise ValueError(f"{name} should be a sequence of length {length}.")
# Validate order
if value[0] > value[1]:
warnings.warn(f"{name.capitalize()} range reversed. Swapping.")
return value[1], value[0]
return value
class RandomResizedCropToSequence(torch.nn.Module):
"""
Randomly crop the input image to a subregion with varying area and aspect ratio
(relative to the original), then resize that crop to a target size. The target size
is determined such that patchifying the resized image (with `patch_size`)
does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop.
This combines aspects of torchvision's RandomResizedCrop with sequence length constraints.
Args:
patch_size (int or tuple[int, int]):
Patch dimensions (patch_h, patch_w) for sequence length calculation.
max_seq_len (int):
Maximum number of patches allowed in the final image.
scale (tuple[float, float]):
Range (min, max) of area fraction of the original image to crop.
ratio (tuple[float, float]):
Range (min, max) of aspect ratio *multipliers* for the crop, relative
to the original image's aspect ratio. E.g., (0.75, 1.333) means the
crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar.
Uses log-uniform sampling.
interpolation (str or InterpolationMode):
Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest',
or 'random' (chooses between bilinear and bicubic).
Defaults to 'bicubic'.
divisible_by_patch (bool):
If True, the final image height and width will be multiples of the
respective patch dimensions. Defaults to True.
max_ratio (float, optional):
An optional upper limit on the scaling ratio applied during resizing.
Prevents excessive upsampling of the initial crop. `max_ratio=1.0`
prevents any upsampling beyond the cropped size. Defaults to None (no limit).
final_scale_range (tuple[float, float], optional):
If provided, applies an *additional* random scaling factor to the
final target size. The factor is sampled uniformly from this range,
and multiplied by the size determined by `get_image_size_for_seq`.
E.g., (0.8, 1.0) means the final size will be between 80% and 100%
of the maximum feasible size. Defaults to None (use maximum feasible size).
attempts (int):
Number of attempts to sample a valid crop geometry before falling back
to a center crop strategy. Defaults to 10.
"""
def __init__(
self,
patch_size: Union[int, Tuple[int, int]] = 16,
max_seq_len: int = 1024,
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (.8, 1.25),
interpolation: Union[str, InterpolationMode] = 'bicubic',
divisible_by_patch: bool = True,
max_ratio: Optional[float] = None,
final_scale_range: Optional[Tuple[float, float]] = None,
attempts: int = 10,
):
super().__init__()
if isinstance(patch_size, int):
self.patch_h, self.patch_w = patch_size, patch_size
else:
# Assume it's a tuple/list: (patch_h, patch_w)
if len(patch_size) != 2:
raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
self.patch_h, self.patch_w = patch_size
self.max_seq_len = max_seq_len
self.scale = scale
self.ratio = ratio
self.divisible_by_patch = divisible_by_patch
self.max_ratio = max_ratio
self.final_scale_range = final_scale_range
self.attempts = attempts
if isinstance(interpolation, str):
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
else:
self.interpolation = interpolation
# Validate scale and ratio
self.scale = _validate_range(self.scale, "scale")
self.ratio = _validate_range(self.ratio, "ratio")
# Validate final_scale_range if provided
if self.final_scale_range is not None:
self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range")
# Additional validation for final_scale_range values
if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0):
warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.")
@staticmethod
def get_params(
img: torch.Tensor,
scale: Tuple[float, float],
ratio: Tuple[float, float],
crop_attempts: int = 10,
patch_h: int = 16,
patch_w: int = 16,
max_seq_len: int = 1024,
divisible_by_patch: bool = True,
max_ratio: Optional[float] = None,
final_scale_range: Optional[Tuple[float, float]] = None,
interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION,
) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]:
""" Get parameters for a random sized crop relative to image aspect ratio.
"""
_, height, width = F.get_dimensions(img)
if height <= 0 or width <= 0:
raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}")
area = height * width
orig_aspect = width / height
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
for _ in range(crop_attempts):
target_area = area * random.uniform(scale[0], scale[1])
aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
aspect_ratio = orig_aspect * aspect_ratio_factor
# Calculate target dimensions for the crop
# target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h
# => crop_h = sqrt(target_area / aspect_ratio)
# => crop_w = sqrt(target_area * aspect_ratio)
crop_h = int(round(math.sqrt(target_area / aspect_ratio)))
crop_w = int(round(math.sqrt(target_area * aspect_ratio)))
if 0 < crop_w <= width and 0 < crop_h <= height:
top = random.randint(0, height - crop_h)
left = random.randint(0, width - crop_w)
break
else:
# Fallback strategy, use center crop trying to respect ratio range
min_aspect_ratio = orig_aspect * ratio[0]
max_aspect_ratio = orig_aspect * ratio[1]
if orig_aspect < min_aspect_ratio:
# Original is narrower than target min, clamp width
crop_w = width
crop_h = min(int(round(crop_w / min_aspect_ratio)), height)
elif orig_aspect > max_aspect_ratio:
# Original is wider than target max, clamp height
crop_h = height
crop_w = min(int(round(crop_h * max_aspect_ratio)), width)
else:
# Aspect ratio is within range, take the largest possible crop (full image)
crop_w = width
crop_h = height
# Ensure valid dimensions after fallback calculation
crop_h = max(1, crop_h)
crop_w = max(1, crop_w)
top = (height - crop_h) // 2
left = (width - crop_w) // 2
# Determine max feasible size for scaling of the *cropped* region
feasible_ratio, feasible_size = get_image_size_for_seq(
(crop_h, crop_w),
patch_size=(patch_h, patch_w), # Pass as tuple
max_seq_len=max_seq_len,
divisible_by_patch=divisible_by_patch,
max_ratio=max_ratio,
)
# Optionally apply final scale randomization
final_size = feasible_size
if final_scale_range is not None:
min_sc, max_sc = final_scale_range
scale_factor = random.uniform(min_sc, max_sc)
scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case
# Calculate raw scaled size
# Note: feasible_ratio already accounts for max_ratio clamp if any
raw_h = crop_h * feasible_ratio * scale_factor
raw_w = crop_w * feasible_ratio * scale_factor
# Re-apply divisibility constraint if needed
if divisible_by_patch:
# Use ceil to avoid going under minimum patch size
target_h = patch_h * math.ceil(raw_h / patch_h)
target_w = patch_w * math.ceil(raw_w / patch_w)
else:
target_h = int(round(raw_h))
target_w = int(round(raw_w))
# Ensure final size is at least one patch dimension
target_h = max(target_h, patch_h)
target_w = max(target_w, patch_w)
final_size = (target_h, target_w)
# Final check: Ensure this randomized size still fits max_seq_len
# (It should, as we scaled down, but rounding might theoretically push it over)
num_patches_h = final_size[0] // patch_h
num_patches_w = final_size[1] // patch_w
if (num_patches_h * num_patches_w) > max_seq_len:
# If it exceeds, revert to the original feasible_size (safest)
final_size = feasible_size
warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.")
# Select interpolation mode
if isinstance(interpolation, (tuple, list)):
interpolation = random.choice(interpolation)
else:
interpolation = interpolation
return (top, left, crop_h, crop_w), final_size, interpolation
def forward(self, img: torch.Tensor) -> torch.Tensor:
# Sample crop, resize, and interpolation parameters
crop_params, final_size, interpolation = self.get_params(
img,
scale=self.scale,
ratio=self.ratio,
crop_attempts=self.attempts,
patch_h=self.patch_h,
patch_w=self.patch_w,
divisible_by_patch=self.divisible_by_patch,
max_seq_len=self.max_seq_len,
final_scale_range=self.final_scale_range,
interpolation=self.interpolation,
)
top, left, crop_h, crop_w = crop_params
output = F.resized_crop(
img,
top=top,
left=left,
height=crop_h,
width=crop_w,
size=final_size,
interpolation=interpolation,
antialias=True,
)
return output
def __repr__(self) -> str:
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation)
else:
interpolate_str = str(self.interpolation)
format_string = self.__class__.__name__ + '('
format_string += f"patch_size=({self.patch_h}, {self.patch_w})"
format_string += f", max_seq_len={self.max_seq_len}"
format_string += f", scale={self.scale}"
format_string += f", ratio={self.ratio}"
format_string += f", interpolation=[{interpolate_str}]"
format_string += f", divisible_by_patch={self.divisible_by_patch}"
format_string += f", max_ratio={self.max_ratio}"
format_string += f", final_scale_range={self.final_scale_range}"
format_string += f", attempts={self.attempts}"
format_string += ')'
return format_string
def patchify(
img: torch.Tensor,
patch_size: Tuple[int, int],
pad: bool = True,
include_info: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
c, h, w = img.shape
ph, pw = patch_size
# Ensure the image is divisible by patch size
if pad and (h % ph != 0 or w % pw != 0):
new_h = math.ceil(h / ph) * ph
new_w = math.ceil(w / pw) * pw
padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype)
padded_img[:, :h, :w] = img
img = padded_img
c, h, w = img.shape
# Calculate number of patches in each dimension
nh, nw = h // ph, w // pw
# Reshape image to patches [nh, nw, ph, pw, c]
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
if include_info:
# Create coordinate indices
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
# Stack into a single coords tensor [N, 2] with (y, x) order
coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1)
# Create type indicators (all 1s for regular patches)
valid = torch.ones(nh * nw, dtype=torch.bool)
return patches, coord, valid
return patches
class Patchify(torch.nn.Module):
"""Transform an image into patches with corresponding coordinates and type indicators."""
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
def forward(self, img):
"""
Args:
img: A PIL Image or tensor of shape [C, H, W]
Returns:
A dictionary containing:
- patches: Tensor of shape [N, P*P*C] where N is the number of patches
- patch_coord: Tensor of shape [N, 2] with (y, x) coordinates
- patch_valid: Valid indicator (all 1s for non-padding patches)
"""
if isinstance(img, Image.Image):
# Convert PIL Image to tensor [C, H, W]
img = transforms.functional.to_tensor(img)
patches, coord, valid = patchify(img, self.patch_size)
return {
'patches': patches,
'patch_coord': coord,
'patch_valid': valid,
}