diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index c739291b..dab7acc9 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -10,7 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import logging import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn as nn @@ -180,7 +180,8 @@ class PatchEmbedWithSize(PatchEmbed): return x, feat_size -def resample_patch_embed( +# FIXME to remove, keeping for comparison for now +def resample_patch_embed_old( patch_embed, new_size: List[int], interpolation: str = 'bicubic', @@ -250,6 +251,191 @@ def resample_patch_embed( return patch_embed +DTYPE_INTERMEDIATE = torch.float32 + + +def _compute_resize_matrix( + old_size: Tuple[int, int], + new_size: Tuple[int, int], + interpolation: str, + antialias: bool, + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE +) -> torch.Tensor: + """Computes the resize matrix basis vectors and interpolates them to new_size.""" + old_h, old_w = old_size + new_h, new_w = new_size + old_total = old_h * old_w + new_total = new_h * new_w + + eye_matrix = torch.eye(old_total, device=device, dtype=dtype) + basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w) + + resized_basis_vectors_batch = F.interpolate( + basis_vectors_batch, + size=new_size, + mode=interpolation, + antialias=antialias, + align_corners=False + ) # Output shape: (old_total, 1, new_h, new_w) + + resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T + return resize_matrix # Shape: (new_total, old_total) + + +def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor: + """Calculates the pseudoinverse matrix used for the resampling operation.""" + pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total) + return pinv_matrix + + +def _apply_resampling( + patch_embed: torch.Tensor, + pinv_matrix: torch.Tensor, + new_size_tuple: Tuple[int, int], + orig_dtype: torch.dtype, + intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE +) -> torch.Tensor: + """Applies the precomputed pinv_matrix to resample the patch_embed tensor.""" + try: + from torch import vmap + except ImportError: + from functorch import vmap + + def resample_kernel(kernel: torch.Tensor) -> torch.Tensor: + kernel_flat = kernel.reshape(-1).to(intermediate_dtype) + resampled_kernel_flat = pinv_matrix @ kernel_flat + return resampled_kernel_flat.reshape(new_size_tuple) + + resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0) + patch_embed_float = patch_embed.to(intermediate_dtype) + resampled_patch_embed = resample_kernel_vmap(patch_embed_float) + return resampled_patch_embed.to(orig_dtype) + + +def resample_patch_embed( + patch_embed: torch.Tensor, + new_size: List[int], + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + """ Standalone function (computes matrix on each call). """ + assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)" + assert len(new_size) == 2, "New shape should only be hw (height, width)" + + old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:]) + new_size_tuple: Tuple[int, int] = tuple(new_size) + + if old_size_tuple == new_size_tuple: + return patch_embed + + device = patch_embed.device + orig_dtype = patch_embed.dtype + + resize_mat = _compute_resize_matrix( + old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE + ) + pinv_matrix = _compute_pinv_for_resampling(resize_mat) + resampled_patch_embed = _apply_resampling( + patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE + ) + return resampled_patch_embed + + +class PatchEmbedResamplerFixedOrigSize(nn.Module): + """ + Resample patch embedding weights from a fixed original size, + caching the pseudoinverse matrix based on the target size. + """ + def __init__( + self, + orig_size: Tuple[int, int], + interpolation: str = 'bicubic', + antialias: bool = True + ): + """ + Args: + orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors. + interpolation (str): Interpolation mode. + antialias (bool): Use anti-aliasing filter in resize. + """ + super().__init__() + assert isinstance(orig_size, tuple) and len(orig_size) == 2, \ + "`orig_size` must be a tuple of (height, width)" + self.orig_size = orig_size # expected original size + self.interpolation = interpolation + self.antialias = antialias + # Cache map key is the target new_size tuple + self._pinv_cache_map: Dict[Tuple[int, int], str] = {} + + def _get_or_create_pinv_matrix( + self, + new_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE + ) -> torch.Tensor: + """Retrieves the cached pinv matrix or computes and caches it for the given new_size.""" + cache_key = new_size + buffer_name = self._pinv_cache_map.get(cache_key) + + if buffer_name and hasattr(self, buffer_name): + pinv_matrix = getattr(self, buffer_name) + if pinv_matrix.device == device and pinv_matrix.dtype == dtype: + return pinv_matrix + + # Calculate the matrix if not cached or needs update + resize_mat = _compute_resize_matrix( + self.orig_size, new_size, self.interpolation, self.antialias, device, dtype + ) + pinv_matrix = _compute_pinv_for_resampling(resize_mat) + + # Cache using register_buffer + buffer_name = f"pinv_{new_size[0]}x{new_size[1]}" + if hasattr(self, buffer_name): + delattr(self, buffer_name) + self.register_buffer(buffer_name, pinv_matrix) + self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name + + return pinv_matrix + + def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor: + """ Resamples the patch embedding weights to new_size. + + Args: + patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig). + new_size (List[int]): Target [height, width]. + + Returns: + torch.Tensor: Resampled weights. + """ + assert len(patch_embed.shape) == 4 + assert len(new_size) == 2 + + # Input Validation + input_size = tuple(patch_embed.shape[-2:]) + assert input_size == self.orig_size, \ + f"Input patch_embed spatial size {input_size} does not match " \ + f"module's expected original size {self.orig_size}" + + new_size_tuple: Tuple[int, int] = tuple(new_size) + + # Check no-op case against self.orig_size + if self.orig_size == new_size_tuple: + return patch_embed + + device = patch_embed.device + orig_dtype = patch_embed.dtype + + # Get or compute the required pseudoinverse matrix + pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device) + + # Apply the resampling + resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype) + + return resampled_patch_embed + + # def divs(n, m=None): # m = m or n // 2 # if m == 1: