From 392b78aee7c30c7e8cdd321aeb7a0d9818bfedee Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Jul 2024 15:25:48 -0700 Subject: [PATCH] set_input_size initial impl for vit & swin v1. Move HybridEmbed to own location in timm/layers --- timm/layers/__init__.py | 1 + timm/layers/hybrid_embed.py | 253 +++++++++++++++++++++++ timm/layers/patch_embed.py | 52 +++-- timm/models/convit.py | 3 +- timm/models/swin_transformer.py | 128 +++++++++++- timm/models/swin_transformer_v2_cr.py | 72 ++++--- timm/models/vision_transformer.py | 25 +++ timm/models/vision_transformer_hybrid.py | 169 +-------------- timm/models/vitamin.py | 3 +- 9 files changed, 476 insertions(+), 230 deletions(-) create mode 100644 timm/layers/hybrid_embed.py diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 3f023572..38c82407 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -27,6 +27,7 @@ from .gather_excite import GatherExcite from .global_context import GlobalContext from .grid import ndgrid, meshgrid from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple +from .hybrid_embed import HybridEmbed, HybridEmbedWithSize from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d diff --git a/timm/layers/hybrid_embed.py b/timm/layers/hybrid_embed.py new file mode 100644 index 00000000..de57a2e9 --- /dev/null +++ b/timm/layers/hybrid_embed.py @@ -0,0 +1,253 @@ +""" Image to Patch Hybird Embedding Layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn as nn +import torch.nn.functional as F + +from .format import Format, nchw_to +from .helpers import to_2tuple +from .patch_embed import resample_patch_embed + + +_logger = logging.getLogger(__name__) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + output_fmt: Format + dynamic_img_pad: torch.jit.Final[bool] + + def __init__( + self, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, + bias: bool = True, + proj: bool = True, + flatten: bool = True, + output_fmt: Optional[str] = None, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + ): + super().__init__() + assert isinstance(backbone, nn.Module) + self.backbone = backbone + self.in_chans = in_chans + ( + self.img_size, + self.patch_size, + self.feature_size, + self.feature_ratio, + self.feature_dim, + self.grid_size, + self.num_patches, + ) = self._init_backbone( + img_size=img_size, + patch_size=patch_size, + feature_size=feature_size, + feature_ratio=feature_ratio, + ) + + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if not dynamic_img_pad: + assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0 + + if proj: + self.proj = nn.Conv2d( + self.feature_dim, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + ) + else: + assert self.feature_dim == embed_dim, \ + f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.' + self.proj = nn.Identity() + + def _init_backbone( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + feature_dim: Optional[int] = None, + ): + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = self.backbone.training + if training: + self.backbone.eval() + o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + self.backbone.train(training) + feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)]) + else: + feature_size = to_2tuple(feature_size) + feature_ratio = to_2tuple(feature_ratio or 16) + if feature_dim is None: + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches + + def set_input_size( + self, + img_size: Optional[Union[int, Tuple[int, int]]] = None, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + feature_dim: Optional[int] = None, + ): + assert img_size is not None or patch_size is not None + img_size = img_size or self.img_size + new_patch_size = None + if patch_size is not None: + new_patch_size = to_2tuple(patch_size) + if new_patch_size is not None and new_patch_size != self.patch_size: + assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.' + with torch.no_grad(): + new_proj = nn.Conv2d( + self.proj.in_channels, + self.proj.out_channels, + kernel_size=new_patch_size, + stride=new_patch_size, + bias=self.proj.bias is not None, + ) + new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)) + if self.proj.bias is not None: + new_proj.bias.copy_(self.proj.bias) + self.proj = new_proj + patch_size = new_patch_size + patch_size = patch_size or self.patch_size + + if img_size != self.img_size or patch_size != self.patch_size: + ( + self.img_size, + self.patch_size, + self.feature_size, + self.feature_ratio, + self.feature_dim, + self.grid_size, + self.num_patches, + ) = self._init_backbone( + img_size=img_size, + patch_size=patch_size, + feature_size=feature_size, + feature_ratio=feature_ratio, + feature_dim=feature_dim, + ) + + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: + total_reduction = ( + self.feature_ratio[0] * self.patch_size[0], + self.feature_ratio[1] * self.patch_size[1] + ) + if as_scalar: + return max(total_reduction) + else: + return total_reduction + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """ Get feature grid size taking account dynamic padding and backbone network feat reduction + """ + feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1]) + if self.dynamic_img_pad: + return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1]) + else: + return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1] + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + _, _, H, W = x.shape + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) + return x + + +class HybridEmbedWithSize(HybridEmbed): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__( + self, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, + bias=True, + proj=True, + ): + super().__init__( + backbone=backbone, + img_size=img_size, + patch_size=patch_size, + feature_size=feature_size, + feature_ratio=feature_ratio, + in_chans=in_chans, + embed_dim=embed_dim, + bias=bias, + proj=proj, + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x) + return x.flatten(2).transpose(1, 2), x.shape[-2:] \ No newline at end of file diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 3f148944..c739291b 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -44,14 +44,7 @@ class PatchEmbed(nn.Module): ): super().__init__() self.patch_size = to_2tuple(patch_size) - if img_size is not None: - self.img_size = to_2tuple(img_size) - self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - else: - self.img_size = None - self.grid_size = None - self.num_patches = None + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) if output_fmt is not None: self.flatten = False @@ -66,6 +59,41 @@ class PatchEmbed(nn.Module): self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): + assert self.patch_size + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def set_input_size( + self, + img_size: Optional[Union[int, Tuple[int, int]]] = None, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + ): + new_patch_size = None + if patch_size is not None: + new_patch_size = to_2tuple(patch_size) + if new_patch_size is not None and new_patch_size != self.patch_size: + with torch.no_grad(): + new_proj = nn.Conv2d( + self.proj.in_channels, + self.proj.out_channels, + kernel_size=new_patch_size, + stride=new_patch_size, + bias=self.proj.bias is not None, + ) + new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)) + if self.proj.bias is not None: + new_proj.bias.copy_(self.proj.bias) + self.proj = new_proj + self.patch_size = new_patch_size + img_size = img_size or self.img_size + if img_size != self.img_size or new_patch_size is not None: + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: if as_scalar: return max(self.patch_size) @@ -180,13 +208,9 @@ def resample_patch_embed( """ import numpy as np try: - import functorch - vmap = functorch.vmap + from torch import vmap except ImportError: - if hasattr(torch, 'vmap'): - vmap = torch.vmap - else: - assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing." + from functorch import vmap assert len(patch_embed.shape) == 4, "Four dimensions expected" assert len(new_size) == 2, "New shape should only be hw" diff --git a/timm/models/convit.py b/timm/models/convit.py index dadc41b8..cbe3b51e 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -27,11 +27,10 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm +from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._registry import register_model, generate_default_cfgs -from .vision_transformer_hybrid import HybridEmbed __all__ = ['ConVit'] diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index a5800937..71c4e639 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -140,6 +140,27 @@ class WindowAttention(nn.Module): trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) + def set_window_size(self, window_size: Tuple[int, int]) -> None: + """Update window size & interpolate position embeddings + Args: + window_size (int): New window size + """ + window_size = to_2tuple(window_size) + if window_size == self.window_size: + return + self.window_size = window_size + win_h, win_w = self.window_size + self.window_area = win_h * win_w + with torch.no_grad(): + new_bias_shape = (2 * win_h - 1) * (2 * win_w - 1), self.num_heads + self.relative_position_bias_table = nn.Parameter( + resize_rel_pos_bias_table( + self.relative_position_bias_table, + new_window_size=self.window_size, + new_bias_shape=new_bias_shape, + )) + self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) + def _get_rel_pos_bias(self) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH @@ -197,6 +218,7 @@ class SwinTransformerBlock(nn.Module): head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, shift_size: int = 0, + always_partition: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., @@ -224,9 +246,9 @@ class SwinTransformerBlock(nn.Module): super().__init__() self.dim = dim self.input_resolution = input_resolution - ws, ss = self._calc_window_shift(window_size, shift_size) - self.window_size: Tuple[int, int] = ws - self.shift_size: Tuple[int, int] = ss + self.target_shift_size = to_2tuple(shift_size) + self.always_partition = always_partition + self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size) self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio @@ -251,6 +273,9 @@ class SwinTransformerBlock(nn.Module): ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self._make_attention_mask() + + def _make_attention_mask(self): if any(self.shift_size): # calculate attention mask for SW-MSA H, W = self.input_resolution @@ -274,16 +299,47 @@ class SwinTransformerBlock(nn.Module): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - self.register_buffer("attn_mask", attn_mask, persistent=False) - def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + def _calc_window_shift( + self, + target_window_size: Union[int, Tuple[int, int]], + target_shift_size: Optional[Union[int, Tuple[int, int]]] = None, + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: target_window_size = to_2tuple(target_window_size) - target_shift_size = to_2tuple(target_shift_size) + if target_shift_size is None: + # if passed value is None, recalculate from default window_size // 2 if it was active + target_shift_size = self.target_shift_size + if any(target_shift_size): + target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2] + else: + target_shift_size = to_2tuple(target_shift_size) + if self.always_partition: + return target_window_size, target_shift_size window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] return tuple(window_size), tuple(shift_size) + def set_input_size( + self, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + always_partition: Optional[bool] = None, + ): + """ + Args: + feat_size: New input resolution + window_size: New window size + always_partition: Change always_partition attribute if not None + """ + self.input_resolution = feat_size + if always_partition is not None: + self.always_partition = always_partition + self.window_size, self.shift_size = self._calc_window_shift(window_size) + self.window_area = self.window_size[0] * self.window_size[1] + self.attn.set_window_size(self.window_size) + self._make_attention_mask() + def _attn(self, x): B, H, W, C = x.shape @@ -374,6 +430,7 @@ class SwinTransformerStage(nn.Module): num_heads: int = 4, head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, + always_partition: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., @@ -427,6 +484,7 @@ class SwinTransformerStage(nn.Module): head_dim=head_dim, window_size=window_size, shift_size=0 if (i % 2 == 0) else shift_size, + always_partition=always_partition, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -436,6 +494,30 @@ class SwinTransformerStage(nn.Module): ) for i in range(depth)]) + def set_input_size( + self, + feat_size: Tuple[int, int], + window_size: int, + always_partition: Optional[bool] = None, + ): + """Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + + Args: + feat_size (Tuple[int, int]): New input resolution + window_size (int): New window size + """ + self.input_resolution = feat_size + if isinstance(self.downsample, nn.Identity): + self.output_resolution = feat_size + else: + self.output_resolution = tuple(i // 2 for i in feat_size) + for block in self.blocks: + block.set_input_size( + feat_size=self.output_resolution, + window_size=window_size, + always_partition=always_partition, + ) + def forward(self, x): x = self.downsample(x) @@ -465,6 +547,7 @@ class SwinTransformer(nn.Module): num_heads: Tuple[int, ...] = (3, 6, 12, 24), head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, + always_partition: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., @@ -546,6 +629,7 @@ class SwinTransformer(nn.Module): num_heads=num_heads[i], head_dim=head_dim[i], window_size=window_size[i], + always_partition=always_partition, mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, proj_drop=proj_drop_rate, @@ -556,7 +640,7 @@ class SwinTransformer(nn.Module): in_dim = out_dim if i > 0: scale *= 2 - self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] + self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')] self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) @@ -584,6 +668,36 @@ class SwinTransformer(nn.Module): nwd.add(n) return nwd + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + window_size: Optional[Tuple[int, int]] = None, + window_ratio: int = 32, + always_partition: Optional[bool] = None, + ) -> None: + """ Updates the image resolution and window size. + + Args: + img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used + window_size (Optional[int]): New window size, if None based on new_img_size // window_div + window_ratio (int): divisor for calculating window size from image size + """ + if img_size is not None or patch_size is not None: + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + self.patch_grid = self.patch_embed.grid_size + if window_size is None: + img_size = self.patch_embed.img_size + window_size = tuple([s // window_ratio for s in img_size]) + for index, stage in enumerate(self.layers): + stage_scale = 2 ** max(index - 1, 0) + print(self.patch_grid, stage_scale) + stage.set_input_size( + feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale), + window_size=window_size, + always_partition=always_partition, + ) + @torch.jit.ignore def group_matcher(self, coarse=False): return dict( diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d5fcbadc..d7c5f672 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -119,7 +119,7 @@ class WindowMultiHeadAttention(nn.Module): assert dim % num_heads == 0, \ "The number of input features (in_features) are not divisible by the number of heads (num_heads)." self.in_features: int = dim - self.window_size: Tuple[int, int] = window_size + self.window_size: Tuple[int, int] = to_2tuple(window_size) self.num_heads: int = num_heads self.sequential_attn: bool = sequential_attn @@ -152,16 +152,15 @@ class WindowMultiHeadAttention(nn.Module): 1.0 + relative_coordinates.abs()) self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False) - def update_input_size(self, new_window_size: int, **kwargs: Any) -> None: - """Method updates the window size and so the pair-wise relative positions - + def set_window_size(self, window_size: Tuple[int, int]) -> None: + """Update window size & interpolate position embeddings Args: - new_window_size (int): New window size - kwargs (Any): Unused + window_size (int): New window size """ - # Set new window size and new pair-wise relative positions - self.window_size: int = new_window_size - self._make_pair_wise_relative_positions() + window_size = to_2tuple(window_size) + if window_size != self.window_size: + self.window_size = window_size + self._make_pair_wise_relative_positions() def _relative_positional_encodings(self) -> torch.Tensor: """Method computes the relative positional encodings @@ -321,18 +320,18 @@ class SwinTransformerV2CrBlock(nn.Module): nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) - def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: + def set_input_size(self, feat_size: Tuple[int, int], window_size: Tuple[int, int]) -> None: """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. Args: - new_window_size (int): New window size - new_feat_size (Tuple[int, int]): New input resolution + feat_size (Tuple[int, int]): New input resolution + window_size (int): New window size """ # Update input resolution - self.feat_size: Tuple[int, int] = new_feat_size - self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size)) + self.feat_size: Tuple[int, int] = feat_size + self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] - self.attn.update_input_size(new_window_size=self.window_size) + self.attn.set_window_size(self.window_size) self._make_attention_mask() def _shifted_window_attn(self, x): @@ -510,18 +509,18 @@ class SwinTransformerV2CrStage(nn.Module): for index in range(depth)] ) - def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None: + def set_input_size(self, feat_size: Tuple[int, int], window_size: int) -> None: """Method updates the resolution to utilize and the window size and so the pair-wise relative positions. Args: - new_window_size (int): New window size - new_feat_size (Tuple[int, int]): New input resolution + window_size (int): New window size + feat_size (Tuple[int, int]): New input resolution """ self.feat_size: Tuple[int, int] = ( - (new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size + (feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size ) for block in self.blocks: - block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size) + block.set_input_size(feat_size=self.feat_size, window_size=window_size) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -657,33 +656,32 @@ class SwinTransformerV2Cr(nn.Module): if weight_init != 'skip': named_apply(init_weights, self) - def update_input_size( + def set_input_size( self, - new_img_size: Optional[Tuple[int, int]] = None, - new_window_size: Optional[int] = None, - img_window_ratio: int = 32, + img_size: Optional[Tuple[int, int]] = None, + window_size: Optional[Tuple[int, int]] = None, + window_ratio: int = 32, ) -> None: """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. Args: - new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div - new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used - img_window_ratio (int): divisor for calculating window size from image size + img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used + window_size (Optional[int]): New window size, if None based on new_img_size // window_div + window_ratio (int): divisor for calculating window size from image size """ - # Check parameters - if new_img_size is None: - new_img_size = self.img_size + if img_size is None: + img_size = self.img_size else: - new_img_size = to_2tuple(new_img_size) - if new_window_size is None: - new_window_size = tuple([s // img_window_ratio for s in new_img_size]) + img_size = to_2tuple(img_size) + if window_size is None: + window_size = tuple([s // window_ratio for s in img_size]) # Compute new patch resolution & update resolution of each stage - new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size) + patch_grid_size = (img_size[0] // self.patch_size, img_size[1] // self.patch_size) for index, stage in enumerate(self.stages): stage_scale = 2 ** max(index - 1, 0) - stage.update_input_size( - new_window_size=new_window_size, - new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale), + stage.set_input_size( + feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale), + window_size=window_size, ) @torch.jit.ignore diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 24a74b84..08c04d54 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -632,6 +632,31 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + ): + """Method updates the input image resolution, patch size + + Args: + img_size: New input resolution, if None current resolution is used + patch_size: New patch size, if None existing patch size is used + """ + prev_grid_size = self.patch_embed.grid_size + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + if self.pos_embed is not None: + num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens + num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens + if num_new_tokens != self.pos_embed.shape[1]: + self.pos_embed = nn.Parameter(resample_abs_pos_embed( + self.pos_embed, + new_size=self.patch_embed.grid_size, + old_size=prev_grid_size, + num_prefix_tokens=num_prefix_tokens, + verbose=True, + )) + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.pos_embed is None: return x.view(x.shape[0], -1, x.shape[-1]) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index ffeabc4a..4cf3a766 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to +from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, HybridEmbed from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -31,172 +30,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem from .vision_transformer import VisionTransformer -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - output_fmt: Format - dynamic_img_pad: torch.jit.Final[bool] - - def __init__( - self, - backbone: nn.Module, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 1, - feature_size: Optional[Union[int, Tuple[int, int]]] = None, - feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, - in_chans: int = 3, - embed_dim: int = 768, - bias: bool = True, - proj: bool = True, - flatten: bool = True, - output_fmt: Optional[str] = None, - strict_img_size: bool = True, - dynamic_img_pad: bool = False, - ): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.backbone = backbone - if feature_size is None: - with torch.no_grad(): - # NOTE Most reliable way of determining output dims is to run forward pass - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) - if isinstance(o, (list, tuple)): - o = o[-1] # last feature if backbone outputs list/tuple of features - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)]) - else: - - feature_size = to_2tuple(feature_size) - feature_ratio = to_2tuple(feature_ratio or 16) - if hasattr(self.backbone, 'feature_info'): - feature_dim = self.backbone.feature_info.channels()[-1] - else: - feature_dim = self.backbone.num_features - if not dynamic_img_pad: - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.feature_size = feature_size - self.feature_ratio = feature_ratio - self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - if output_fmt is not None: - self.flatten = False - self.output_fmt = Format(output_fmt) - else: - # flatten spatial dim and transpose to channels last, kept for bwd compat - self.flatten = flatten - self.output_fmt = Format.NCHW - self.strict_img_size = strict_img_size - self.dynamic_img_pad = dynamic_img_pad - - if proj: - self.proj = nn.Conv2d( - feature_dim, - embed_dim, - kernel_size=patch_size, - stride=patch_size, - bias=bias, - ) - else: - assert feature_dim == embed_dim,\ - f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.' - self.proj = nn.Identity() - - def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: - total_reduction = ( - self.feature_ratio[0] * self.patch_size[0], - self.feature_ratio[1] * self.patch_size[1] - ) - if as_scalar: - return max(total_reduction) - else: - return total_reduction - - def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: - """ Get feature grid size taking account dynamic padding and backbone network feat reduction - """ - feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1]) - if self.dynamic_img_pad: - return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1]) - else: - return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1] - - @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): - if hasattr(self.backbone, 'set_grad_checkpointing'): - self.backbone.set_grad_checkpointing(enable=enable) - elif hasattr(self.backbone, 'grad_checkpointing'): - self.backbone.grad_checkpointing = enable - - def forward(self, x): - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - _, _, H, W = x.shape - if self.dynamic_img_pad: - pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] - pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = F.pad(x, (0, pad_w, 0, pad_h)) - x = self.proj(x) - if self.flatten: - x = x.flatten(2).transpose(1, 2) # NCHW -> NLC - elif self.output_fmt != Format.NCHW: - x = nchw_to(x, self.output_fmt) - return x - - -class HybridEmbedWithSize(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - def __init__( - self, - backbone: nn.Module, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 1, - feature_size: Optional[Union[int, Tuple[int, int]]] = None, - feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, - in_chans: int = 3, - embed_dim: int = 768, - bias=True, - proj=True, - ): - super().__init__( - backbone=backbone, - img_size=img_size, - patch_size=patch_size, - feature_size=feature_size, - feature_ratio=feature_ratio, - in_chans=in_chans, - embed_dim=embed_dim, - bias=bias, - proj=proj, - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): - if hasattr(self.backbone, 'set_grad_checkpointing'): - self.backbone.set_grad_checkpointing(enable=enable) - elif hasattr(self.backbone, 'grad_checkpointing'): - self.backbone.grad_checkpointing = enable - - def forward(self, x) -> Tuple[torch.Tensor, List[int]]: - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x) - return x.flatten(2).transpose(1, 2), x.shape[-2:] - - class ConvStem(nn.Sequential): def __init__( self, diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index db1f2669..18635f60 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -29,12 +29,11 @@ import torch.nn as nn from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \ - make_divisible, DropPath + make_divisible, DropPath, HybridEmbed from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq from ._registry import register_model, generate_default_cfgs from .vision_transformer import VisionTransformer, checkpoint_filter_fn -from .vision_transformer_hybrid import HybridEmbed @dataclass