mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial impl of dynamic resize for existing vit models (incl vit-resnet hybrids)
This commit is contained in:
parent
38c474e3de
commit
fdd8c7c2da
@ -29,7 +29,7 @@ def resample_abs_pos_embed(
|
||||
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
||||
return posemb
|
||||
|
||||
if not old_size:
|
||||
if old_size is None:
|
||||
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
|
||||
old_size = hw, hw
|
||||
|
||||
|
@ -38,7 +38,7 @@ from torch.jit import Final
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
||||
resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -383,6 +383,7 @@ class VisionTransformer(nn.Module):
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
dynamic_size: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -400,6 +401,7 @@ class VisionTransformer(nn.Module):
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
no_embed_class: bool = False,
|
||||
dynamic_size: bool = False,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
drop_rate: float = 0.,
|
||||
@ -452,14 +454,23 @@ class VisionTransformer(nn.Module):
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
self.dynamic_size = dynamic_size
|
||||
self.grad_checkpointing = False
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_size:
|
||||
embed_args.update(dict(
|
||||
strict_img_size=False,
|
||||
flatten=False, # flatten deferred until after pos embed
|
||||
output_fmt='NHWC',
|
||||
))
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
@ -546,10 +557,16 @@ class VisionTransformer(nn.Module):
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.dynamic_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
|
||||
x = x.view(B, -1, C)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + self.pos_embed
|
||||
x = x + pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
else:
|
||||
@ -557,7 +574,7 @@ class VisionTransformer(nn.Module):
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = x + pos_embed
|
||||
return self.pos_drop(x)
|
||||
|
||||
def _intermediate_layers(
|
||||
|
@ -14,13 +14,13 @@ They were moved here to keep file sizes sane.
|
||||
Hacked together by / Copyright 2020, Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
@ -40,6 +40,9 @@ class HybridEmbed(nn.Module):
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
bias=True,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
strict_img_size: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
@ -69,6 +72,15 @@ class HybridEmbed(nn.Module):
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
self.output_fmt = Format(output_fmt)
|
||||
else:
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
self.strict_img_size = strict_img_size
|
||||
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
@ -76,7 +88,10 @@ class HybridEmbed(nn.Module):
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
elif self.output_fmt != Format.NCHW:
|
||||
x = nchw_to(x, self.output_fmt)
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user