set_input_size initial impl for vit & swin v1. Move HybridEmbed to own location in timm/layers

This commit is contained in:
Ross Wightman 2024-07-17 15:25:48 -07:00
parent f920119f3b
commit 392b78aee7
9 changed files with 476 additions and 230 deletions

View File

@ -27,6 +27,7 @@ from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .grid import ndgrid, meshgrid
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d

253
timm/layers/hybrid_embed.py Normal file
View File

@ -0,0 +1,253 @@
""" Image to Patch Hybird Embedding Layer
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn as nn
import torch.nn.functional as F
from .format import Format, nchw_to
from .helpers import to_2tuple
from .patch_embed import resample_patch_embed
_logger = logging.getLogger(__name__)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
proj: bool = True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
assert isinstance(backbone, nn.Module)
self.backbone = backbone
self.in_chans = in_chans
(
self.img_size,
self.patch_size,
self.feature_size,
self.feature_ratio,
self.feature_dim,
self.grid_size,
self.num_patches,
) = self._init_backbone(
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
)
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
if not dynamic_img_pad:
assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0
if proj:
self.proj = nn.Conv2d(
self.feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
)
else:
assert self.feature_dim == embed_dim, \
f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
self.proj = nn.Identity()
def _init_backbone(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
feature_dim: Optional[int] = None,
):
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
if feature_size is None:
with torch.no_grad():
# NOTE Most reliable way of determining output dims is to run forward pass
training = self.backbone.training
if training:
self.backbone.eval()
o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
self.backbone.train(training)
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
else:
feature_size = to_2tuple(feature_size)
feature_ratio = to_2tuple(feature_ratio or 16)
if feature_dim is None:
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)])
num_patches = grid_size[0] * grid_size[1]
return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches
def set_input_size(
self,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
feature_dim: Optional[int] = None,
):
assert img_size is not None or patch_size is not None
img_size = img_size or self.img_size
new_patch_size = None
if patch_size is not None:
new_patch_size = to_2tuple(patch_size)
if new_patch_size is not None and new_patch_size != self.patch_size:
assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.'
with torch.no_grad():
new_proj = nn.Conv2d(
self.proj.in_channels,
self.proj.out_channels,
kernel_size=new_patch_size,
stride=new_patch_size,
bias=self.proj.bias is not None,
)
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
if self.proj.bias is not None:
new_proj.bias.copy_(self.proj.bias)
self.proj = new_proj
patch_size = new_patch_size
patch_size = patch_size or self.patch_size
if img_size != self.img_size or patch_size != self.patch_size:
(
self.img_size,
self.patch_size,
self.feature_size,
self.feature_ratio,
self.feature_dim,
self.grid_size,
self.num_patches,
) = self._init_backbone(
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
feature_dim=feature_dim,
)
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
self.feature_ratio[0] * self.patch_size[0],
self.feature_ratio[1] * self.patch_size[1]
)
if as_scalar:
return max(total_reduction)
else:
return total_reduction
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
"""
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
if self.dynamic_img_pad:
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
_, _, H, W = x.shape
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
return x
class HybridEmbedWithSize(HybridEmbed):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias=True,
proj=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
proj=proj,
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
return x.flatten(2).transpose(1, 2), x.shape[-2:]

View File

