dynamic_size -> dynamic_img_size, add dynamic_img_pad for padding option

This commit is contained in:
Ross Wightman 2023-08-27 10:07:01 -07:00 committed by Ross Wightman
parent 1f4512fca3
commit fc5d705b83
5 changed files with 34 additions and 14 deletions

View File

@ -26,6 +26,7 @@ class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
@ -38,6 +39,7 @@ class PatchEmbed(nn.Module):
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
@ -58,6 +60,7 @@ class PatchEmbed(nn.Module):
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -68,7 +71,7 @@ class PatchEmbed(nn.Module):
if self.strict_img_size:
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
else:
elif not self.dynamic_img_pad:
_assert(
H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
@ -77,7 +80,10 @@ class PatchEmbed(nn.Module):
W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
)
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

View File

@ -74,7 +74,7 @@ class VisionTransformerDistilled(VisionTransformer):
self.distilled_training = enable
def _pos_embed(self, x):
if self.dynamic_size:
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,

View File

@ -367,7 +367,8 @@ class Eva(nn.Module):
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
use_post_norm: bool = False,
dynamic_size: bool = False,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001,
):
@ -407,11 +408,11 @@ class Eva(nn.Module):
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.dynamic_size = dynamic_size
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
embed_args = {}
if dynamic_size:
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = PatchEmbed(
@ -419,6 +420,7 @@ class Eva(nn.Module):
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
@ -442,7 +444,7 @@ class Eva(nn.Module):
self.rope = RotaryEmbeddingCat(
embed_dim // num_heads,
in_pixels=False,
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape,
)
else:
@ -527,7 +529,7 @@ class Eva(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.dynamic_size:
if self.dynamic_img_size:
B, H, W, C = x.shape
if self.pos_embed is not None:
pos_embed = resample_abs_pos_embed(

View File

@ -383,7 +383,7 @@ class VisionTransformer(nn.Module):
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_size: Final[bool]
dynamic_img_size: Final[bool]
def __init__(
self,
@ -401,9 +401,10 @@ class VisionTransformer(nn.Module):
init_values: Optional[float] = None,
class_token: bool = True,
no_embed_class: bool = False,
dynamic_size: bool = False,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
@ -454,11 +455,11 @@ class VisionTransformer(nn.Module):
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.dynamic_size = dynamic_size
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
embed_args = {}
if dynamic_size:
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = embed_layer(
@ -467,6 +468,7 @@ class VisionTransformer(nn.Module):
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
@ -554,7 +556,7 @@ class VisionTransformer(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x):
if self.dynamic_size:
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,

View File

@ -18,6 +18,7 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
@ -32,6 +33,7 @@ class HybridEmbed(nn.Module):
Extract feature map from CNN, flatten, project to embedding dim.
"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
@ -45,6 +47,7 @@ class HybridEmbed(nn.Module):
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
assert isinstance(backbone, nn.Module)
@ -71,7 +74,8 @@ class HybridEmbed(nn.Module):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
if not dynamic_img_pad:
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
if output_fmt is not None:
@ -82,6 +86,7 @@ class HybridEmbed(nn.Module):
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
@ -89,6 +94,11 @@ class HybridEmbed(nn.Module):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
_, _, H, W = x.shape
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC