A much faster resample_patch_embed, can be used at train/validation time

This commit is contained in:
Ross Wightman 2025-04-10 15:58:24 -07:00
parent b4bb0f452a
commit 97341fec51

View File

@ -10,7 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn as nn
@ -180,7 +180,8 @@ class PatchEmbedWithSize(PatchEmbed):
return x, feat_size
def resample_patch_embed(
# FIXME to remove, keeping for comparison for now
def resample_patch_embed_old(
patch_embed,
new_size: List[int],
interpolation: str = 'bicubic',
@ -250,6 +251,191 @@ def resample_patch_embed(
return patch_embed
DTYPE_INTERMEDIATE = torch.float32
def _compute_resize_matrix(
old_size: Tuple[int, int],
new_size: Tuple[int, int],
interpolation: str,
antialias: bool,
device: torch.device,
dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
"""Computes the resize matrix basis vectors and interpolates them to new_size."""
old_h, old_w = old_size
new_h, new_w = new_size
old_total = old_h * old_w
new_total = new_h * new_w
eye_matrix = torch.eye(old_total, device=device, dtype=dtype)
basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w)
resized_basis_vectors_batch = F.interpolate(
basis_vectors_batch,
size=new_size,
mode=interpolation,
antialias=antialias,
align_corners=False
) # Output shape: (old_total, 1, new_h, new_w)
resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T
return resize_matrix # Shape: (new_total, old_total)
def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor:
"""Calculates the pseudoinverse matrix used for the resampling operation."""
pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total)
return pinv_matrix
def _apply_resampling(
patch_embed: torch.Tensor,
pinv_matrix: torch.Tensor,
new_size_tuple: Tuple[int, int],
orig_dtype: torch.dtype,
intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
"""Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
try:
from torch import vmap
except ImportError:
from functorch import vmap
def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
resampled_kernel_flat = pinv_matrix @ kernel_flat
return resampled_kernel_flat.reshape(new_size_tuple)
resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
patch_embed_float = patch_embed.to(intermediate_dtype)
resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
return resampled_patch_embed.to(orig_dtype)
def resample_patch_embed(
patch_embed: torch.Tensor,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
""" Standalone function (computes matrix on each call). """
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)"
assert len(new_size) == 2, "New shape should only be hw (height, width)"
old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])
new_size_tuple: Tuple[int, int] = tuple(new_size)
if old_size_tuple == new_size_tuple:
return patch_embed
device = patch_embed.device
orig_dtype = patch_embed.dtype
resize_mat = _compute_resize_matrix(
old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE
)
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
resampled_patch_embed = _apply_resampling(
patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE
)
return resampled_patch_embed
class PatchEmbedResamplerFixedOrigSize(nn.Module):
"""
Resample patch embedding weights from a fixed original size,
caching the pseudoinverse matrix based on the target size.
"""
def __init__(
self,
orig_size: Tuple[int, int],
interpolation: str = 'bicubic',
antialias: bool = True
):
"""
Args:
orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors.
interpolation (str): Interpolation mode.
antialias (bool): Use anti-aliasing filter in resize.
"""
super().__init__()
assert isinstance(orig_size, tuple) and len(orig_size) == 2, \
"`orig_size` must be a tuple of (height, width)"
self.orig_size = orig_size # expected original size
self.interpolation = interpolation
self.antialias = antialias
# Cache map key is the target new_size tuple
self._pinv_cache_map: Dict[Tuple[int, int], str] = {}
def _get_or_create_pinv_matrix(
self,
new_size: Tuple[int, int],
device: torch.device,
dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
"""Retrieves the cached pinv matrix or computes and caches it for the given new_size."""
cache_key = new_size
buffer_name = self._pinv_cache_map.get(cache_key)
if buffer_name and hasattr(self, buffer_name):
pinv_matrix = getattr(self, buffer_name)
if pinv_matrix.device == device and pinv_matrix.dtype == dtype:
return pinv_matrix
# Calculate the matrix if not cached or needs update
resize_mat = _compute_resize_matrix(
self.orig_size, new_size, self.interpolation, self.antialias, device, dtype
)
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
# Cache using register_buffer
buffer_name = f"pinv_{new_size[0]}x{new_size[1]}"
if hasattr(self, buffer_name):
delattr(self, buffer_name)
self.register_buffer(buffer_name, pinv_matrix)
self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name
return pinv_matrix
def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor:
""" Resamples the patch embedding weights to new_size.
Args:
patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig).
new_size (List[int]): Target [height, width].
Returns:
torch.Tensor: Resampled weights.
"""
assert len(patch_embed.shape) == 4
assert len(new_size) == 2
# Input Validation
input_size = tuple(patch_embed.shape[-2:])
assert input_size == self.orig_size, \
f"Input patch_embed spatial size {input_size} does not match " \
f"module's expected original size {self.orig_size}"
new_size_tuple: Tuple[int, int] = tuple(new_size)
# Check no-op case against self.orig_size
if self.orig_size == new_size_tuple:
return patch_embed
device = patch_embed.device
orig_dtype = patch_embed.dtype
# Get or compute the required pseudoinverse matrix
pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device)
# Apply the resampling
resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype)
return resampled_patch_embed
# def divs(n, m=None):
# m = m or n // 2
# if m == 1: