Initial impl of dynamic resize for existing vit models (incl vit-resnet hybrids)

This commit is contained in:
Ross Wightman 2023-08-25 12:45:17 -07:00 committed by Ross Wightman
parent 38c474e3de
commit fdd8c7c2da
3 changed files with 39 additions and 7 deletions

View File

@ -29,7 +29,7 @@ def resample_abs_pos_embed(
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if not old_size:
if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw

View File

@ -38,7 +38,7 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -383,6 +383,7 @@ class VisionTransformer(nn.Module):
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_size: Final[bool]
def __init__(
self,
@ -400,6 +401,7 @@ class VisionTransformer(nn.Module):
init_values: Optional[float] = None,
class_token: bool = True,
no_embed_class: bool = False,
dynamic_size: bool = False,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
drop_rate: float = 0.,
@ -452,14 +454,23 @@ class VisionTransformer(nn.Module):
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.dynamic_size = dynamic_size
self.grad_checkpointing = False
embed_args = {}
if dynamic_size:
embed_args.update(dict(
strict_img_size=False,
flatten=False, # flatten deferred until after pos embed
output_fmt='NHWC',
))
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**embed_args,
)
num_patches = self.patch_embed.num_patches
@ -546,10 +557,16 @@ class VisionTransformer(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x):
if self.dynamic_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
x = x + pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
@ -557,7 +574,7 @@ class VisionTransformer(nn.Module):
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
x = x + pos_embed
return self.pos_drop(x)
def _intermediate_layers(

View File

@ -14,13 +14,13 @@ They were moved here to keep file sizes sane.
Hacked together by / Copyright 2020, Ross Wightman
"""
from functools import partial
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -40,6 +40,9 @@ class HybridEmbed(nn.Module):
in_chans=3,
embed_dim=768,
bias=True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
):
super().__init__()
assert isinstance(backbone, nn.Module)
@ -69,6 +72,15 @@ class HybridEmbed(nn.Module):
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
@ -76,7 +88,10 @@ class HybridEmbed(nn.Module):
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
return x