@ -44,14 +44,7 @@ class PatchEmbed(nn.Module):
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
if img_size is not None:
self.img_size = to_2tuple(img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
if output_fmt is not None:
self.flatten = False
@ -66,6 +59,41 @@ class PatchEmbed(nn.Module):
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
assert self.patch_size
if img_size is None:
return None, None, None
img_size = to_2tuple(img_size)
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
num_patches = grid_size[0] * grid_size[1]
return img_size, grid_size, num_patches
def set_input_size(
self,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
):
new_patch_size = None
if patch_size is not None:
new_patch_size = to_2tuple(patch_size)
if new_patch_size is not None and new_patch_size != self.patch_size:
with torch.no_grad():
new_proj = nn.Conv2d(
self.proj.in_channels,
self.proj.out_channels,
kernel_size=new_patch_size,
stride=new_patch_size,
bias=self.proj.bias is not None,
)
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
if self.proj.bias is not None:
new_proj.bias.copy_(self.proj.bias)
self.proj = new_proj
self.patch_size = new_patch_size
img_size = img_size or self.img_size
if img_size != self.img_size or new_patch_size is not None:
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
if as_scalar:
return max(self.patch_size)
@ -180,13 +208,9 @@ def resample_patch_embed(
"""
import numpy as np
try:
import functorch
vmap = functorch.vmap
from torch import vmap
except ImportError:
if hasattr(torch, 'vmap'):
vmap = torch.vmap
else:
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
from functorch import vmap
assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw"

View File

@ -27,11 +27,10 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._registry import register_model, generate_default_cfgs
from .vision_transformer_hybrid import HybridEmbed
__all__ = ['ConVit']

View File

@ -140,6 +140,27 @@ class WindowAttention(nn.Module):
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def set_window_size(self, window_size: Tuple[int, int]) -> None:
"""Update window size & interpolate position embeddings
Args:
window_size (int): New window size
"""
window_size = to_2tuple(window_size)
if window_size == self.window_size:
return
self.window_size = window_size
win_h, win_w = self.window_size
self.window_area = win_h * win_w
with torch.no_grad():
new_bias_shape = (2 * win_h - 1) * (2 * win_w - 1), self.num_heads
self.relative_position_bias_table = nn.Parameter(
resize_rel_pos_bias_table(
self.relative_position_bias_table,
new_window_size=self.window_size,
new_bias_shape=new_bias_shape,
))
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
@ -197,6 +218,7 @@ class SwinTransformerBlock(nn.Module):
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
shift_size: int = 0,
always_partition: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
proj_drop: float = 0.,
@ -224,9 +246,9 @@ class SwinTransformerBlock(nn.Module):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
ws, ss = self._calc_window_shift(window_size, shift_size)
self.window_size: Tuple[int, int] = ws
self.shift_size: Tuple[int, int] = ss
self.target_shift_size = to_2tuple(shift_size)
self.always_partition = always_partition
self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size)
self.window_area = self.window_size[0] * self.window_size[1]
self.mlp_ratio = mlp_ratio
@ -251,6 +273,9 @@ class SwinTransformerBlock(nn.Module):
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self._make_attention_mask()
def _make_attention_mask(self):
if any(self.shift_size):
# calculate attention mask for SW-MSA
H, W = self.input_resolution
@ -274,16 +299,47 @@ class SwinTransformerBlock(nn.Module):
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask, persistent=False)
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
def _calc_window_shift(
self,
target_window_size: Union[int, Tuple[int, int]],
target_shift_size: Optional[Union[int, Tuple[int, int]]] = None,
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size)
if target_shift_size is None:
# if passed value is None, recalculate from default window_size // 2 if it was active
target_shift_size = self.target_shift_size
if any(target_shift_size):
target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2]
else:
target_shift_size = to_2tuple(target_shift_size)
if self.always_partition:
return target_window_size, target_shift_size
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size)
def set_input_size(
self,
feat_size: Tuple[int, int],
window_size: Tuple[int, int],
always_partition: Optional[bool] = None,
):
"""
Args:
feat_size: New input resolution
window_size: New window size
always_partition: Change always_partition attribute if not None
"""
self.input_resolution = feat_size
if always_partition is not None:
self.always_partition = always_partition
self.window_size, self.shift_size = self._calc_window_shift(window_size)
self.window_area = self.window_size[0] * self.window_size[1]
self.attn.set_window_size(self.window_size)
self._make_attention_mask()
def _attn(self, x):
B, H, W, C = x.shape
@ -374,6 +430,7 @@ class SwinTransformerStage(nn.Module):
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
always_partition: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
proj_drop: float = 0.,
@ -427,6 +484,7 @@ class SwinTransformerStage(nn.Module):
head_dim=head_dim,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else shift_size,
always_partition=always_partition,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop,
@ -436,6 +494,30 @@ class SwinTransformerStage(nn.Module):
)
for i in range(depth)])
def set_input_size(
self,
feat_size: Tuple[int, int],
window_size: int,
always_partition: Optional[bool] = None,
):
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
Args:
feat_size (Tuple[int, int]): New input resolution
window_size (int): New window size
"""
self.input_resolution = feat_size
if isinstance(self.downsample, nn.Identity):
self.output_resolution = feat_size
else:
self.output_resolution = tuple(i // 2 for i in feat_size)
for block in self.blocks:
block.set_input_size(
feat_size=self.output_resolution,
window_size=window_size,
always_partition=always_partition,
)
def forward(self, x):
x = self.downsample(x)
@ -465,6 +547,7 @@ class SwinTransformer(nn.Module):
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
always_partition: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
@ -546,6 +629,7 @@ class SwinTransformer(nn.Module):
num_heads=num_heads[i],
head_dim=head_dim[i],
window_size=window_size[i],
always_partition=always_partition,
mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
@ -556,7 +640,7 @@ class SwinTransformer(nn.Module):
in_dim = out_dim
if i > 0:
scale *= 2
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')]
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
@ -584,6 +668,36 @@ class SwinTransformer(nn.Module):
nwd.add(n)
return nwd
def set_input_size(
self,
img_size: Optional[Tuple[int, int]] = None,
patch_size: Optional[Tuple[int, int]] = None,
window_size: Optional[Tuple[int, int]] = None,
window_ratio: int = 32,
always_partition: Optional[bool] = None,
) -> None:
""" Updates the image resolution and window size.
Args:
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
window_ratio (int): divisor for calculating window size from image size
"""
if img_size is not None or patch_size is not None:
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
self.patch_grid = self.patch_embed.grid_size
if window_size is None:
img_size = self.patch_embed.img_size
window_size = tuple([s // window_ratio for s in img_size])
for index, stage in enumerate(self.layers):
stage_scale = 2 ** max(index - 1, 0)
print(self.patch_grid, stage_scale)
stage.set_input_size(
feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale),
window_size=window_size,
always_partition=always_partition,
)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(

View File

@ -119,7 +119,7 @@ class WindowMultiHeadAttention(nn.Module):
assert dim % num_heads == 0, \
"The number of input features (in_features) are not divisible by the number of heads (num_heads)."
self.in_features: int = dim
self.window_size: Tuple[int, int] = window_size
self.window_size: Tuple[int, int] = to_2tuple(window_size)
self.num_heads: int = num_heads
self.sequential_attn: bool = sequential_attn
@ -152,16 +152,15 @@ class WindowMultiHeadAttention(nn.Module):
1.0 + relative_coordinates.abs())
self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:
"""Method updates the window size and so the pair-wise relative positions
def set_window_size(self, window_size: Tuple[int, int]) -> None:
"""Update window size & interpolate position embeddings
Args:
new_window_size (int): New window size
kwargs (Any): Unused
window_size (int): New window size
"""
# Set new window size and new pair-wise relative positions
self.window_size: int = new_window_size
self._make_pair_wise_relative_positions()
window_size = to_2tuple(window_size)
if window_size != self.window_size:
self.window_size = window_size
self._make_pair_wise_relative_positions()
def _relative_positional_encodings(self) -> torch.Tensor:
"""Method computes the relative positional encodings
@ -321,18 +320,18 @@ class SwinTransformerV2CrBlock(nn.Module):
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
def set_input_size(self, feat_size: Tuple[int, int], window_size: Tuple[int, int]) -> None:
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_feat_size (Tuple[int, int]): New input resolution
feat_size (Tuple[int, int]): New input resolution
window_size (int): New window size
"""
# Update input resolution
self.feat_size: Tuple[int, int] = new_feat_size
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))
self.feat_size: Tuple[int, int] = feat_size
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
self.window_area = self.window_size[0] * self.window_size[1]
self.attn.update_input_size(new_window_size=self.window_size)
self.attn.set_window_size(self.window_size)
self._make_attention_mask()
def _shifted_window_attn(self, x):
@ -510,18 +509,18 @@ class SwinTransformerV2CrStage(nn.Module):
for index in range(depth)]
)
def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:
def set_input_size(self, feat_size: Tuple[int, int], window_size: int) -> None:
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_feat_size (Tuple[int, int]): New input resolution
window_size (int): New window size
feat_size (Tuple[int, int]): New input resolution
"""
self.feat_size: Tuple[int, int] = (
(new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size
(feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size
)
for block in self.blocks:
block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)
block.set_input_size(feat_size=self.feat_size, window_size=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
@ -657,33 +656,32 @@ class SwinTransformerV2Cr(nn.Module):
if weight_init != 'skip':
named_apply(init_weights, self)
def update_input_size(
def set_input_size(
self,
new_img_size: Optional[Tuple[int, int]] = None,
new_window_size: Optional[int] = None,
img_window_ratio: int = 32,
img_size: Optional[Tuple[int, int]] = None,
window_size: Optional[Tuple[int, int]] = None,
window_ratio: int = 32,
) -> None:
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
Args:
new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div
new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
img_window_ratio (int): divisor for calculating window size from image size
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
window_ratio (int): divisor for calculating window size from image size
"""
# Check parameters
if new_img_size is None:
new_img_size = self.img_size
if img_size is None:
img_size = self.img_size
else:
new_img_size = to_2tuple(new_img_size)
if new_window_size is None:
new_window_size = tuple([s // img_window_ratio for s in new_img_size])
img_size = to_2tuple(img_size)
if window_size is None:
window_size = tuple([s // window_ratio for s in img_size])
# Compute new patch resolution & update resolution of each stage
new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)
patch_grid_size = (img_size[0] // self.patch_size, img_size[1] // self.patch_size)
for index, stage in enumerate(self.stages):
stage_scale = 2 ** max(index - 1, 0)
stage.update_input_size(
new_window_size=new_window_size,
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
stage.set_input_size(
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
window_size=window_size,
)
@torch.jit.ignore

View File

@ -632,6 +632,31 @@ class VisionTransformer(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def set_input_size(
self,
img_size: Optional[Tuple[int, int]] = None,
patch_size: Optional[Tuple[int, int]] = None,
):
"""Method updates the input image resolution, patch size
Args:
img_size: New input resolution, if None current resolution is used
patch_size: New patch size, if None existing patch size is used
"""
prev_grid_size = self.patch_embed.grid_size
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
if self.pos_embed is not None:
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
if num_new_tokens != self.pos_embed.shape[1]:
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
self.pos_embed,
new_size=self.patch_embed.grid_size,
old_size=prev_grid_size,
num_prefix_tokens=num_prefix_tokens,
verbose=True,
))
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.pos_embed is None:
return x.view(x.shape[0], -1, x.shape[-1])

View File

@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, HybridEmbed
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -31,172 +30,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import VisionTransformer
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
proj: bool = True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# NOTE Most reliable way of determining output dims is to run forward pass
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
else:
feature_size = to_2tuple(feature_size)
feature_ratio = to_2tuple(feature_ratio or 16)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
if not dynamic_img_pad:
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.feature_size = feature_size
self.feature_ratio = feature_ratio
self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
if proj:
self.proj = nn.Conv2d(
feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
)
else:
assert feature_dim == embed_dim,\
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
self.proj = nn.Identity()
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
self.feature_ratio[0] * self.patch_size[0],
self.feature_ratio[1] * self.patch_size[1]
)
if as_scalar:
return max(total_reduction)
else:
return total_reduction
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
"""
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
if self.dynamic_img_pad:
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
_, _, H, W = x.shape
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
return x
class HybridEmbedWithSize(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias=True,
proj=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
proj=proj,
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
return x.flatten(2).transpose(1, 2), x.shape[-2:]
class ConvStem(nn.Sequential):
def __init__(
self,

View File

@ -29,12 +29,11 @@ import torch.nn as nn
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
make_divisible, DropPath
make_divisible, DropPath, HybridEmbed
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
from .vision_transformer_hybrid import HybridEmbed
@dataclass