mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
A much faster resample_patch_embed, can be used at train/validation time
This commit is contained in:
parent
b4bb0f452a
commit
97341fec51
@ -10,7 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
@ -180,7 +180,8 @@ class PatchEmbedWithSize(PatchEmbed):
|
||||
return x, feat_size
|
||||
|
||||
|
||||
def resample_patch_embed(
|
||||
# FIXME to remove, keeping for comparison for now
|
||||
def resample_patch_embed_old(
|
||||
patch_embed,
|
||||
new_size: List[int],
|
||||
interpolation: str = 'bicubic',
|
||||
@ -250,6 +251,191 @@ def resample_patch_embed(
|
||||
return patch_embed
|
||||
|
||||
|
||||
DTYPE_INTERMEDIATE = torch.float32
|
||||
|
||||
|
||||
def _compute_resize_matrix(
|
||||
old_size: Tuple[int, int],
|
||||
new_size: Tuple[int, int],
|
||||
interpolation: str,
|
||||
antialias: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = DTYPE_INTERMEDIATE
|
||||
) -> torch.Tensor:
|
||||
"""Computes the resize matrix basis vectors and interpolates them to new_size."""
|
||||
old_h, old_w = old_size
|
||||
new_h, new_w = new_size
|
||||
old_total = old_h * old_w
|
||||
new_total = new_h * new_w
|
||||
|
||||
eye_matrix = torch.eye(old_total, device=device, dtype=dtype)
|
||||
basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w)
|
||||
|
||||
resized_basis_vectors_batch = F.interpolate(
|
||||
basis_vectors_batch,
|
||||
size=new_size,
|
||||
mode=interpolation,
|
||||
antialias=antialias,
|
||||
align_corners=False
|
||||
) # Output shape: (old_total, 1, new_h, new_w)
|
||||
|
||||
resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T
|
||||
return resize_matrix # Shape: (new_total, old_total)
|
||||
|
||||
|
||||
def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculates the pseudoinverse matrix used for the resampling operation."""
|
||||
pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total)
|
||||
return pinv_matrix
|
||||
|
||||
|
||||
def _apply_resampling(
|
||||
patch_embed: torch.Tensor,
|
||||
pinv_matrix: torch.Tensor,
|
||||
new_size_tuple: Tuple[int, int],
|
||||
orig_dtype: torch.dtype,
|
||||
intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
|
||||
) -> torch.Tensor:
|
||||
"""Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
|
||||
try:
|
||||
from torch import vmap
|
||||
except ImportError:
|
||||
from functorch import vmap
|
||||
|
||||
def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
|
||||
kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
|
||||
resampled_kernel_flat = pinv_matrix @ kernel_flat
|
||||
return resampled_kernel_flat.reshape(new_size_tuple)
|
||||
|
||||
resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
|
||||
patch_embed_float = patch_embed.to(intermediate_dtype)
|
||||
resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
|
||||
return resampled_patch_embed.to(orig_dtype)
|
||||
|
||||
|
||||
def resample_patch_embed(
|
||||
patch_embed: torch.Tensor,
|
||||
new_size: List[int],
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
verbose: bool = False,
|
||||
):
|
||||
""" Standalone function (computes matrix on each call). """
|
||||
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)"
|
||||
assert len(new_size) == 2, "New shape should only be hw (height, width)"
|
||||
|
||||
old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])
|
||||
new_size_tuple: Tuple[int, int] = tuple(new_size)
|
||||
|
||||
if old_size_tuple == new_size_tuple:
|
||||
return patch_embed
|
||||
|
||||
device = patch_embed.device
|
||||
orig_dtype = patch_embed.dtype
|
||||
|
||||
resize_mat = _compute_resize_matrix(
|
||||
old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE
|
||||
)
|
||||
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
|
||||
resampled_patch_embed = _apply_resampling(
|
||||
patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE
|
||||
)
|
||||
return resampled_patch_embed
|
||||
|
||||
|
||||
class PatchEmbedResamplerFixedOrigSize(nn.Module):
|
||||
"""
|
||||
Resample patch embedding weights from a fixed original size,
|
||||
caching the pseudoinverse matrix based on the target size.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
orig_size: Tuple[int, int],
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors.
|
||||
interpolation (str): Interpolation mode.
|
||||
antialias (bool): Use anti-aliasing filter in resize.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(orig_size, tuple) and len(orig_size) == 2, \
|
||||
"`orig_size` must be a tuple of (height, width)"
|
||||
self.orig_size = orig_size # expected original size
|
||||
self.interpolation = interpolation
|
||||
self.antialias = antialias
|
||||
# Cache map key is the target new_size tuple
|
||||
self._pinv_cache_map: Dict[Tuple[int, int], str] = {}
|
||||
|
||||
def _get_or_create_pinv_matrix(
|
||||
self,
|
||||
new_size: Tuple[int, int],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = DTYPE_INTERMEDIATE
|
||||
) -> torch.Tensor:
|
||||
"""Retrieves the cached pinv matrix or computes and caches it for the given new_size."""
|
||||
cache_key = new_size
|
||||
buffer_name = self._pinv_cache_map.get(cache_key)
|
||||
|
||||
if buffer_name and hasattr(self, buffer_name):
|
||||
pinv_matrix = getattr(self, buffer_name)
|
||||
if pinv_matrix.device == device and pinv_matrix.dtype == dtype:
|
||||
return pinv_matrix
|
||||
|
||||
# Calculate the matrix if not cached or needs update
|
||||
resize_mat = _compute_resize_matrix(
|
||||
self.orig_size, new_size, self.interpolation, self.antialias, device, dtype
|
||||
)
|
||||
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
|
||||
|
||||
# Cache using register_buffer
|
||||
buffer_name = f"pinv_{new_size[0]}x{new_size[1]}"
|
||||
if hasattr(self, buffer_name):
|
||||
delattr(self, buffer_name)
|
||||
self.register_buffer(buffer_name, pinv_matrix)
|
||||
self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name
|
||||
|
||||
return pinv_matrix
|
||||
|
||||
def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor:
|
||||
""" Resamples the patch embedding weights to new_size.
|
||||
|
||||
Args:
|
||||
patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig).
|
||||
new_size (List[int]): Target [height, width].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Resampled weights.
|
||||
"""
|
||||
assert len(patch_embed.shape) == 4
|
||||
assert len(new_size) == 2
|
||||
|
||||
# Input Validation
|
||||
input_size = tuple(patch_embed.shape[-2:])
|
||||
assert input_size == self.orig_size, \
|
||||
f"Input patch_embed spatial size {input_size} does not match " \
|
||||
f"module's expected original size {self.orig_size}"
|
||||
|
||||
new_size_tuple: Tuple[int, int] = tuple(new_size)
|
||||
|
||||
# Check no-op case against self.orig_size
|
||||
if self.orig_size == new_size_tuple:
|
||||
return patch_embed
|
||||
|
||||
device = patch_embed.device
|
||||
orig_dtype = patch_embed.dtype
|
||||
|
||||
# Get or compute the required pseudoinverse matrix
|
||||
pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device)
|
||||
|
||||
# Apply the resampling
|
||||
resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype)
|
||||
|
||||
return resampled_patch_embed
|
||||
|
||||
|
||||
# def divs(n, m=None):
|
||||
# m = m or n // 2
|
||||
# if m == 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user