mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
dynamic_size -> dynamic_img_size, add dynamic_img_pad for padding option
This commit is contained in:
parent
1f4512fca3
commit
fc5d705b83
@ -26,6 +26,7 @@ class PatchEmbed(nn.Module):
|
|||||||
""" 2D Image to Patch Embedding
|
""" 2D Image to Patch Embedding
|
||||||
"""
|
"""
|
||||||
output_fmt: Format
|
output_fmt: Format
|
||||||
|
dynamic_img_pad: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -38,6 +39,7 @@ class PatchEmbed(nn.Module):
|
|||||||
output_fmt: Optional[str] = None,
|
output_fmt: Optional[str] = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
strict_img_size: bool = True,
|
strict_img_size: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = to_2tuple(patch_size)
|
self.patch_size = to_2tuple(patch_size)
|
||||||
@ -58,6 +60,7 @@ class PatchEmbed(nn.Module):
|
|||||||
self.flatten = flatten
|
self.flatten = flatten
|
||||||
self.output_fmt = Format.NCHW
|
self.output_fmt = Format.NCHW
|
||||||
self.strict_img_size = strict_img_size
|
self.strict_img_size = strict_img_size
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
@ -68,7 +71,7 @@ class PatchEmbed(nn.Module):
|
|||||||
if self.strict_img_size:
|
if self.strict_img_size:
|
||||||
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||||
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||||
else:
|
elif not self.dynamic_img_pad:
|
||||||
_assert(
|
_assert(
|
||||||
H % self.patch_size[0] == 0,
|
H % self.patch_size[0] == 0,
|
||||||
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||||
@ -77,7 +80,10 @@ class PatchEmbed(nn.Module):
|
|||||||
W % self.patch_size[1] == 0,
|
W % self.patch_size[1] == 0,
|
||||||
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||||
)
|
)
|
||||||
|
if self.dynamic_img_pad:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
@ -74,7 +74,7 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
self.distilled_training = enable
|
self.distilled_training = enable
|
||||||
|
|
||||||
def _pos_embed(self, x):
|
def _pos_embed(self, x):
|
||||||
if self.dynamic_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
self.pos_embed,
|
self.pos_embed,
|
||||||
|
@ -367,7 +367,8 @@ class Eva(nn.Module):
|
|||||||
use_abs_pos_emb: bool = True,
|
use_abs_pos_emb: bool = True,
|
||||||
use_rot_pos_emb: bool = False,
|
use_rot_pos_emb: bool = False,
|
||||||
use_post_norm: bool = False,
|
use_post_norm: bool = False,
|
||||||
dynamic_size: bool = False,
|
dynamic_img_size: bool = False,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||||
head_init_scale: float = 0.001,
|
head_init_scale: float = 0.001,
|
||||||
):
|
):
|
||||||
@ -407,11 +408,11 @@ class Eva(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
self.num_prefix_tokens = 1 if class_token else 0
|
self.num_prefix_tokens = 1 if class_token else 0
|
||||||
self.dynamic_size = dynamic_size
|
self.dynamic_img_size = dynamic_img_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
embed_args = {}
|
embed_args = {}
|
||||||
if dynamic_size:
|
if dynamic_img_size:
|
||||||
# flatten deferred until after pos embed
|
# flatten deferred until after pos embed
|
||||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = PatchEmbed(
|
||||||
@ -419,6 +420,7 @@ class Eva(nn.Module):
|
|||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
|
dynamic_img_pad=dynamic_img_pad,
|
||||||
**embed_args,
|
**embed_args,
|
||||||
)
|
)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
@ -442,7 +444,7 @@ class Eva(nn.Module):
|
|||||||
self.rope = RotaryEmbeddingCat(
|
self.rope = RotaryEmbeddingCat(
|
||||||
embed_dim // num_heads,
|
embed_dim // num_heads,
|
||||||
in_pixels=False,
|
in_pixels=False,
|
||||||
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
|
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
|
||||||
ref_feat_shape=ref_feat_shape,
|
ref_feat_shape=ref_feat_shape,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -527,7 +529,7 @@ class Eva(nn.Module):
|
|||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if self.dynamic_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
|
@ -383,7 +383,7 @@ class VisionTransformer(nn.Module):
|
|||||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||||
- https://arxiv.org/abs/2010.11929
|
- https://arxiv.org/abs/2010.11929
|
||||||
"""
|
"""
|
||||||
dynamic_size: Final[bool]
|
dynamic_img_size: Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -401,9 +401,10 @@ class VisionTransformer(nn.Module):
|
|||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
class_token: bool = True,
|
class_token: bool = True,
|
||||||
no_embed_class: bool = False,
|
no_embed_class: bool = False,
|
||||||
dynamic_size: bool = False,
|
|
||||||
pre_norm: bool = False,
|
pre_norm: bool = False,
|
||||||
fc_norm: Optional[bool] = None,
|
fc_norm: Optional[bool] = None,
|
||||||
|
dynamic_img_size: bool = False,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
pos_drop_rate: float = 0.,
|
pos_drop_rate: float = 0.,
|
||||||
patch_drop_rate: float = 0.,
|
patch_drop_rate: float = 0.,
|
||||||
@ -454,11 +455,11 @@ class VisionTransformer(nn.Module):
|
|||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
self.num_prefix_tokens = 1 if class_token else 0
|
self.num_prefix_tokens = 1 if class_token else 0
|
||||||
self.no_embed_class = no_embed_class
|
self.no_embed_class = no_embed_class
|
||||||
self.dynamic_size = dynamic_size
|
self.dynamic_img_size = dynamic_img_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
embed_args = {}
|
embed_args = {}
|
||||||
if dynamic_size:
|
if dynamic_img_size:
|
||||||
# flatten deferred until after pos embed
|
# flatten deferred until after pos embed
|
||||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||||
self.patch_embed = embed_layer(
|
self.patch_embed = embed_layer(
|
||||||
@ -467,6 +468,7 @@ class VisionTransformer(nn.Module):
|
|||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||||
|
dynamic_img_pad=dynamic_img_pad,
|
||||||
**embed_args,
|
**embed_args,
|
||||||
)
|
)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
@ -554,7 +556,7 @@ class VisionTransformer(nn.Module):
|
|||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def _pos_embed(self, x):
|
def _pos_embed(self, x):
|
||||||
if self.dynamic_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
self.pos_embed,
|
self.pos_embed,
|
||||||
|
@ -18,6 +18,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
||||||
@ -32,6 +33,7 @@ class HybridEmbed(nn.Module):
|
|||||||
Extract feature map from CNN, flatten, project to embedding dim.
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
"""
|
"""
|
||||||
output_fmt: Format
|
output_fmt: Format
|
||||||
|
dynamic_img_pad: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -45,6 +47,7 @@ class HybridEmbed(nn.Module):
|
|||||||
flatten: bool = True,
|
flatten: bool = True,
|
||||||
output_fmt: Optional[str] = None,
|
output_fmt: Optional[str] = None,
|
||||||
strict_img_size: bool = True,
|
strict_img_size: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(backbone, nn.Module)
|
assert isinstance(backbone, nn.Module)
|
||||||
@ -71,7 +74,8 @@ class HybridEmbed(nn.Module):
|
|||||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||||
else:
|
else:
|
||||||
feature_dim = self.backbone.num_features
|
feature_dim = self.backbone.num_features
|
||||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
if not dynamic_img_pad:
|
||||||
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
if output_fmt is not None:
|
if output_fmt is not None:
|
||||||
@ -82,6 +86,7 @@ class HybridEmbed(nn.Module):
|
|||||||
self.flatten = flatten
|
self.flatten = flatten
|
||||||
self.output_fmt = Format.NCHW
|
self.output_fmt = Format.NCHW
|
||||||
self.strict_img_size = strict_img_size
|
self.strict_img_size = strict_img_size
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||||
|
|
||||||
@ -89,6 +94,11 @@ class HybridEmbed(nn.Module):
|
|||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
if self.dynamic_img_pad:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
Loading…
x
Reference in New Issue
Block a user