From fdd8c7c2dabaf50cef3f04ddaaa4183bb849a36f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 25 Aug 2023 12:45:17 -0700 Subject: [PATCH] Initial impl of dynamic resize for existing vit models (incl vit-resnet hybrids) --- timm/layers/pos_embed.py | 2 +- timm/models/vision_transformer.py | 23 ++++++++++++++++++++--- timm/models/vision_transformer_hybrid.py | 21 ++++++++++++++++++--- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 6be0017f..3e67be00 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -29,7 +29,7 @@ def resample_abs_pos_embed( if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: return posemb - if not old_size: + if old_size is None: hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) old_size = hw, hw diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 025a01a8..b63fe79f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -38,7 +38,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked + resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -383,6 +383,7 @@ class VisionTransformer(nn.Module): A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ + dynamic_size: Final[bool] def __init__( self, @@ -400,6 +401,7 @@ class VisionTransformer(nn.Module): init_values: Optional[float] = None, class_token: bool = True, no_embed_class: bool = False, + dynamic_size: bool = False, pre_norm: bool = False, fc_norm: Optional[bool] = None, drop_rate: float = 0., @@ -452,14 +454,23 @@ class VisionTransformer(nn.Module): self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class + self.dynamic_size = dynamic_size self.grad_checkpointing = False + embed_args = {} + if dynamic_size: + embed_args.update(dict( + strict_img_size=False, + flatten=False, # flatten deferred until after pos embed + output_fmt='NHWC', + )) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + **embed_args, ) num_patches = self.patch_embed.num_patches @@ -546,10 +557,16 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x): + if self.dynamic_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W)) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat - x = x + self.pos_embed + x = x + pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: @@ -557,7 +574,7 @@ class VisionTransformer(nn.Module): # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.pos_embed + x = x + pos_embed return self.pos_drop(x) def _intermediate_layers( diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 8cf7bec1..c9b0e7ac 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -14,13 +14,13 @@ They were moved here to keep file sizes sane. Hacked together by / Copyright 2020, Ross Wightman """ from functools import partial -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, to_2tuple +from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem @@ -40,6 +40,9 @@ class HybridEmbed(nn.Module): in_chans=3, embed_dim=768, bias=True, + flatten: bool = True, + output_fmt: Optional[str] = None, + strict_img_size: bool = True, ): super().__init__() assert isinstance(backbone, nn.Module) @@ -69,6 +72,15 @@ class HybridEmbed(nn.Module): assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) def forward(self, x): @@ -76,7 +88,10 @@ class HybridEmbed(nn.Module): if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.proj(x) - x = x.flatten(2).transpose(1, 2) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) return x