mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
dynamic_size -> dynamic_img_size, add dynamic_img_pad for padding option
This commit is contained in:
parent
1f4512fca3
commit
fc5d705b83
@ -26,6 +26,7 @@ class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
output_fmt: Format
|
||||
dynamic_img_pad: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -38,6 +39,7 @@ class PatchEmbed(nn.Module):
|
||||
output_fmt: Optional[str] = None,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
@ -58,6 +60,7 @@ class PatchEmbed(nn.Module):
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
@ -68,7 +71,7 @@ class PatchEmbed(nn.Module):
|
||||
if self.strict_img_size:
|
||||
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
else:
|
||||
elif not self.dynamic_img_pad:
|
||||
_assert(
|
||||
H % self.patch_size[0] == 0,
|
||||
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
@ -77,7 +80,10 @@ class PatchEmbed(nn.Module):
|
||||
W % self.patch_size[1] == 0,
|
||||
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
)
|
||||
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
|
@ -74,7 +74,7 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
self.distilled_training = enable
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.dynamic_size:
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
|
@ -367,7 +367,8 @@ class Eva(nn.Module):
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rot_pos_emb: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
dynamic_size: bool = False,
|
||||
dynamic_img_size: bool = False,
|
||||
dynamic_img_pad: bool = False,
|
||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||
head_init_scale: float = 0.001,
|
||||
):
|
||||
@ -407,11 +408,11 @@ class Eva(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.dynamic_size = dynamic_size
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_size:
|
||||
if dynamic_img_size:
|
||||
# flatten deferred until after pos embed
|
||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||
self.patch_embed = PatchEmbed(
|
||||
@ -419,6 +420,7 @@ class Eva(nn.Module):
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
dynamic_img_pad=dynamic_img_pad,
|
||||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
@ -442,7 +444,7 @@ class Eva(nn.Module):
|
||||
self.rope = RotaryEmbeddingCat(
|
||||
embed_dim // num_heads,
|
||||
in_pixels=False,
|
||||
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
|
||||
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
|
||||
ref_feat_shape=ref_feat_shape,
|
||||
)
|
||||
else:
|
||||
@ -527,7 +529,7 @@ class Eva(nn.Module):
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.dynamic_size:
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
if self.pos_embed is not None:
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
|
@ -383,7 +383,7 @@ class VisionTransformer(nn.Module):
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
dynamic_size: Final[bool]
|
||||
dynamic_img_size: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -401,9 +401,10 @@ class VisionTransformer(nn.Module):
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
no_embed_class: bool = False,
|
||||
dynamic_size: bool = False,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
dynamic_img_size: bool = False,
|
||||
dynamic_img_pad: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
pos_drop_rate: float = 0.,
|
||||
patch_drop_rate: float = 0.,
|
||||
@ -454,11 +455,11 @@ class VisionTransformer(nn.Module):
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
self.dynamic_size = dynamic_size
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_size:
|
||||
if dynamic_img_size:
|
||||
# flatten deferred until after pos embed
|
||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||
self.patch_embed = embed_layer(
|
||||
@ -467,6 +468,7 @@ class VisionTransformer(nn.Module):
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
dynamic_img_pad=dynamic_img_pad,
|
||||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
@ -554,7 +556,7 @@ class VisionTransformer(nn.Module):
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.dynamic_size:
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
|
@ -18,6 +18,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
||||
@ -32,6 +33,7 @@ class HybridEmbed(nn.Module):
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
output_fmt: Format
|
||||
dynamic_img_pad: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -45,6 +47,7 @@ class HybridEmbed(nn.Module):
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
@ -71,7 +74,8 @@ class HybridEmbed(nn.Module):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
if not dynamic_img_pad:
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
if output_fmt is not None:
|
||||
@ -82,6 +86,7 @@ class HybridEmbed(nn.Module):
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
@ -89,6 +94,11 @@ class HybridEmbed(nn.Module):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
_, _, H, W = x.shape
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
|
Loading…
x
Reference in New Issue
Block a user