dynamic_size -> dynamic_img_size, add dynamic_img_pad for padding option

This commit is contained in:
Ross Wightman 2023-08-27 10:07:01 -07:00 committed by Ross Wightman
parent 1f4512fca3
commit fc5d705b83
5 changed files with 34 additions and 14 deletions

View File

@ -26,6 +26,7 @@ class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """ 2D Image to Patch Embedding
""" """
output_fmt: Format output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__( def __init__(
self, self,
@ -38,6 +39,7 @@ class PatchEmbed(nn.Module):
output_fmt: Optional[str] = None, output_fmt: Optional[str] = None,
bias: bool = True, bias: bool = True,
strict_img_size: bool = True, strict_img_size: bool = True,
dynamic_img_pad: bool = False,
): ):
super().__init__() super().__init__()
self.patch_size = to_2tuple(patch_size) self.patch_size = to_2tuple(patch_size)
@ -58,6 +60,7 @@ class PatchEmbed(nn.Module):
self.flatten = flatten self.flatten = flatten
self.output_fmt = Format.NCHW self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -68,7 +71,7 @@ class PatchEmbed(nn.Module):
if self.strict_img_size: if self.strict_img_size:
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
else: elif not self.dynamic_img_pad:
_assert( _assert(
H % self.patch_size[0] == 0, H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
@ -77,7 +80,10 @@ class PatchEmbed(nn.Module):
W % self.patch_size[1] == 0, W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
) )
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

View File

@ -74,7 +74,7 @@ class VisionTransformerDistilled(VisionTransformer):
self.distilled_training = enable self.distilled_training = enable
def _pos_embed(self, x): def _pos_embed(self, x):
if self.dynamic_size: if self.dynamic_img_size:
B, H, W, C = x.shape B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed( pos_embed = resample_abs_pos_embed(
self.pos_embed, self.pos_embed,

View File

@ -367,7 +367,8 @@ class Eva(nn.Module):
use_abs_pos_emb: bool = True, use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False, use_rot_pos_emb: bool = False,
use_post_norm: bool = False, use_post_norm: bool = False,
dynamic_size: bool = False, dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001, head_init_scale: float = 0.001,
): ):
@ -407,11 +408,11 @@ class Eva(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.dynamic_size = dynamic_size self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False self.grad_checkpointing = False
embed_args = {} embed_args = {}
if dynamic_size: if dynamic_img_size:
# flatten deferred until after pos embed # flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
@ -419,6 +420,7 @@ class Eva(nn.Module):
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
dynamic_img_pad=dynamic_img_pad,
**embed_args, **embed_args,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
@ -442,7 +444,7 @@ class Eva(nn.Module):
self.rope = RotaryEmbeddingCat( self.rope = RotaryEmbeddingCat(
embed_dim // num_heads, embed_dim // num_heads,
in_pixels=False, in_pixels=False,
feat_shape=None if dynamic_size else self.patch_embed.grid_size, feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape, ref_feat_shape=ref_feat_shape,
) )
else: else:
@ -527,7 +529,7 @@ class Eva(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.dynamic_size: if self.dynamic_img_size:
B, H, W, C = x.shape B, H, W, C = x.shape
if self.pos_embed is not None: if self.pos_embed is not None:
pos_embed = resample_abs_pos_embed( pos_embed = resample_abs_pos_embed(

View File

@ -383,7 +383,7 @@ class VisionTransformer(nn.Module):
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
""" """
dynamic_size: Final[bool] dynamic_img_size: Final[bool]
def __init__( def __init__(
self, self,
@ -401,9 +401,10 @@ class VisionTransformer(nn.Module):
init_values: Optional[float] = None, init_values: Optional[float] = None,
class_token: bool = True, class_token: bool = True,
no_embed_class: bool = False, no_embed_class: bool = False,
dynamic_size: bool = False,
pre_norm: bool = False, pre_norm: bool = False,
fc_norm: Optional[bool] = None, fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0., drop_rate: float = 0.,
pos_drop_rate: float = 0., pos_drop_rate: float = 0.,
patch_drop_rate: float = 0., patch_drop_rate: float = 0.,
@ -454,11 +455,11 @@ class VisionTransformer(nn.Module):
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class self.no_embed_class = no_embed_class
self.dynamic_size = dynamic_size self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False self.grad_checkpointing = False
embed_args = {} embed_args = {}
if dynamic_size: if dynamic_img_size:
# flatten deferred until after pos embed # flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
@ -467,6 +468,7 @@ class VisionTransformer(nn.Module):
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args, **embed_args,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
@ -554,7 +556,7 @@ class VisionTransformer(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x): def _pos_embed(self, x):
if self.dynamic_size: if self.dynamic_img_size:
B, H, W, C = x.shape B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed( pos_embed = resample_abs_pos_embed(
self.pos_embed, self.pos_embed,

View File

@ -18,6 +18,7 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
@ -32,6 +33,7 @@ class HybridEmbed(nn.Module):
Extract feature map from CNN, flatten, project to embedding dim. Extract feature map from CNN, flatten, project to embedding dim.
""" """
output_fmt: Format output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__( def __init__(
self, self,
@ -45,6 +47,7 @@ class HybridEmbed(nn.Module):
flatten: bool = True, flatten: bool = True,
output_fmt: Optional[str] = None, output_fmt: Optional[str] = None,
strict_img_size: bool = True, strict_img_size: bool = True,
dynamic_img_pad: bool = False,
): ):
super().__init__() super().__init__()
assert isinstance(backbone, nn.Module) assert isinstance(backbone, nn.Module)
@ -71,6 +74,7 @@ class HybridEmbed(nn.Module):
feature_dim = self.backbone.feature_info.channels()[-1] feature_dim = self.backbone.feature_info.channels()[-1]
else: else:
feature_dim = self.backbone.num_features feature_dim = self.backbone.num_features
if not dynamic_img_pad:
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
@ -82,6 +86,7 @@ class HybridEmbed(nn.Module):
self.flatten = flatten self.flatten = flatten
self.output_fmt = Format.NCHW self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
@ -89,6 +94,11 @@ class HybridEmbed(nn.Module):
x = self.backbone(x) x = self.backbone(x)
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features x = x[-1] # last feature if backbone outputs list/tuple of features
_, _, H, W = x.shape
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x = x.flatten(2).transpose(1, 2) # NCHW -> NLC