""" NaFlex (NaViT + FlexiViT) Transforms and Collation Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers: - NaViT: https://arxiv.org/abs/2307.14995 - FlexiViT: https://arxiv.org/abs/2212.08013 Enables variable resolution/aspect ratio image handling with efficient patching. """ import math import random import warnings from typing import List, Optional, Sequence, Tuple, Union import torch from PIL import Image from torchvision import transforms from torchvision.transforms import functional as F from torchvision.transforms.functional import InterpolationMode from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad def get_image_size_for_seq( image_hw, patch_size=16, max_seq_len=1024, divisible_by_patch=True, max_ratio=None, eps = 1e-5, ): """ Determine scaling ratio and image size so that when `image_hw` is scaled by 'ratio', the total number of resulting patches does not exceed 'max_seq_len'. - Patch size can be an integer (square patch) or a tuple (patch_h, patch_w). - Optionally cap the ratio at `max_ratio` to prevent upsampling beyond a certain multiple of the original size. Args: image_hw (tuple or list of int): (height, width) of the original image. patch_size (int or tuple[int, int]): If int, patch is square. If tuple, patch is rectangular (patch_h, patch_w). max_seq_len (int): Maximum allowed sequence length for the resulting image. divisible_by_patch (bool): If True, the resulting image height and width must be multiples of patch_size. eps (float): Small number for binary search convergence. max_ratio (float or None): If provided, the scaling ratio found by the binary search will be clamped to min(found_ratio, max_ratio). Set max_ratio=1.0 to ensure no upsampling beyond original size. Returns: ratio (float): Found scaling ratio (capped by `max_ratio` if provided). target_hw (tuple of int): Target (height, width) after scaling. """ # Handle patch size input, extract patch_h, patch_w if isinstance(patch_size, int): patch_h, patch_w = patch_size, patch_size else: # Assume it's a tuple/list: (patch_h, patch_w) if len(patch_size) != 2: raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).") patch_h, patch_w = patch_size # Safety checks if patch_h <= 0 or patch_w <= 0: raise ValueError("patch_size dimensions must be positive.") def prepare_target_hw(ratio): """Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w.""" scaled_h = image_hw[0] * ratio scaled_w = image_hw[1] * ratio # If we need the result to be divisible by patch_size if divisible_by_patch: scaled_h = patch_h * math.ceil(scaled_h / patch_h) scaled_w = patch_w * math.ceil(scaled_w / patch_w) # Ensure at least one patch in each dimension scaled_h = int(max(scaled_h, patch_h)) scaled_w = int(max(scaled_w, patch_w)) return scaled_h, scaled_w def is_feasible(ratio): """Check if scaling by 'ratio' keeps patch count within max_seq_len.""" t_h, t_w = prepare_target_hw(ratio) # Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True. # Use integer division to count patches. num_patches_h = t_h // patch_h num_patches_w = t_w // patch_w seq_len = num_patches_h * num_patches_w return seq_len <= max_seq_len # Binary search boundaries lb = eps / 10.0 rb = 100.0 # Standard binary search loop while (rb - lb) >= eps: mid = (lb + rb) / 2.0 if is_feasible(mid): lb = mid else: rb = mid # The final ratio from the binary search ratio = lb # If max_ratio is provided, clamp it to prevent upsampling beyond that threshold if max_ratio is not None: ratio = min(ratio, max_ratio) # Final checks if ratio <= eps: raise ValueError("Binary search failed - image might be too large?") if ratio >= 100.0: raise ValueError("Binary search failed - image might be too small?") # Prepare the final target dimensions with the possibly clamped ratio target_hw = prepare_target_hw(ratio) return ratio, target_hw _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) class ResizeToSequence(torch.nn.Module): """Resize image to fit within a maximum sequence length constraint when patchified. This maintains aspect ratio while ensuring the resulting image, when divided into patches, will not exceed the specified maximum sequence length. """ def __init__( self, patch_size: int, max_seq_len: int = 1024, divisible_by_patch: bool = True, max_ratio: Optional[float] = None, interpolation='bicubic', ): super().__init__() self.patch_size = patch_size self.max_seq_len = max_seq_len self.divisible_by_patch = divisible_by_patch self.max_ratio = max_ratio if isinstance(interpolation, str): if interpolation == 'random': self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = str_to_interp_mode(interpolation) else: self.interpolation = interpolation def forward(self, img): """Resize image to maintain aspect ratio and fit sequence constraint.""" _, h, w = transforms.functional.get_dimensions(img) _, target_hw = get_image_size_for_seq( (h, w), self.patch_size, self.max_seq_len, divisible_by_patch=self.divisible_by_patch, max_ratio=self.max_ratio, ) if isinstance(self.interpolation, (tuple, list)): interpolation = random.choice(self.interpolation) else: interpolation = self.interpolation resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True) return resized_img class ResizeKeepRatioToSequence(torch.nn.Module): """ Resize and Keep Aspect Ratio, adapted to fit sequence length constraints. """ def __init__( self, patch_size=16, max_sequence_len=1024, divisible_by_patch=True, longest=0., interpolation='bilinear', random_scale_prob=0., random_scale_range=(0.85, 1.05), random_scale_area=False, random_aspect_prob=0., random_aspect_range=(0.9, 1.11), max_ratio=None, ): """ Args: patch_size: Size of patches (int or tuple of (patch_h, patch_w)) max_sequence_len: Maximum allowed sequence length for the resulting image divisible_by_patch: If True, ensure dimensions are divisible by patch_size longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale interpolation: Interpolation method for resizing random_scale_prob: Probability of applying random scaling random_scale_range: Range for random scaling factor (min, max) random_scale_area: If True, scale factors affect area (√ factor) random_aspect_prob: Probability of applying random aspect ratio jittering random_aspect_range: Range for random aspect ratio (min, max) max_ratio: Maximum allowed scaling ratio """ super().__init__() self.patch_size = patch_size self.max_sequence_len = max_sequence_len self.divisible_by_patch = divisible_by_patch self.longest = float(longest) if interpolation == 'random': self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = str_to_interp_mode(interpolation) self.random_scale_prob = random_scale_prob self.random_scale_range = random_scale_range self.random_scale_area = random_scale_area self.random_aspect_prob = random_aspect_prob self.random_aspect_range = random_aspect_range self.max_ratio = max_ratio @staticmethod def get_params( img, patch_size, max_sequence_len, divisible_by_patch, longest, random_scale_prob=0., random_scale_range=(1.0, 1.33), random_scale_area=False, random_aspect_prob=0., random_aspect_range=(0.9, 1.11), max_ratio=None, ): """Get parameters for resizing.""" # Get image dimensions img_h, img_w = F.get_dimensions(img)[1:] # Step 1: Get the maximum allowed dimensions from sequence length constraint _, target_hw = get_image_size_for_seq( (img_h, img_w), patch_size, max_sequence_len, divisible_by_patch, max_ratio, ) target_h, target_w = target_hw # Calculate ratio based on sequence constraint ratio_h = target_h / img_h ratio_w = target_w / img_w # Apply longest blending ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) # Apply random scaling if random_scale_prob > 0 and random.random() < random_scale_prob: ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) if random_scale_area: # Make ratio factor equivalent to area change ratio_factor = 1. / math.sqrt(ratio_factor) ratio_factor = (ratio_factor, ratio_factor) else: ratio_factor = (1., 1.) # Apply random aspect if random_aspect_prob > 0 and random.random() < random_aspect_prob: log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1])) aspect_factor = math.exp(random.uniform(*log_aspect)) aspect_factor = math.sqrt(aspect_factor) # Apply aspect ratio jittering ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) # Calculate final dimensions size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)] # Ensure dimensions satisfy sequence constraint and are divisible by patch size if isinstance(patch_size, int): ph, pw = patch_size, patch_size else: ph, pw = patch_size # Ensure dimensions are at least one patch size[0] = max(size[0], ph) size[1] = max(size[1], pw) # Make divisible by patch size if needed if divisible_by_patch: size[0] = ph * math.ceil(size[0] / ph) size[1] = pw * math.ceil(size[1] / pw) # Verify we haven't exceeded sequence length num_patches_h = size[0] // ph num_patches_w = size[1] // pw seq_len = num_patches_h * num_patches_w if seq_len > max_sequence_len: # Scale back down to fit sequence constraint scale_back = math.sqrt(max_sequence_len / seq_len) size[0] = int(size[0] * scale_back) size[1] = int(size[1] * scale_back) # Ensure divisible by patch size after scaling back if divisible_by_patch: size[0] = ph * math.ceil(size[0] / ph) size[1] = pw * math.ceil(size[1] / pw) return size def forward(self, img): """ Resize the image with aspect ratio preservation and sequence length constraints. """ size = self.get_params( img, self.patch_size, self.max_sequence_len, self.divisible_by_patch, self.longest, self.random_scale_prob, self.random_scale_range, self.random_scale_area, self.random_aspect_prob, self.random_aspect_range, self.max_ratio, ) if isinstance(self.interpolation, (tuple, list)): interpolation = random.choice(self.interpolation) else: interpolation = self.interpolation return F.resize(img, size, interpolation) def __repr__(self): interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation) return (f"{self.__class__.__name__}(patch_size={self.patch_size}, " f"max_sequence_len={self.max_sequence_len}, " f"longest={self.longest:.3f}, " f"random_scale_prob={self.random_scale_prob:.3f}, " f"random_aspect_prob={self.random_aspect_prob:.3f})") class CenterCropToSequence(torch.nn.Module): """Center crop the image such that the resulting patch sequence length meets constraints.""" def __init__( self, patch_size: int, max_seq_len: int, divisible_by_patch: bool = True, fill: Union[int, Tuple[int, int, int]] = 0, padding_mode: str = 'constant' ): super().__init__() self.patch_size = patch_size self.max_seq_len = max_seq_len self.divisible_by_patch = divisible_by_patch self.fill = fill self.padding_mode = padding_mode def forward(self, img): """Center crop the image to maintain aspect ratio and fit sequence constraint.""" _, h, w = transforms.functional.get_dimensions(img) _, target_hw = get_image_size_for_seq( (h, w), self.patch_size, self.max_seq_len, self.divisible_by_patch ) # Use center crop return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode) class RandomCropToSequence(torch.nn.Module): """Randomly crop and/or pad the image to fit sequence length constraints. This maintains aspect ratio while ensuring the resulting image, when divided into patches, will not exceed the specified maximum sequence length. Similar to CentralCropToSequence but with randomized positioning. """ def __init__( self, patch_size: int, max_sequence_len: int, divisible_by_patch: bool = True, fill: Union[int, Tuple[int, int, int]] = 0, padding_mode: str = 'constant' ): """ Args: patch_size: Size of patches (int or tuple of (patch_h, patch_w)) max_sequence_len: Maximum allowed sequence length for the resulting image divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size fill: Fill value for padding padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric') """ super().__init__() self.patch_size = patch_size self.max_sequence_len = max_sequence_len self.divisible_by_patch = divisible_by_patch self.fill = fill self.padding_mode = padding_mode @staticmethod def get_params(img, target_size): """Get random position for crop/pad.""" _, image_height, image_width = transforms.functional.get_dimensions(img) delta_height = image_height - target_size[0] delta_width = image_width - target_size[1] # Handle both positive (crop) and negative (pad) deltas if delta_height == 0: top = 0 else: top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height)) if delta_width == 0: left = 0 else: left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width)) return top, left def forward(self, img): """Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint.""" # Get current dimensions _, img_h, img_w = transforms.functional.get_dimensions(img) # Calculate target dimensions that satisfy sequence length # We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size _, target_hw = get_image_size_for_seq( (img_h, img_w), self.patch_size, self.max_sequence_len, self.divisible_by_patch, max_ratio=1.0 # Prevent upscaling ) # Get random position for crop/pad top, left = self.get_params(img, target_hw) # Apply crop or pad return crop_or_pad( img, top=top, left=left, height=target_hw[0], width=target_hw[1], fill=self.fill, padding_mode=self.padding_mode, ) def __repr__(self) -> str: return (f"{self.__class__.__name__}(patch_size={self.patch_size}, " f"max_sequence_len={self.max_sequence_len}, " f"divisible_by_patch={self.divisible_by_patch})") def _validate_range(value, name, length=2): # Validate type and length if not isinstance(value, Sequence) or len(value) != length: raise ValueError(f"{name} should be a sequence of length {length}.") # Validate order if value[0] > value[1]: warnings.warn(f"{name.capitalize()} range reversed. Swapping.") return value[1], value[0] return value class RandomResizedCropToSequence(torch.nn.Module): """ Randomly crop the input image to a subregion with varying area and aspect ratio (relative to the original), then resize that crop to a target size. The target size is determined such that patchifying the resized image (with `patch_size`) does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop. This combines aspects of torchvision's RandomResizedCrop with sequence length constraints. Args: patch_size (int or tuple[int, int]): Patch dimensions (patch_h, patch_w) for sequence length calculation. max_seq_len (int): Maximum number of patches allowed in the final image. scale (tuple[float, float]): Range (min, max) of area fraction of the original image to crop. ratio (tuple[float, float]): Range (min, max) of aspect ratio *multipliers* for the crop, relative to the original image's aspect ratio. E.g., (0.75, 1.333) means the crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar. Uses log-uniform sampling. interpolation (str or InterpolationMode): Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest', or 'random' (chooses between bilinear and bicubic). Defaults to 'bicubic'. divisible_by_patch (bool): If True, the final image height and width will be multiples of the respective patch dimensions. Defaults to True. max_ratio (float, optional): An optional upper limit on the scaling ratio applied during resizing. Prevents excessive upsampling of the initial crop. `max_ratio=1.0` prevents any upsampling beyond the cropped size. Defaults to None (no limit). final_scale_range (tuple[float, float], optional): If provided, applies an *additional* random scaling factor to the final target size. The factor is sampled uniformly from this range, and multiplied by the size determined by `get_image_size_for_seq`. E.g., (0.8, 1.0) means the final size will be between 80% and 100% of the maximum feasible size. Defaults to None (use maximum feasible size). attempts (int): Number of attempts to sample a valid crop geometry before falling back to a center crop strategy. Defaults to 10. """ def __init__( self, patch_size: Union[int, Tuple[int, int]] = 16, max_seq_len: int = 1024, scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (.8, 1.25), interpolation: Union[str, InterpolationMode] = 'bicubic', divisible_by_patch: bool = True, max_ratio: Optional[float] = None, final_scale_range: Optional[Tuple[float, float]] = None, attempts: int = 10, ): super().__init__() if isinstance(patch_size, int): self.patch_h, self.patch_w = patch_size, patch_size else: # Assume it's a tuple/list: (patch_h, patch_w) if len(patch_size) != 2: raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).") self.patch_h, self.patch_w = patch_size self.max_seq_len = max_seq_len self.scale = scale self.ratio = ratio self.divisible_by_patch = divisible_by_patch self.max_ratio = max_ratio self.final_scale_range = final_scale_range self.attempts = attempts if isinstance(interpolation, str): if interpolation == 'random': self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = str_to_interp_mode(interpolation) else: self.interpolation = interpolation # Validate scale and ratio self.scale = _validate_range(self.scale, "scale") self.ratio = _validate_range(self.ratio, "ratio") # Validate final_scale_range if provided if self.final_scale_range is not None: self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range") # Additional validation for final_scale_range values if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0): warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.") @staticmethod def get_params( img: torch.Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], crop_attempts: int = 10, patch_h: int = 16, patch_w: int = 16, max_seq_len: int = 1024, divisible_by_patch: bool = True, max_ratio: Optional[float] = None, final_scale_range: Optional[Tuple[float, float]] = None, interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION, ) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]: """ Get parameters for a random sized crop relative to image aspect ratio. """ _, height, width = F.get_dimensions(img) if height <= 0 or width <= 0: raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}") area = height * width orig_aspect = width / height log_ratio = (math.log(ratio[0]), math.log(ratio[1])) for _ in range(crop_attempts): target_area = area * random.uniform(scale[0], scale[1]) aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1])) aspect_ratio = orig_aspect * aspect_ratio_factor # Calculate target dimensions for the crop # target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h # => crop_h = sqrt(target_area / aspect_ratio) # => crop_w = sqrt(target_area * aspect_ratio) crop_h = int(round(math.sqrt(target_area / aspect_ratio))) crop_w = int(round(math.sqrt(target_area * aspect_ratio))) if 0 < crop_w <= width and 0 < crop_h <= height: top = random.randint(0, height - crop_h) left = random.randint(0, width - crop_w) break else: # Fallback strategy, use center crop trying to respect ratio range min_aspect_ratio = orig_aspect * ratio[0] max_aspect_ratio = orig_aspect * ratio[1] if orig_aspect < min_aspect_ratio: # Original is narrower than target min, clamp width crop_w = width crop_h = min(int(round(crop_w / min_aspect_ratio)), height) elif orig_aspect > max_aspect_ratio: # Original is wider than target max, clamp height crop_h = height crop_w = min(int(round(crop_h * max_aspect_ratio)), width) else: # Aspect ratio is within range, take the largest possible crop (full image) crop_w = width crop_h = height # Ensure valid dimensions after fallback calculation crop_h = max(1, crop_h) crop_w = max(1, crop_w) top = (height - crop_h) // 2 left = (width - crop_w) // 2 # Determine max feasible size for scaling of the *cropped* region feasible_ratio, feasible_size = get_image_size_for_seq( (crop_h, crop_w), patch_size=(patch_h, patch_w), # Pass as tuple max_seq_len=max_seq_len, divisible_by_patch=divisible_by_patch, max_ratio=max_ratio, ) # Optionally apply final scale randomization final_size = feasible_size if final_scale_range is not None: min_sc, max_sc = final_scale_range scale_factor = random.uniform(min_sc, max_sc) scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case # Calculate raw scaled size # Note: feasible_ratio already accounts for max_ratio clamp if any raw_h = crop_h * feasible_ratio * scale_factor raw_w = crop_w * feasible_ratio * scale_factor # Re-apply divisibility constraint if needed if divisible_by_patch: # Use ceil to avoid going under minimum patch size target_h = patch_h * math.ceil(raw_h / patch_h) target_w = patch_w * math.ceil(raw_w / patch_w) else: target_h = int(round(raw_h)) target_w = int(round(raw_w)) # Ensure final size is at least one patch dimension target_h = max(target_h, patch_h) target_w = max(target_w, patch_w) final_size = (target_h, target_w) # Final check: Ensure this randomized size still fits max_seq_len # (It should, as we scaled down, but rounding might theoretically push it over) num_patches_h = final_size[0] // patch_h num_patches_w = final_size[1] // patch_w if (num_patches_h * num_patches_w) > max_seq_len: # If it exceeds, revert to the original feasible_size (safest) final_size = feasible_size warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.") # Select interpolation mode if isinstance(interpolation, (tuple, list)): interpolation = random.choice(interpolation) else: interpolation = interpolation return (top, left, crop_h, crop_w), final_size, interpolation def forward(self, img: torch.Tensor) -> torch.Tensor: # Sample crop, resize, and interpolation parameters crop_params, final_size, interpolation = self.get_params( img, scale=self.scale, ratio=self.ratio, crop_attempts=self.attempts, patch_h=self.patch_h, patch_w=self.patch_w, divisible_by_patch=self.divisible_by_patch, max_seq_len=self.max_seq_len, final_scale_range=self.final_scale_range, interpolation=self.interpolation, ) top, left, crop_h, crop_w = crop_params output = F.resized_crop( img, top=top, left=left, height=crop_h, width=crop_w, size=final_size, interpolation=interpolation, antialias=True, ) return output def __repr__(self) -> str: if isinstance(self.interpolation, (tuple, list)): interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation) else: interpolate_str = str(self.interpolation) format_string = self.__class__.__name__ + '(' format_string += f"patch_size=({self.patch_h}, {self.patch_w})" format_string += f", max_seq_len={self.max_seq_len}" format_string += f", scale={self.scale}" format_string += f", ratio={self.ratio}" format_string += f", interpolation=[{interpolate_str}]" format_string += f", divisible_by_patch={self.divisible_by_patch}" format_string += f", max_ratio={self.max_ratio}" format_string += f", final_scale_range={self.final_scale_range}" format_string += f", attempts={self.attempts}" format_string += ')' return format_string def patchify( img: torch.Tensor, patch_size: Tuple[int, int], pad: bool = True, include_info: bool = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: c, h, w = img.shape ph, pw = patch_size # Ensure the image is divisible by patch size if pad and (h % ph != 0 or w % pw != 0): new_h = math.ceil(h / ph) * ph new_w = math.ceil(w / pw) * pw padded_img = torch.zeros(c, new_h, new_w, dtype=img.dtype) padded_img[:, :h, :w] = img img = padded_img c, h, w = img.shape # Calculate number of patches in each dimension nh, nw = h // ph, w // pw # Reshape image to patches [nh, nw, ph, pw, c] patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c) if include_info: # Create coordinate indices y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij') # Stack into a single coords tensor [N, 2] with (y, x) order coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1) # Create type indicators (all 1s for regular patches) valid = torch.ones(nh * nw, dtype=torch.bool) return patches, coord, valid return patches class Patchify(torch.nn.Module): """Transform an image into patches with corresponding coordinates and type indicators.""" def __init__(self, patch_size): super().__init__() self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) def forward(self, img): """ Args: img: A PIL Image or tensor of shape [C, H, W] Returns: A dictionary containing: - patches: Tensor of shape [N, P*P*C] where N is the number of patches - patch_coord: Tensor of shape [N, 2] with (y, x) coordinates - patch_valid: Valid indicator (all 1s for non-padding patches) """ if isinstance(img, Image.Image): # Convert PIL Image to tensor [C, H, W] img = transforms.functional.to_tensor(img) patches, coord, valid = patchify(img, self.patch_size) return { 'patches': patches, 'patch_coord': coord, 'patch_valid': valid, }