mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial impl of dynamic resize for existing vit models (incl vit-resnet hybrids)
This commit is contained in:
parent
38c474e3de
commit
fdd8c7c2da
@ -29,7 +29,7 @@ def resample_abs_pos_embed(
|
|||||||
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
||||||
return posemb
|
return posemb
|
||||||
|
|
||||||
if not old_size:
|
if old_size is None:
|
||||||
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
|
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
|
||||||
old_size = hw, hw
|
old_size = hw, hw
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ from torch.jit import Final
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
@ -383,6 +383,7 @@ class VisionTransformer(nn.Module):
|
|||||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||||
- https://arxiv.org/abs/2010.11929
|
- https://arxiv.org/abs/2010.11929
|
||||||
"""
|
"""
|
||||||
|
dynamic_size: Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -400,6 +401,7 @@ class VisionTransformer(nn.Module):
|
|||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
class_token: bool = True,
|
class_token: bool = True,
|
||||||
no_embed_class: bool = False,
|
no_embed_class: bool = False,
|
||||||
|
dynamic_size: bool = False,
|
||||||
pre_norm: bool = False,
|
pre_norm: bool = False,
|
||||||
fc_norm: Optional[bool] = None,
|
fc_norm: Optional[bool] = None,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
@ -452,14 +454,23 @@ class VisionTransformer(nn.Module):
|
|||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
self.num_prefix_tokens = 1 if class_token else 0
|
self.num_prefix_tokens = 1 if class_token else 0
|
||||||
self.no_embed_class = no_embed_class
|
self.no_embed_class = no_embed_class
|
||||||
|
self.dynamic_size = dynamic_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
embed_args = {}
|
||||||
|
if dynamic_size:
|
||||||
|
embed_args.update(dict(
|
||||||
|
strict_img_size=False,
|
||||||
|
flatten=False, # flatten deferred until after pos embed
|
||||||
|
output_fmt='NHWC',
|
||||||
|
))
|
||||||
self.patch_embed = embed_layer(
|
self.patch_embed = embed_layer(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||||
|
**embed_args,
|
||||||
)
|
)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
@ -546,10 +557,16 @@ class VisionTransformer(nn.Module):
|
|||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def _pos_embed(self, x):
|
def _pos_embed(self, x):
|
||||||
|
if self.dynamic_size:
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
pos_embed = resample_abs_pos_embed(self.pos_embed, (H, W))
|
||||||
|
x = x.view(B, -1, C)
|
||||||
|
else:
|
||||||
|
pos_embed = self.pos_embed
|
||||||
if self.no_embed_class:
|
if self.no_embed_class:
|
||||||
# deit-3, updated JAX (big vision)
|
# deit-3, updated JAX (big vision)
|
||||||
# position embedding does not overlap with class token, add then concat
|
# position embedding does not overlap with class token, add then concat
|
||||||
x = x + self.pos_embed
|
x = x + pos_embed
|
||||||
if self.cls_token is not None:
|
if self.cls_token is not None:
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
else:
|
else:
|
||||||
@ -557,7 +574,7 @@ class VisionTransformer(nn.Module):
|
|||||||
# pos_embed has entry for class token, concat then add
|
# pos_embed has entry for class token, concat then add
|
||||||
if self.cls_token is not None:
|
if self.cls_token is not None:
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
x = x + self.pos_embed
|
x = x + pos_embed
|
||||||
return self.pos_drop(x)
|
return self.pos_drop(x)
|
||||||
|
|
||||||
def _intermediate_layers(
|
def _intermediate_layers(
|
||||||
|
@ -14,13 +14,13 @@ They were moved here to keep file sizes sane.
|
|||||||
Hacked together by / Copyright 2020, Ross Wightman
|
Hacked together by / Copyright 2020, Ross Wightman
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
|
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
from .resnet import resnet26d, resnet50d
|
from .resnet import resnet26d, resnet50d
|
||||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||||
@ -40,6 +40,9 @@ class HybridEmbed(nn.Module):
|
|||||||
in_chans=3,
|
in_chans=3,
|
||||||
embed_dim=768,
|
embed_dim=768,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
flatten: bool = True,
|
||||||
|
output_fmt: Optional[str] = None,
|
||||||
|
strict_img_size: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(backbone, nn.Module)
|
assert isinstance(backbone, nn.Module)
|
||||||
@ -69,6 +72,15 @@ class HybridEmbed(nn.Module):
|
|||||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
|
if output_fmt is not None:
|
||||||
|
self.flatten = False
|
||||||
|
self.output_fmt = Format(output_fmt)
|
||||||
|
else:
|
||||||
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||||
|
self.flatten = flatten
|
||||||
|
self.output_fmt = Format.NCHW
|
||||||
|
self.strict_img_size = strict_img_size
|
||||||
|
|
||||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -76,7 +88,10 @@ class HybridEmbed(nn.Module):
|
|||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = x.flatten(2).transpose(1, 2)
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
elif self.output_fmt != Format.NCHW:
|
||||||
|
x = nchw_to(x, self.output_fmt)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user