Initial impl of dynamic resize for existing vit models (incl vit-resnet hybrids)

This commit is contained in:
Ross Wightman 2023-08-25 12:45:17 -07:00 committed by Ross Wightman
parent 38c474e3de
commit fdd8c7c2da
3 changed files with 39 additions and 7 deletions

View File

@ -29,7 +29,7 @@ def resample_abs_pos_embed(
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb return posemb
if not old_size: if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw old_size = hw, hw

View File

@ -38,7 +38,7 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -383,6 +383,7 @@ class VisionTransformer(nn.Module):
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
""" """
dynamic_size: Final[bool]
def __init__( def __init__(
self, self,
@ -400,6 +401,7 @@ class VisionTransformer(nn.Module):
init_values: Optional[float] = None, init_values: Optional[float] = None,
class_token: bool = True, class_token: bool = True,
no_embed_class: bool = False, no_embed_class: bool = False,
dynamic_size: bool = False,
pre_norm: bool = False, pre_norm: bool = False,
fc_norm: Optional[bool] = None, fc_norm: Optional[bool] = None,
drop_rate: float = 0., drop_rate: float = 0.,
@ -452,14 +454,23 @@ class VisionTransformer(nn.Module):
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class self.no_embed_class = no_embed_class
self.dynamic_size = dynamic_size
self.grad_checkpointing = False self.grad_checkpointing = False
embed_args = {}
if dynamic_size:
embed_args.update(dict(
strict_img_size=False,
flatten=False, # flatten deferred until after pos embed
output_fmt='NHWC',
))
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**embed_args,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
@ -546,10 +557,16 @@ class VisionTransformer(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x): def _pos_embed(self, x):
if self.dynamic_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed
if self.no_embed_class: if self.no_embed_class:
# deit-3, updated JAX (big vision) # deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat # position embedding does not overlap with class token, add then concat
x = x + self.pos_embed x = x + pos_embed
if self.cls_token is not None: if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else: else:
@ -557,7 +574,7 @@ class VisionTransformer(nn.Module):
# pos_embed has entry for class token, concat then add # pos_embed has entry for class token, concat then add
if self.cls_token is not None: if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed x = x + pos_embed
return self.pos_drop(x) return self.pos_drop(x)
def _intermediate_layers( def _intermediate_layers(

View File

@ -14,13 +14,13 @@ They were moved here to keep file sizes sane.
Hacked together by / Copyright 2020, Ross Wightman Hacked together by / Copyright 2020, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -40,6 +40,9 @@ class HybridEmbed(nn.Module):
in_chans=3, in_chans=3,
embed_dim=768, embed_dim=768,
bias=True, bias=True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
): ):
super().__init__() super().__init__()
assert isinstance(backbone, nn.Module) assert isinstance(backbone, nn.Module)
@ -69,6 +72,15 @@ class HybridEmbed(nn.Module):
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x): def forward(self, x):
@ -76,7 +88,10 @@ class HybridEmbed(nn.Module):
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x) x = self.proj(x)
x = x.flatten(2).transpose(1, 2) if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
return x return x